Serving a Pytorch model

Learn how to serve a pytorch model with Ray Serve

This tutorial will show you how to serve a more advance application. In particular, a Pytorch Resnet model. This tutorial is closely modeled after a similar tutorial for Ray Serve.

  1. Create a project for your application.

    $ mkdir hello_world_pytorch
    $ cd hello_world_pytorch
    $ anyscale init
    Project name: hello_world_pytorch
    Project prj_6SCoYQrJU4BYzpcpj0mKxk created. View at https://console.anyscale.com/projects/prj_6SCoYQrJU4BYzpcpj0mKxk
  2. Install necessary dependencies.

    pip install "ray[serve]"
    pip install torch torchvision
  3. Add imports and create runtime environment with the necessary dependencies.

    import ray
    from ray import serve
    
    from io import BytesIO
    from PIL import Image
    import requests
    
    import torch
    from torchvision import transforms
    from torchvision.models import resnet18
    
    runtime_env = {
     "pip": ["torch", "torchvision", "ray[serve]"]
    }
  4. Define your service

    @serve.deployment(route_prefix="/image_predict")
    class ImageModel:
       def __init__(self):
           self.model = resnet18(pretrained=True).eval()
           self.preprocessor = transforms.Compose([
               transforms.Resize(224),
               transforms.CenterCrop(224),
               transforms.ToTensor(),
               transforms.Lambda(lambda t: t[:3, ...]),  # remove alpha channel
               transforms.Normalize(
                   mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
           ])
    
       async def __call__(self, starlette_request):
           image_payload_bytes = await starlette_request.body()
           pil_image = Image.open(BytesIO(image_payload_bytes))
           print("[1/3] Parsed image data: {}".format(pil_image))
    
           pil_images = [pil_image]  # Our current batch size is one
           input_tensor = torch.cat(
               [self.preprocessor(i).unsqueeze(0) for i in pil_images])
           print("[2/3] Images transformed, tensor shape {}".format(
               input_tensor.shape))
    
           with torch.no_grad():
               output_tensor = self.model(input_tensor)
           print("[3/3] Inference done!")
           return {"class_index": int(torch.argmax(output_tensor[0]))}
  5. Connect to anyscale and deploy it.

    ray.init("anyscale://my_cluster", namespace="my_serve_namespace", runtime_env=runtime_env, allow_public_internet_traffic=True, autosuspend=-1)
    
    serve.start(detached=True)
    
    ImageModel.deploy()
  6. Get your session id from the CLI response and query the endpoint! (Remember: replace 6g8m3we3xifa14umfs6fru with your own id.)

    ray_logo_bytes = requests.get(
       "https://github.com/ray-project/ray/raw/"
       "master/doc/source/images/ray_header_logo.png").content
    
    resp = requests.post(
       "https://serve-session-6g8m3we3xifa14umfs6fru.i.anyscaleuserdata.com/image_predict", data=ray_logo_bytes)
    print(resp.json())
    # Output
    # {'class_index': 919}

Last updated

Was this helpful?