apache_beam.ml.inference.pytorch_inference module¶
-
class
apache_beam.ml.inference.pytorch_inference.
PytorchModelHandlerTensor
(state_dict_path: str, model_class: Callable[[...], torch.nn.modules.module.Module], model_params: Dict[str, Any], device: str = 'CPU')[source]¶ Bases:
apache_beam.ml.inference.base.ModelHandler
Implementation of the ModelHandler interface for PyTorch.
Example Usage:
pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri"))
Parameters: - state_dict_path – path to the saved dictionary of the model state.
- model_class – class of the Pytorch model that defines the model structure.
- model_params – A dictionary of arguments required to instantiate the model class.
- device – the device on which you wish to run the model. If
device = GPU
then a GPU device will be used if it is available. Otherwise, it will be CPU.
See https://pytorch.org/tutorials/beginner/saving_loading_models.html for details
-
load_model
() → torch.nn.modules.module.Module[source]¶ Loads and initializes a Pytorch model for processing.
-
run_inference
(batch: Sequence[torch.Tensor], model: torch.nn.modules.module.Module, inference_args: Optional[Dict[str, Any]] = None) → Iterable[apache_beam.ml.inference.base.PredictionResult][source]¶ Runs inferences on a batch of Tensors and returns an Iterable of Tensor Predictions.
This method stacks the list of Tensors in a vectorized format to optimize the inference call.
Parameters: - batch – A sequence of Tensors. These Tensors should be batchable, as this method will call torch.stack() and pass in batched Tensors with dimensions (batch_size, n_features, etc.) into the model’s forward() function.
- model – A PyTorch model.
- inference_args – Non-batchable arguments required as inputs to the model’s forward() function. Unlike Tensors in batch, these parameters will not be dynamically batched
Returns: An Iterable of type PredictionResult.
-
get_num_bytes
(batch: Sequence[torch.Tensor]) → int[source]¶ Returns: The number of bytes of data for a batch of Tensors.
-
class
apache_beam.ml.inference.pytorch_inference.
PytorchModelHandlerKeyedTensor
(state_dict_path: str, model_class: Callable[[...], torch.nn.modules.module.Module], model_params: Dict[str, Any], device: str = 'CPU')[source]¶ Bases:
apache_beam.ml.inference.base.ModelHandler
Implementation of the ModelHandler interface for PyTorch.
Example Usage:
pcoll | RunInference( PytorchModelHandlerKeyedTensor(state_dict_path="my_uri"))
NOTE: This API and its implementation are under development and do not provide backward compatibility guarantees.
See https://pytorch.org/tutorials/beginner/saving_loading_models.html for details
Parameters: - state_dict_path – path to the saved dictionary of the model state.
- model_class – class of the Pytorch model that defines the model structure.
- model_params – A dictionary of arguments required to instantiate the model class.
- device – the device on which you wish to run the model. If
device = GPU
then a GPU device will be used if it is available. Otherwise, it will be CPU.
-
load_model
() → torch.nn.modules.module.Module[source]¶ Loads and initializes a Pytorch model for processing.
-
run_inference
(batch: Sequence[Dict[str, torch.Tensor]], model: torch.nn.modules.module.Module, inference_args: Optional[Dict[str, Any]] = None) → Iterable[apache_beam.ml.inference.base.PredictionResult][source]¶ Runs inferences on a batch of Keyed Tensors and returns an Iterable of Tensor Predictions.
For the same key across all examples, this will stack all Tensors values in a vectorized format to optimize the inference call.
Parameters: - batch – A sequence of keyed Tensors. These Tensors should be batchable, as this method will call torch.stack() and pass in batched Tensors with dimensions (batch_size, n_features, etc.) into the model’s forward() function.
- model – A PyTorch model.
- inference_args – Non-batchable arguments required as inputs to the model’s forward() function. Unlike Tensors in batch, these parameters will not be dynamically batched
Returns: An Iterable of type PredictionResult.
-
get_num_bytes
(batch: Sequence[torch.Tensor]) → int[source]¶ Returns: The number of bytes of data for a batch of Dict of Tensors.