apache_beam.yaml.yaml_ml module
This module defines yaml wrappings for some ML transforms.
- class apache_beam.yaml.yaml_ml.ModelHandlerProvider(handler, preprocess: dict[str, str] | None = None, postprocess: dict[str, str] | None = None)[source]
Bases:
object
- handler_types: dict[str, Callable[[...], ModelHandlerProvider]] = {'VertexAIModelHandlerJSON': <class 'apache_beam.yaml.yaml_ml.VertexAIModelHandlerJSONProvider'>}
- classmethod create_handler(model_handler_spec) ModelHandlerProvider [source]
- class apache_beam.yaml.yaml_ml.VertexAIModelHandlerJSONProvider(endpoint_id: str, project: str, location: str, preprocess: dict[str, str], postprocess: dict[str, str] | None = None, experiment: str | None = None, network: str | None = None, private: bool = False, min_batch_size: int | None = None, max_batch_size: int | None = None, max_batch_duration_secs: int | None = None, env_vars: dict[str, Any] | None = None)[source]
Bases:
ModelHandlerProvider
ModelHandler for Vertex AI.
This Model Handler can be used with RunInference to load a model hosted on VertexAI. Every model that is hosted on VertexAI should have three distinct, required, parameters - endpoint_id, project and location. These parameters tell the Model Handler how to access the model’s endpoint so that input data can be sent using an API request, and inferences can be received as a response.
This Model Handler also requires a preprocess function to be defined. Preprocessing and Postprocessing are described in more detail in the RunInference docs: https://beam.apache.org/releases/yamldoc/current/#runinference
Every model will have a unique input, but all requests should be JSON-formatted. For example, most language models such as Llama and Gemma expect a JSON with the key “prompt” (among other optional keys). In Python, JSON can be expressed as a dictionary.
For example:
- type: RunInference config: inference_tag: 'my_inference' model_handler: type: VertexAIModelHandlerJSON config: endpoint_id: 9876543210 project: my-project location: us-east1 preprocess: callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}'
In the above example, which mimics a call to a Llama 3 model hosted on VertexAI, the preprocess function (in this case a lambda) takes in a Beam Row with a single field, “prompt”, and maps it to a dict with the same field. It also specifies an optional parameter, “max_tokens”, that tells the model the allowed token size (in this case input + output token size).
- Parameters:
endpoint_id – the numerical ID of the Vertex AI endpoint to query.
project – the GCP project name where the endpoint is deployed.
location – the GCP location where the endpoint is deployed.
preprocess – A python callable, defined either inline, or using a file, that is invoked on the input row before sending to the model to be loaded by this ModelHandler. This parameter is required by the VertexAIModelHandlerJSON ModelHandler.
postprocess – A python callable, defined either inline, or using a file, that is invoked on the PredictionResult output by the ModelHandler before parsing into the output Beam Row under the field name defined by the inference_tag.
experiment – Experiment label to apply to the queries. See https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments for more information.
network – The full name of the Compute Engine network the endpoint is deployed on; used for private endpoints. The network or subnetwork Dataflow pipeline option must be set and match this network for pipeline execution. Ex: “projects/12345/global/networks/myVPC”
private – If the deployed Vertex AI endpoint is private, set to true. Requires a network to be provided as well.
min_batch_size – The minimum batch size to use when batching inputs.
max_batch_size – The maximum batch size to use when batching inputs.
max_batch_duration_secs – The maximum amount of time to buffer a batch before emitting; used in streaming contexts.
env_vars – Environment variables.
- apache_beam.yaml.yaml_ml.run_inference(model_handler: dict[str, Any], inference_tag: str | None = 'inference', inference_args: dict[str, Any] | None = None) PCollection[Row] [source]
A transform that takes the input rows, containing examples (or features), for use on an ML model. The transform then appends the inferences (or predictions) for those examples to the input row.
A ModelHandler must be passed to the model_handler parameter. The ModelHandler is responsible for configuring how the ML model will be loaded and how input data will be passed to it. Every ModelHandler has a config tag, similar to how a transform is defined, where the parameters are defined.
For example:
- type: RunInference config: model_handler: type: ModelHandler config: param_1: arg1 param_2: arg2 ...
By default, the RunInference transform will return the input row with a single field appended named by the inference_tag parameter (“inference” by default) that contains the inference directly returned by the underlying ModelHandler, after any optional postprocessing.
For example, if the input had the following:
Row(question="What is a car?")
The output row would look like:
Row(question="What is a car?", inference=...)
where the inference tag can be overridden with the inference_tag parameter.
However, if one specified the following transform config:
- type: RunInference config: inference_tag: my_inference model_handler: ...
The output row would look like:
Row(question="What is a car?", my_inference=...)
See more complete documentation on the underlying [RunInference](https://beam.apache.org/documentation/ml/inference-overview/) transform.
### Preprocessing input data
In most cases, the model will be expecting data in a particular data format, whether it be a Python Dict, PyTorch tensor, etc. However, the outputs of all built-in Beam YAML transforms are Beam Rows. To allow for transforming the Beam Row into a data format the model recognizes, each ModelHandler is equipped with a preprocessing parameter for performing necessary data preprocessing. It is possible for a ModelHandler to define a default preprocessing function, but in most cases, one will need to be specified by the caller.
For example, using callable:
pipeline: type: chain transforms: - type: Create config: elements: - question: "What is a car?" - question: "Where is the Eiffel Tower located?" - type: RunInference config: model_handler: type: ModelHandler config: param_1: arg1 param_2: arg2 preprocess: callable: 'lambda row: {"prompt": row.question}' ...
In the above example, the Create transform generates a collection of two Beam Row elements, each with a single field - “question”. The model, however, expects a Python Dict with a single key, “prompt”. In this case, we can specify a simple Lambda function (alternatively could define a full function), to map the data.
### Postprocessing predictions
It is also possible to define a postprocessing function to postprocess the data output by the ModelHandler. See the documentation for the ModelHandler you intend to use (list defined below under model_handler parameter doc).
In many cases, before postprocessing, the object will be a [PredictionResult](https://beam.apache.org/releases/pydoc/BEAM_VERSION/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult). # pylint: disable=line-too-long This type behaves very similarly to a Beam Row and fields can be accessed using dot notation. However, make sure to check the docs for your ModelHandler to see which fields its PredictionResult contains or if it returns a different object altogether.
For example:
- type: RunInference config: model_handler: type: ModelHandler config: param_1: arg1 param_2: arg2 postprocess: callable: | def fn(x: PredictionResult): return beam.Row(x.example, x.inference, x.model_id) ...
The above example demonstrates converting the original output data type (in this case it is PredictionResult), and converts to a Beam Row, which allows for easier mapping in a later transform.
### File-based pre/postprocessing functions
For both preprocessing and postprocessing, it is also possible to specify a Python UDF (User-defined function) file that contains the function. This is possible by specifying the path to the file (local file or GCS path) and the name of the function in the file.
For example:
- type: RunInference config: model_handler: type: ModelHandler config: param_1: arg1 param_2: arg2 preprocess: path: gs://my-bucket/path/to/preprocess.py name: my_preprocess_fn postprocess: path: gs://my-bucket/path/to/postprocess.py name: my_postprocess_fn ...
- Parameters:
model_handler –
Specifies the parameters for the respective enrichment_handler in a YAML/JSON format. To see the full set of handler_config parameters, see their corresponding doc pages:
[VertexAIModelHandlerJSON](https://beam.apache.org/releases/pydoc/current/apache_beam.yaml.yaml_ml.VertexAIModelHandlerJSONProvider) # pylint: disable=line-too-long
inference_tag – The tag to use for the returned inference. Default is ‘inference’.
inference_args – Extra arguments for models whose inference call requires extra parameters. Make sure to check the underlying ModelHandler docs to see which args are allowed.