Source code for apache_beam.ml.inference.sklearn_inference

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import enum
import pickle
import sys
from typing import Any
from typing import Iterable
from typing import List

import numpy

from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.api import PredictionResult
from apache_beam.ml.inference.base import InferenceRunner
from apache_beam.ml.inference.base import ModelLoader

try:
  import joblib
except ImportError:
  # joblib is an optional dependency.
  pass


[docs]class ModelFileType(enum.Enum): PICKLE = 1 JOBLIB = 2
[docs]class SklearnInferenceRunner(InferenceRunner):
[docs] def run_inference(self, batch: List[numpy.ndarray], model: Any) -> Iterable[PredictionResult]: # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) predictions = model.predict(vectorized_batch) return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
[docs] def get_num_bytes(self, batch: List[numpy.ndarray]) -> int: """Returns the number of bytes of data for a batch.""" return sum(sys.getsizeof(element) for element in batch)
[docs]class SklearnModelLoader(ModelLoader): def __init__( self, model_file_type: ModelFileType = ModelFileType.PICKLE, model_uri: str = ''): self._model_file_type = model_file_type self._model_uri = model_uri self._inference_runner = SklearnInferenceRunner()
[docs] def load_model(self): """Loads and initializes a model for processing.""" file = FileSystems.open(self._model_uri, 'rb') if self._model_file_type == ModelFileType.PICKLE: return pickle.load(file) elif self._model_file_type == ModelFileType.JOBLIB: if not joblib: raise ImportError( 'Could not import joblib in this execution environment. ' 'For help with managing dependencies on Python workers.' 'see https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/' # pylint: disable=line-too-long ) return joblib.load(file) raise AssertionError('Unsupported serialization type.')
[docs] def get_inference_runner(self) -> SklearnInferenceRunner: return self._inference_runner