apache_beam.yaml.yaml_ml module

This module defines yaml wrappings for some ML transforms.

class apache_beam.yaml.yaml_ml.ModelHandlerProvider(handler, preprocess: dict[str, str] | None = None, postprocess: dict[str, str] | None = None)[source]

Bases: object

handler_types: dict[str, Callable[[...], ModelHandlerProvider]] = {'VertexAIModelHandlerJSON': <class 'apache_beam.yaml.yaml_ml.VertexAIModelHandlerJSONProvider'>}
inference_output_type()[source]
static parse_processing_transform(processing_transform, typ)[source]
underlying_handler()[source]
static default_preprocess_fn()[source]
static default_postprocess_fn()[source]
static validate(model_handler_spec)[source]
classmethod register_handler_type(type_name)[source]
classmethod create_handler(model_handler_spec) ModelHandlerProvider[source]
class apache_beam.yaml.yaml_ml.VertexAIModelHandlerJSONProvider(endpoint_id: str, project: str, location: str, preprocess: dict[str, str], postprocess: dict[str, str] | None = None, experiment: str | None = None, network: str | None = None, private: bool = False, min_batch_size: int | None = None, max_batch_size: int | None = None, max_batch_duration_secs: int | None = None, env_vars: dict[str, Any] | None = None)[source]

Bases: ModelHandlerProvider

ModelHandler for Vertex AI.

This Model Handler can be used with RunInference to load a model hosted on VertexAI. Every model that is hosted on VertexAI should have three distinct, required, parameters - endpoint_id, project and location. These parameters tell the Model Handler how to access the model’s endpoint so that input data can be sent using an API request, and inferences can be received as a response.

This Model Handler also requires a preprocess function to be defined. Preprocessing and Postprocessing are described in more detail in the RunInference docs: https://beam.apache.org/releases/yamldoc/current/#runinference

Every model will have a unique input, but all requests should be JSON-formatted. For example, most language models such as Llama and Gemma expect a JSON with the key “prompt” (among other optional keys). In Python, JSON can be expressed as a dictionary.

For example:

- type: RunInference
  config:
    inference_tag: 'my_inference'
    model_handler:
      type: VertexAIModelHandlerJSON
      config:
        endpoint_id: 9876543210
        project: my-project
        location: us-east1
        preprocess:
          callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}'

In the above example, which mimics a call to a Llama 3 model hosted on VertexAI, the preprocess function (in this case a lambda) takes in a Beam Row with a single field, “prompt”, and maps it to a dict with the same field. It also specifies an optional parameter, “max_tokens”, that tells the model the allowed token size (in this case input + output token size).

Parameters:
  • endpoint_id – the numerical ID of the Vertex AI endpoint to query.

  • project – the GCP project name where the endpoint is deployed.

  • location – the GCP location where the endpoint is deployed.

  • preprocess – A python callable, defined either inline, or using a file, that is invoked on the input row before sending to the model to be loaded by this ModelHandler. This parameter is required by the VertexAIModelHandlerJSON ModelHandler.

  • postprocess – A python callable, defined either inline, or using a file, that is invoked on the PredictionResult output by the ModelHandler before parsing into the output Beam Row under the field name defined by the inference_tag.

  • experiment – Experiment label to apply to the queries. See https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments for more information.

  • network – The full name of the Compute Engine network the endpoint is deployed on; used for private endpoints. The network or subnetwork Dataflow pipeline option must be set and match this network for pipeline execution. Ex: “projects/12345/global/networks/myVPC”

  • private – If the deployed Vertex AI endpoint is private, set to true. Requires a network to be provided as well.

  • min_batch_size – The minimum batch size to use when batching inputs.

  • max_batch_size – The maximum batch size to use when batching inputs.

  • max_batch_duration_secs – The maximum amount of time to buffer a batch before emitting; used in streaming contexts.

  • env_vars – Environment variables.

static validate(model_handler_spec)[source]
inference_output_type()[source]
apache_beam.yaml.yaml_ml.get_user_schema_fields(user_type)[source]
apache_beam.yaml.yaml_ml.run_inference(model_handler: dict[str, Any], inference_tag: str | None = 'inference', inference_args: dict[str, Any] | None = None) PCollection[Row][source]

A transform that takes the input rows, containing examples (or features), for use on an ML model. The transform then appends the inferences (or predictions) for those examples to the input row.

A ModelHandler must be passed to the model_handler parameter. The ModelHandler is responsible for configuring how the ML model will be loaded and how input data will be passed to it. Every ModelHandler has a config tag, similar to how a transform is defined, where the parameters are defined.

For example:

- type: RunInference
  config:
    model_handler:
      type: ModelHandler
      config:
        param_1: arg1
        param_2: arg2
        ...

By default, the RunInference transform will return the input row with a single field appended named by the inference_tag parameter (“inference” by default) that contains the inference directly returned by the underlying ModelHandler, after any optional postprocessing.

For example, if the input had the following:

Row(question="What is a car?")

The output row would look like:

Row(question="What is a car?", inference=...)

where the inference tag can be overridden with the inference_tag parameter.

However, if one specified the following transform config:

- type: RunInference
  config:
    inference_tag: my_inference
    model_handler: ...

The output row would look like:

Row(question="What is a car?", my_inference=...)

See more complete documentation on the underlying [RunInference](https://beam.apache.org/documentation/ml/inference-overview/) transform.

### Preprocessing input data

In most cases, the model will be expecting data in a particular data format, whether it be a Python Dict, PyTorch tensor, etc. However, the outputs of all built-in Beam YAML transforms are Beam Rows. To allow for transforming the Beam Row into a data format the model recognizes, each ModelHandler is equipped with a preprocessing parameter for performing necessary data preprocessing. It is possible for a ModelHandler to define a default preprocessing function, but in most cases, one will need to be specified by the caller.

For example, using callable:

pipeline:
  type: chain

  transforms:
    - type: Create
      config:
        elements:
          - question: "What is a car?"
          - question: "Where is the Eiffel Tower located?"

    - type: RunInference
      config:
        model_handler:
          type: ModelHandler
          config:
            param_1: arg1
            param_2: arg2
            preprocess:
              callable: 'lambda row: {"prompt": row.question}'
            ...

In the above example, the Create transform generates a collection of two Beam Row elements, each with a single field - “question”. The model, however, expects a Python Dict with a single key, “prompt”. In this case, we can specify a simple Lambda function (alternatively could define a full function), to map the data.

### Postprocessing predictions

It is also possible to define a postprocessing function to postprocess the data output by the ModelHandler. See the documentation for the ModelHandler you intend to use (list defined below under model_handler parameter doc).

In many cases, before postprocessing, the object will be a [PredictionResult](https://beam.apache.org/releases/pydoc/BEAM_VERSION/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult). # pylint: disable=line-too-long This type behaves very similarly to a Beam Row and fields can be accessed using dot notation. However, make sure to check the docs for your ModelHandler to see which fields its PredictionResult contains or if it returns a different object altogether.

For example:

- type: RunInference
  config:
    model_handler:
      type: ModelHandler
      config:
        param_1: arg1
        param_2: arg2
        postprocess:
          callable: |
            def fn(x: PredictionResult):
              return beam.Row(x.example, x.inference, x.model_id)
        ...

The above example demonstrates converting the original output data type (in this case it is PredictionResult), and converts to a Beam Row, which allows for easier mapping in a later transform.

### File-based pre/postprocessing functions

For both preprocessing and postprocessing, it is also possible to specify a Python UDF (User-defined function) file that contains the function. This is possible by specifying the path to the file (local file or GCS path) and the name of the function in the file.

For example:

- type: RunInference
  config:
    model_handler:
      type: ModelHandler
      config:
        param_1: arg1
        param_2: arg2
        preprocess:
          path: gs://my-bucket/path/to/preprocess.py
          name: my_preprocess_fn
        postprocess:
          path: gs://my-bucket/path/to/postprocess.py
          name: my_postprocess_fn
        ...
Parameters:
  • model_handler

    Specifies the parameters for the respective enrichment_handler in a YAML/JSON format. To see the full set of handler_config parameters, see their corresponding doc pages:

  • inference_tag – The tag to use for the returned inference. Default is ‘inference’.

  • inference_args – Extra arguments for models whose inference call requires extra parameters. Make sure to check the underlying ModelHandler docs to see which args are allowed.

apache_beam.yaml.yaml_ml.ml_transform(write_artifact_location: str | None = None, read_artifact_location: str | None = None, transforms: list[Any] | None = None)[source]