Serving a Pytorch model
Learn how to serve a pytorch model with Ray Serve
This tutorial will show you how to serve a more advance application. In particular, a Pytorch Resnet model. This tutorial is closely modeled after a similar tutorial for Ray Serve.
Create a project for your application.
$ mkdir hello_world_pytorch $ cd hello_world_pytorch $ anyscale init Project name: hello_world_pytorch Project prj_6SCoYQrJU4BYzpcpj0mKxk created. View at https://console.anyscale.com/projects/prj_6SCoYQrJU4BYzpcpj0mKxk
Install necessary dependencies.
pip install "ray[serve]" pip install torch torchvision
Add imports and create runtime environment with the necessary dependencies.
import ray from ray import serve from io import BytesIO from PIL import Image import requests import torch from torchvision import transforms from torchvision.models import resnet18 runtime_env = { "pip": ["torch", "torchvision", "ray[serve]"] }
Define your service
@serve.deployment(route_prefix="/image_predict") class ImageModel: def __init__(self): self.model = resnet18(pretrained=True).eval() self.preprocessor = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Lambda(lambda t: t[:3, ...]), # remove alpha channel transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) async def __call__(self, starlette_request): image_payload_bytes = await starlette_request.body() pil_image = Image.open(BytesIO(image_payload_bytes)) print("[1/3] Parsed image data: {}".format(pil_image)) pil_images = [pil_image] # Our current batch size is one input_tensor = torch.cat( [self.preprocessor(i).unsqueeze(0) for i in pil_images]) print("[2/3] Images transformed, tensor shape {}".format( input_tensor.shape)) with torch.no_grad(): output_tensor = self.model(input_tensor) print("[3/3] Inference done!") return {"class_index": int(torch.argmax(output_tensor[0]))}
Connect to anyscale and deploy it.
ray.init("anyscale://my_cluster", namespace="my_serve_namespace", runtime_env=runtime_env, allow_public_internet_traffic=True, autosuspend=-1) serve.start(detached=True) ImageModel.deploy()
Get your session id from the CLI response and query the endpoint! (Remember: replace
6g8m3we3xifa14umfs6fru
with your own id.)ray_logo_bytes = requests.get( "https://github.com/ray-project/ray/raw/" "master/doc/source/images/ray_header_logo.png").content resp = requests.post( "https://serve-session-6g8m3we3xifa14umfs6fru.i.anyscaleuserdata.com/image_predict", data=ray_logo_bytes) print(resp.json()) # Output # {'class_index': 919}
Last updated
Was this helpful?