Serving a Pytorch model
Learn how to serve a pytorch model with Ray Serve
$ mkdir hello_world_pytorch $ cd hello_world_pytorch $ anyscale init Project name: hello_world_pytorch Project prj_6SCoYQrJU4BYzpcpj0mKxk created. View at https://console.anyscale.com/projects/prj_6SCoYQrJU4BYzpcpj0mKxkpip install "ray[serve]" pip install torch torchvisionimport ray from ray import serve from io import BytesIO from PIL import Image import requests import torch from torchvision import transforms from torchvision.models import resnet18 runtime_env = { "pip": ["torch", "torchvision", "ray[serve]"] }@serve.deployment(route_prefix="/image_predict") class ImageModel: def __init__(self): self.model = resnet18(pretrained=True).eval() self.preprocessor = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Lambda(lambda t: t[:3, ...]), # remove alpha channel transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) async def __call__(self, starlette_request): image_payload_bytes = await starlette_request.body() pil_image = Image.open(BytesIO(image_payload_bytes)) print("[1/3] Parsed image data: {}".format(pil_image)) pil_images = [pil_image] # Our current batch size is one input_tensor = torch.cat( [self.preprocessor(i).unsqueeze(0) for i in pil_images]) print("[2/3] Images transformed, tensor shape {}".format( input_tensor.shape)) with torch.no_grad(): output_tensor = self.model(input_tensor) print("[3/3] Inference done!") return {"class_index": int(torch.argmax(output_tensor[0]))}ray.init("anyscale://my_cluster", namespace="my_serve_namespace", runtime_env=runtime_env, allow_public_internet_traffic=True, autosuspend=-1) serve.start(detached=True) ImageModel.deploy()ray_logo_bytes = requests.get( "https://github.com/ray-project/ray/raw/" "master/doc/source/images/ray_header_logo.png").content resp = requests.post( "https://serve-session-6g8m3we3xifa14umfs6fru.i.anyscaleuserdata.com/image_predict", data=ray_logo_bytes) print(resp.json()) # Output # {'class_index': 919}
Last updated
Was this helpful?
