Source code for apache_beam.ml.inference.api

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# mypy: ignore-errors

from dataclasses import dataclass
from typing import Tuple
from typing import TypeVar
from typing import Union

import apache_beam as beam
from apache_beam.ml.inference import base

_K = TypeVar('_K')
_INPUT_TYPE = TypeVar('_INPUT_TYPE')
_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE')


[docs]@dataclass class PredictionResult: example: _INPUT_TYPE inference: _OUTPUT_TYPE
[docs]@beam.typehints.with_input_types(Union[_INPUT_TYPE, Tuple[_K, _INPUT_TYPE]]) @beam.typehints.with_output_types(Union[PredictionResult, Tuple[_K, PredictionResult]]) # pylint: disable=line-too-long class RunInference(beam.PTransform): """ A transform that takes a PCollection of examples (or features) to be used on an ML model. It will then output inferences (or predictions) for those examples in a PCollection of PredictionResults, containing the input examples and output inferences. If examples are paired with keys, it will output a tuple (key, PredictionResult) for each (key, example) input. Models for supported frameworks can be loaded via a URI. Supported services can also be used. TODO(BEAM-14046): Add and link to help documentation """ def __init__(self, model_loader: base.ModelLoader): self._model_loader = model_loader
[docs] def expand(self, pcoll: beam.PCollection) -> beam.PCollection: return pcoll | base.RunInference(self._model_loader)