Source code for apache_beam.ml.transforms.base

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pytype: skip-file

import abc
from typing import Dict
from typing import Generic
from typing import List
from typing import Optional
from typing import Sequence
from typing import TypeVar

import apache_beam as beam

__all__ = ['MLTransform', 'ProcessHandler', 'BaseOperation']

TransformedDatasetT = TypeVar('TransformedDatasetT')
TransformedMetadataT = TypeVar('TransformedMetadataT')

# Input/Output types to the MLTransform.
ExampleT = TypeVar('ExampleT')
MLTransformOutputT = TypeVar('MLTransformOutputT')

# Input to the apply() method of BaseOperation.
OperationInputT = TypeVar('OperationInputT')
# Output of the apply() method of BaseOperation.
OperationOutputT = TypeVar('OperationOutputT')


class ArtifactMode(object):
  PRODUCE = 'produce'
  CONSUME = 'consume'


[docs]class BaseOperation(Generic[OperationInputT, OperationOutputT], abc.ABC): def __init__(self, columns: List[str]) -> None: """ Base Opertation class data processing transformations. Args: columns: List of column names to apply the transformation. """ self.columns = columns
[docs] @abc.abstractmethod def apply_transform(self, data: OperationInputT, output_column_name: str) -> Dict[str, OperationOutputT]: """ Define any processing logic in the apply_transform() method. processing logics are applied on inputs and returns a transformed output. Args: inputs: input data. """
[docs] @abc.abstractmethod def get_artifacts( self, data: OperationInputT, output_column_prefix: str) -> Optional[Dict[str, OperationOutputT]]: """ If the operation generates any artifacts, they can be returned from this method. """ pass
def __call__(self, data: OperationInputT, output_column_name: str) -> Dict[str, OperationOutputT]: """ This method is called when the instance of the class is called. This method will invoke the apply() method of the class. """ transformed_data = self.apply_transform(data, output_column_name) artifacts = self.get_artifacts(data, output_column_name) if artifacts: transformed_data = {**transformed_data, **artifacts} return transformed_data
[docs]class ProcessHandler(Generic[ExampleT, MLTransformOutputT], abc.ABC): """ Only for internal use. No backwards compatibility guarantees. """
[docs] @abc.abstractmethod def process_data( self, pcoll: beam.PCollection[ExampleT] ) -> beam.PCollection[MLTransformOutputT]: """ Logic to process the data. This will be the entrypoint in beam.MLTransform to process incoming data. """
[docs] @abc.abstractmethod def append_transform(self, transform: BaseOperation): """ Append transforms to the ProcessHandler. """
[docs]class MLTransform(beam.PTransform[beam.PCollection[ExampleT], beam.PCollection[MLTransformOutputT]], Generic[ExampleT, MLTransformOutputT]): def __init__( self, *, write_artifact_location: Optional[str] = None, read_artifact_location: Optional[str] = None, transforms: Optional[Sequence[BaseOperation]] = None): """ MLTransform is a Beam PTransform that can be used to apply transformations to the data. MLTransform is used to wrap the data processing transforms provided by Apache Beam. MLTransform works in two modes: write and read. In the write mode, MLTransform will apply the transforms to the data and store the artifacts in the write_artifact_location. In the read mode, MLTransform will read the artifacts from the read_artifact_location and apply the transforms to the data. The artifact location should be a valid storage path where the artifacts can be written to or read from. Note that when consuming artifacts, it is not necessary to pass the transforms since they are inherently stored within the artifacts themselves. Args: write_artifact_location: A storage location for artifacts resulting from MLTransform. These artifacts include transformations applied to the dataset and generated values like min, max from ScaleTo01, and mean, var from ScaleToZScore. Artifacts are produced and written to this location when using `write_artifact_mode`. Later MLTransforms can reuse produced artifacts by setting `read_artifact_mode` instead of `write_artifact_mode`. The value assigned to `write_artifact_location` should be a valid storage directory that the artifacts from this transform can be written to. If no directory exists at this location, one will be created. This will overwrite any artifacts already in this location, so distinct locations should be used for each instance of MLTransform. Only one of write_artifact_location and read_artifact_location should be specified. read_artifact_location: A storage location to read artifacts resulting froma previous MLTransform. These artifacts include transformations applied to the dataset and generated values like min, max from ScaleTo01, and mean, var from ScaleToZScore. Note that when consuming artifacts, it is not necessary to pass the transforms since they are inherently stored within the artifacts themselves. The value assigned to `read_artifact_location` should be a valid storage path where the artifacts can be read from. Only one of write_artifact_location and read_artifact_location should be specified. transforms: A list of transforms to apply to the data. All the transforms are applied in the order they are specified. The input of the i-th transform is the output of the (i-1)-th transform. Multi-input transforms are not supported yet. """ if transforms: _ = [self._validate_transform(transform) for transform in transforms] if read_artifact_location and write_artifact_location: raise ValueError( 'Only one of read_artifact_location or write_artifact_location can ' 'be specified to initialize MLTransform') if not read_artifact_location and not write_artifact_location: raise ValueError( 'Either a read_artifact_location or write_artifact_location must be ' 'specified to initialize MLTransform') if read_artifact_location: artifact_location = read_artifact_location artifact_mode = ArtifactMode.CONSUME else: artifact_location = write_artifact_location # type: ignore[assignment] artifact_mode = ArtifactMode.PRODUCE # avoid circular import # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.ml.transforms.handlers import TFTProcessHandler # TODO: When new ProcessHandlers(eg: JaxProcessHandler) are introduced, # create a mapping between transforms and ProcessHandler since # ProcessHandler is not exposed to the user. process_handler: ProcessHandler = TFTProcessHandler( artifact_location=artifact_location, artifact_mode=artifact_mode, transforms=transforms) # type: ignore[arg-type] self._process_handler = process_handler
[docs] def expand( self, pcoll: beam.PCollection[ExampleT] ) -> beam.PCollection[MLTransformOutputT]: """ This is the entrypoint for the MLTransform. This method will invoke the process_data() method of the ProcessHandler instance to process the incoming data. process_data takes in a PCollection and applies the PTransforms necessary to process the data and returns a PCollection of transformed data. Args: pcoll: A PCollection of ExampleT type. Returns: A PCollection of MLTransformOutputT type. """ return self._process_handler.process_data(pcoll)
[docs] def with_transform(self, transform: BaseOperation): """ Add a transform to the MLTransform pipeline. Args: transform: A BaseOperation instance. Returns: A MLTransform instance. """ self._validate_transform(transform) self._process_handler.append_transform(transform) return self
def _validate_transform(self, transform): if not isinstance(transform, BaseOperation): raise TypeError( 'transform must be a subclass of BaseOperation. ' 'Got: %s instead.' % type(transform))