Source code for apache_beam.ml.inference.pytorch

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# pytype: skip-file

from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List

import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.api import PredictionResult
from apache_beam.ml.inference.base import InferenceRunner
from apache_beam.ml.inference.base import ModelLoader


[docs]class PytorchInferenceRunner(InferenceRunner): """ This class runs Pytorch inferences with the run_inference method. It also has other methods to get the bytes of a batch of Tensors as well as the namespace for Pytorch models. """ def __init__(self, device: torch.device): self._device = device
[docs] def run_inference(self, batch: List[torch.Tensor], model: torch.nn.Module) -> Iterable[PredictionResult]: """ Runs inferences on a batch of Tensors and returns an Iterable of Tensor Predictions. This method stacks the list of Tensors in a vectorized format to optimize the inference call. """ torch_batch = torch.stack(batch) if torch_batch.device != self._device: torch_batch = torch_batch.to(self._device) predictions = model(torch_batch) return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
[docs] def get_num_bytes(self, batch: List[torch.Tensor]) -> int: """Returns the number of bytes of data for a batch of Tensors.""" return sum((el.element_size() for tensor in batch for el in tensor))
[docs] def get_metrics_namespace(self) -> str: """ Returns a namespace for metrics collected by the RunInference transform. """ return 'RunInferencePytorch'
[docs]class PytorchModelLoader(ModelLoader): """Loads a Pytorch Model.""" def __init__( self, state_dict_path: str, model_class: Callable[..., torch.nn.Module], model_params: Dict[str, Any], device: str = 'CPU'): """ Initializes a PytorchModelLoader :param state_dict_path: path to the saved dictionary of the model state. :param model_class: class of the Pytorch model that defines the model structure. :param device: the device on which you wish to run the model. If ``device = GPU`` then a GPU device will be used if it is available. Otherwise, it will be CPU. See https://pytorch.org/tutorials/beginner/saving_loading_models.html for details """ self._state_dict_path = state_dict_path if device == 'GPU' and torch.cuda.is_available(): self._device = torch.device('cuda') else: self._device = torch.device('cpu') self._model_class = model_class self.model_params = model_params self._inference_runner = PytorchInferenceRunner(device=self._device)
[docs] def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" model = self._model_class(**self.model_params) model.to(self._device) file = FileSystems.open(self._state_dict_path, 'rb') model.load_state_dict(torch.load(file)) model.eval() return model
[docs] def get_inference_runner(self) -> InferenceRunner: """Returns a Pytorch implementation of InferenceRunner.""" return self._inference_runner