apache_beam.ml.inference.agent_development_kit module

ModelHandler for running agents built with the Google Agent Development Kit.

This module provides ADKAgentModelHandler, a Beam ModelHandler that wraps an ADK google.adk.agents.llm_agent.LlmAgent so it can be used with the RunInference transform.

Typical usage:

import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.agent_development_kit import ADKAgentModelHandler
from google.adk.agents import LlmAgent

agent = LlmAgent(
    name="my_agent",
    model="gemini-2.0-flash",
    instruction="You are a helpful assistant.",
)

with beam.Pipeline() as p:
    results = (
        p
        | beam.Create(["What is the capital of France?"])
        | RunInference(ADKAgentModelHandler(agent=agent))
    )

If your agent contains state that is not picklable (e.g. tool closures that capture unpicklable objects), pass a zero-arg factory callable instead:

handler = ADKAgentModelHandler(agent=lambda: LlmAgent(...))
class apache_beam.ml.inference.agent_development_kit.ADKAgentModelHandler(agent: Agent | Callable[[], Agent], app_name: str = 'beam_inference', session_service_factory: Callable[[], BaseSessionService] | None = None, *, min_batch_size: int | None = None, max_batch_size: int | None = None, max_batch_duration_secs: int | None = None, max_batch_weight: int | None = None, element_size_fn: Callable[[Any], int] | None = None, **kwargs)[source]

Bases: ModelHandler[str | Any, PredictionResult, Runner]

ModelHandler for running ADK agents with the Beam RunInference transform.

Accepts either a fully constructed google.adk.agents.Agent or a zero-arg factory callable that produces one. The factory form is useful when the agent contains state that is not picklable and therefore cannot be serialized alongside the pipeline graph.

Each call to run_inference() invokes the agent once per element in the batch. By default every invocation uses a fresh, isolated session (stateless). Stateful multi-turn conversations can be achieved by passing a session_id key inside inference_args; elements sharing the same session_id will continue the same conversation history.

Parameters:
  • agent – A pre-constructed Agent instance, or a zero-arg callable that returns one. The callable form defers agent construction to worker load_model time, which is useful when the agent cannot be serialized.

  • app_name – The ADK application name used to namespace sessions. Defaults to "beam_inference".

  • session_service_factory – Optional zero-arg callable returning a BaseSessionService. When None, an InMemorySessionService is created automatically.

  • min_batch_size – Optional minimum batch size.

  • max_batch_size – Optional maximum batch size.

  • max_batch_duration_secs – Optional maximum time to buffer a batch before emitting; used in streaming contexts.

  • max_batch_weight – Optional maximum total weight of a batch.

  • element_size_fn – Optional function that returns the size (weight) of an element.

load_model() Runner[source]

Instantiates the ADK Runner on the worker.

Resolves the agent (calling the factory if a callable was provided), then creates a Runner backed by the configured session service.

Returns:

A fully initialised Runner.

run_inference(batch: Sequence[str | Any], model: Runner, inference_args: dict[str, Any] | None = None) Iterable[PredictionResult][source]

Runs the ADK agent on each element in the batch.

Each element is sent to the agent as a new user turn. The final response text from the agent is returned as the inference field of a PredictionResult.

Parameters:
  • batch – A sequence of inputs, each of which is either a str (the user message text) or a google.genai.types.Content object (for richer multi-part messages).

  • model – The Runner returned by load_model().

  • inference_args

    Optional dict of extra arguments. Supported keys:

    • "session_id" (str): If supplied, all elements in this batch share this session ID, enabling stateful multi-turn conversations. If omitted, each element receives a unique auto- generated session ID.

    • "user_id" (str): The user identifier to pass to the runner. Defaults to "beam_user".

Returns:

An iterable of PredictionResult, one per input element.

get_metrics_namespace() str[source]