Source code for apache_beam.ml.inference.base

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# TODO: https://github.com/apache/beam/issues/21822
# mypy: ignore-errors

"""An extensible run inference transform.

Users of this module can extend the ModelHandler class for any machine learning
framework. A ModelHandler implementation is a required parameter of
RunInference.

The transform will handle standard inference functionality like metric
collection, sharing model between threads and batching elements.
"""

import logging
import pickle
import sys
import time
from typing import Any
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TypeVar
from typing import Union

import apache_beam as beam
from apache_beam.utils import shared

try:
  # pylint: disable=wrong-import-order, wrong-import-position
  import resource
except ImportError:
  resource = None  # type: ignore[assignment]

_NANOSECOND_TO_MILLISECOND = 1_000_000
_NANOSECOND_TO_MICROSECOND = 1_000

ModelT = TypeVar('ModelT')
ExampleT = TypeVar('ExampleT')
PredictionT = TypeVar('PredictionT')
_INPUT_TYPE = TypeVar('_INPUT_TYPE')
_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE')
KeyT = TypeVar('KeyT')

PredictionResult = NamedTuple(
    'PredictionResult', [
        ('example', _INPUT_TYPE),
        ('inference', _OUTPUT_TYPE),
    ])


def _to_milliseconds(time_ns: int) -> int:
  return int(time_ns / _NANOSECOND_TO_MILLISECOND)


def _to_microseconds(time_ns: int) -> int:
  return int(time_ns / _NANOSECOND_TO_MICROSECOND)


[docs]class ModelHandler(Generic[ExampleT, PredictionT, ModelT]): """Has the ability to load and apply an ML model."""
[docs] def load_model(self) -> ModelT: """Loads and initializes a model for processing.""" raise NotImplementedError(type(self))
[docs] def run_inference( self, batch: Sequence[ExampleT], model: ModelT, inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]: """Runs inferences on a batch of examples. Args: batch: A sequence of examples or features. model: The model used to make inferences. Returns: An Iterable of Predictions. """ raise NotImplementedError(type(self))
[docs] def get_num_bytes(self, batch: Sequence[ExampleT]) -> int: """ Returns: The number of bytes of data for a batch. """ return len(pickle.dumps(batch))
[docs] def get_metrics_namespace(self) -> str: """ Returns: A namespace for metrics collected by RunInference transform. """ return 'RunInference'
[docs] def get_resource_hints(self) -> dict: """ Returns: Resource hints for the transform. """ return {}
[docs] def batch_elements_kwargs(self) -> Mapping[str, Any]: """ Returns: kwargs suitable for beam.BatchElements. """ return {}
[docs]class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[Tuple[KeyT, ExampleT], Tuple[KeyT, PredictionT], ModelT]): """A ModelHandler that takes keyed examples and returns keyed predictions. For example, if the original model was used with RunInference to take a PCollection[E] to a PCollection[P], this would take a PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], allowing one to associate the outputs with the inputs based on the key. """ def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): self._unkeyed = unkeyed
[docs] def load_model(self) -> ModelT: return self._unkeyed.load_model()
[docs] def run_inference( self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT, inference_args: Optional[Dict[str, Any]] = None ) -> Iterable[Tuple[KeyT, PredictionT]]: keys, unkeyed_batch = zip(*batch) return zip( keys, self._unkeyed.run_inference(unkeyed_batch, model, inference_args))
[docs] def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: keys, unkeyed_batch = zip(*batch) return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
[docs] def get_metrics_namespace(self) -> str: return self._unkeyed.get_metrics_namespace()
[docs] def get_resource_hints(self): return self._unkeyed.get_resource_hints()
[docs] def batch_elements_kwargs(self): return self._unkeyed.batch_elements_kwargs()
[docs]class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[Union[ExampleT, Tuple[KeyT, ExampleT]], Union[PredictionT, Tuple[KeyT, PredictionT]], ModelT]): """A ModelHandler that takes possibly keyed examples and returns possibly keyed predictions. For example, if the original model was used with RunInference to take a PCollection[E] to a PCollection[P], this would take either PCollection[E] to a PCollection[P] or PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], depending on the whether the elements happen to be tuples, allowing one to associate the outputs with the inputs based on the key. Note that this cannot be used if E happens to be a tuple type. In addition, either all examples should be keyed, or none of them. """ def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): self._unkeyed = unkeyed
[docs] def load_model(self) -> ModelT: return self._unkeyed.load_model()
[docs] def run_inference( self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]], model: ModelT, inference_args: Optional[Dict[str, Any]] = None ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]: # Really the input should be # Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]] # but there's not a good way to express (or check) that. if isinstance(batch[0], tuple): is_keyed = True keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] else: is_keyed = False unkeyed_batch = batch # type: ignore[assignment] unkeyed_results = self._unkeyed.run_inference( unkeyed_batch, model, inference_args) if is_keyed: return zip(keys, unkeyed_results) else: return unkeyed_results
[docs] def get_num_bytes( self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int: # MyPy can't follow the branching logic. if isinstance(batch[0], tuple): keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type] return len( pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) else: return self._unkeyed.get_num_bytes(batch) # type: ignore[arg-type]
[docs] def get_metrics_namespace(self) -> str: return self._unkeyed.get_metrics_namespace()
[docs] def get_resource_hints(self): return self._unkeyed.get_resource_hints()
[docs] def batch_elements_kwargs(self): return self._unkeyed.batch_elements_kwargs()
[docs]class RunInference(beam.PTransform[beam.PCollection[ExampleT], beam.PCollection[PredictionT]]): def __init__( self, model_handler: ModelHandler[ExampleT, PredictionT, Any], clock=time, inference_args: Optional[Dict[str, Any]] = None): """A transform that takes a PCollection of examples (or features) to be used on an ML model. It will then output inferences (or predictions) for those examples in a PCollection of PredictionResults, containing the input examples and output inferences. Models for supported frameworks can be loaded via a URI. Supported services can also be used. Args: model_handler: An implementation of ModelHandler. clock: A clock implementing time_ns. inference_args: Extra arguments for models whose inference call requires extra parameters. """ self._model_handler = model_handler self._inference_args = inference_args self._clock = clock # TODO(BEAM-14046): Add and link to help documentation.
[docs] @classmethod def create(cls, model_handler_provider, **kwargs): """Multi-language friendly constructor. This constructor can be used with fully_qualified_named_transform to initialize RunInference transform from PythonCallableSource provided by foreign SDKs. Args: model_handler_provider: A callable object that returns ModelHandler. kwargs: Keyword arguments for model_handler_provider. """ return cls(model_handler_provider(**kwargs))
# TODO(https://github.com/apache/beam/issues/21447): Add batch_size back off # in the case there are functional reasons large batch sizes cannot be # handled.
[docs] def expand( self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]: resource_hints = self._model_handler.get_resource_hints() return ( pcoll # TODO(https://github.com/apache/beam/issues/21440): Hook into the # batching DoFn APIs. | beam.BatchElements(**self._model_handler.batch_elements_kwargs()) | ( beam.ParDo( _RunInferenceDoFn(self._model_handler, self._clock), self._inference_args).with_resource_hints(**resource_hints)))
class _MetricsCollector: """A metrics collector that tracks ML related performance and memory usage.""" def __init__(self, namespace: str): # Metrics self._inference_counter = beam.metrics.Metrics.counter( namespace, 'num_inferences') self._inference_request_batch_size = beam.metrics.Metrics.distribution( namespace, 'inference_request_batch_size') self._inference_request_batch_byte_size = ( beam.metrics.Metrics.distribution( namespace, 'inference_request_batch_byte_size')) # Batch inference latency in microseconds. self._inference_batch_latency_micro_secs = ( beam.metrics.Metrics.distribution( namespace, 'inference_batch_latency_micro_secs')) self._model_byte_size = beam.metrics.Metrics.distribution( namespace, 'model_byte_size') # Model load latency in milliseconds. self._load_model_latency_milli_secs = beam.metrics.Metrics.distribution( namespace, 'load_model_latency_milli_secs') # Metrics cache self._load_model_latency_milli_secs_cache = None self._model_byte_size_cache = None def update_metrics_with_cache(self): if self._load_model_latency_milli_secs_cache is not None: self._load_model_latency_milli_secs.update( self._load_model_latency_milli_secs_cache) self._load_model_latency_milli_secs_cache = None if self._model_byte_size_cache is not None: self._model_byte_size.update(self._model_byte_size_cache) self._model_byte_size_cache = None def cache_load_model_metrics(self, load_model_latency_ms, model_byte_size): self._load_model_latency_milli_secs_cache = load_model_latency_ms self._model_byte_size_cache = model_byte_size def update( self, examples_count: int, examples_byte_size: int, latency_micro_secs: int): self._inference_batch_latency_micro_secs.update(latency_micro_secs) self._inference_counter.inc(examples_count) self._inference_request_batch_size.update(examples_count) self._inference_request_batch_byte_size.update(examples_byte_size) class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): """A DoFn implementation generic to frameworks.""" def __init__( self, model_handler: ModelHandler[ExampleT, PredictionT, Any], clock): self._model_handler = model_handler self._shared_model_handle = shared.Shared() self._clock = clock self._model = None def _load_model(self): def load(): """Function for constructing shared LoadedModel.""" memory_before = _get_current_process_memory_in_bytes() start_time = _to_milliseconds(self._clock.time_ns()) model = self._model_handler.load_model() end_time = _to_milliseconds(self._clock.time_ns()) memory_after = _get_current_process_memory_in_bytes() load_model_latency_ms = end_time - start_time model_byte_size = memory_after - memory_before self._metrics_collector.cache_load_model_metrics( load_model_latency_ms, model_byte_size) return model # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing # model. return self._shared_model_handle.acquire(load) def setup(self): self._metrics_collector = _MetricsCollector( self._model_handler.get_metrics_namespace()) self._model = self._load_model() def process(self, batch, inference_args): start_time = _to_microseconds(self._clock.time_ns()) result_generator = self._model_handler.run_inference( batch, self._model, inference_args) predictions = list(result_generator) end_time = _to_microseconds(self._clock.time_ns()) inference_latency = end_time - start_time num_bytes = self._model_handler.get_num_bytes(batch) num_elements = len(batch) self._metrics_collector.update(num_elements, num_bytes, inference_latency) return predictions def finish_bundle(self): # TODO(https://github.com/apache/beam/issues/21435): Figure out why there # is a cache. self._metrics_collector.update_metrics_with_cache() def _is_darwin() -> bool: return sys.platform == 'darwin' def _get_current_process_memory_in_bytes(): """ Returns: memory usage in bytes. """ if resource is not None: usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss if _is_darwin(): return usage return usage * 1024 else: logging.warning( 'Resource module is not available for current platform, ' 'memory usage cannot be fetched.') return 0