apache_beam.ml.inference.base module

An extensible run inference transform.

Users of this module can extend the ModelHandler class for any machine learning framework. A ModelHandler implementation is a required parameter of RunInference.

The transform handles standard inference functionality, like metric collection, sharing model between threads, and batching elements.

class apache_beam.ml.inference.base.PredictionResult[source]

Bases: apache_beam.ml.inference.base.PredictionResult

A NamedTuple containing both input and output from the inference.

class apache_beam.ml.inference.base.ModelMetadata(model_id, model_name)[source]

Bases: tuple

Create new instance of ModelMetadata(model_id, model_name)

model_id

Unique identifier for the model. This can be a file path or a URL where the model can be accessed. It is used to load the model for inference.

model_name

Human-readable name for the model. This can be used to identify the model in the metrics generated by the RunInference transform.

class apache_beam.ml.inference.base.ModelHandler[source]

Bases: typing.Generic

Has the ability to load and apply an ML model.

load_model() → ModelT[source]

Loads and initializes a model for processing.

run_inference(batch: Sequence[ExampleT], model: ModelT, inference_args: Optional[Dict[str, Any]] = None) → Iterable[PredictionT][source]

Runs inferences on a batch of examples.

Parameters:
  • batch – A sequence of examples or features.
  • model – The model used to make inferences.
  • inference_args – Extra arguments for models whose inference call requires extra parameters.
Returns:

An Iterable of Predictions.

get_num_bytes(batch: Sequence[ExampleT]) → int[source]
Returns:The number of bytes of data for a batch.
get_metrics_namespace() → str[source]
Returns:A namespace for metrics collected by the RunInference transform.
get_resource_hints() → dict[source]
Returns:Resource hints for the transform.
batch_elements_kwargs() → Mapping[str, Any][source]
Returns:kwargs suitable for beam.BatchElements.
validate_inference_args(inference_args: Optional[Dict[str, Any]])[source]

Validates inference_args passed in the inference call.

Because most frameworks do not need extra arguments in their predict() call, the default behavior is to error out if inference_args are present.

update_model_path(model_path: Optional[str] = None)[source]

Update the model paths produced by side inputs.

class apache_beam.ml.inference.base.KeyedModelHandler(unkeyed: apache_beam.ml.inference.base.ModelHandler[~ExampleT, ~PredictionT, ~ModelT][ExampleT, PredictionT, ModelT])[source]

Bases: apache_beam.ml.inference.base.ModelHandler

A ModelHandler that takes keyed examples and returns keyed predictions.

For example, if the original model is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], making it possible to use the key to associate the outputs with the inputs.

Parameters:unkeyed – An implementation of ModelHandler that does not require keys.
load_model() → ModelT[source]
run_inference(batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT, inference_args: Optional[Dict[str, Any]] = None) → Iterable[Tuple[KeyT, PredictionT]][source]
get_num_bytes(batch: Sequence[Tuple[KeyT, ExampleT]]) → int[source]
get_metrics_namespace() → str[source]
get_resource_hints()[source]
batch_elements_kwargs()[source]
validate_inference_args(inference_args: Optional[Dict[str, Any]])[source]
update_model_path(model_path: Optional[str] = None)[source]
class apache_beam.ml.inference.base.MaybeKeyedModelHandler(unkeyed: apache_beam.ml.inference.base.ModelHandler[~ExampleT, ~PredictionT, ~ModelT][ExampleT, PredictionT, ModelT])[source]

Bases: apache_beam.ml.inference.base.ModelHandler

A ModelHandler that takes examples that might have keys and returns predictions that might have keys.

For example, if the original model is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take either PCollection[E] to a PCollection[P] or PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], depending on the whether the elements are tuples. This pattern makes it possible to associate the outputs with the inputs based on the key.

Note that you cannot use this ModelHandler if E is a tuple type. In addition, either all examples should be keyed, or none of them.

Parameters:unkeyed – An implementation of ModelHandler that does not require keys.
load_model() → ModelT[source]
run_inference(batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], model: ModelT, inference_args: Optional[Dict[str, Any]] = None) → Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]][source]
get_num_bytes(batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) → int[source]
get_metrics_namespace() → str[source]
get_resource_hints()[source]
batch_elements_kwargs()[source]
validate_inference_args(inference_args: Optional[Dict[str, Any]])[source]
update_model_path(model_path: Optional[str] = None)[source]
class apache_beam.ml.inference.base.RunInference(model_handler: apache_beam.ml.inference.base.ModelHandler[~ExampleT, ~PredictionT, typing.Any][ExampleT, PredictionT, Any], clock=<module 'time' (built-in)>, inference_args: Optional[Dict[str, Any]] = None, metrics_namespace: Optional[str] = None, *, model_metadata_pcoll: apache_beam.pvalue.PCollection[apache_beam.ml.inference.base.ModelMetadata][apache_beam.ml.inference.base.ModelMetadata] = None)[source]

Bases: apache_beam.transforms.ptransform.PTransform

A transform that takes a PCollection of examples (or features) for use on an ML model. The transform then outputs inferences (or predictions) for those examples in a PCollection of PredictionResults that contains the input examples and the output inferences.

Models for supported frameworks can be loaded using a URI. Supported services can also be used.

This transform attempts to batch examples using the beam.BatchElements transform. Batching can be configured using the ModelHandler.

Parameters:
  • model_handler – An implementation of ModelHandler.
  • clock – A clock implementing time_ns. Used for unit testing.
  • inference_args – Extra arguments for models whose inference call requires extra parameters.
  • metrics_namespace – Namespace of the transform to collect metrics.
  • model_metadata_pcoll – PCollection that emits Singleton ModelMetadata containing model path and model name, that is used as a side input to the _RunInferenceDoFn.
classmethod from_callable(model_handler_provider, **kwargs)[source]

Multi-language friendly constructor.

Use this constructor with fully_qualified_named_transform to initialize the RunInference transform from PythonCallableSource provided by foreign SDKs.

Parameters:
  • model_handler_provider – A callable object that returns ModelHandler.
  • kwargs – Keyword arguments for model_handler_provider.
expand(pcoll: apache_beam.pvalue.PCollection[~ExampleT][ExampleT]) → apache_beam.pvalue.PCollection[~PredictionT][PredictionT][source]