Source code for apache_beam.ml.inference.utils

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file

"""
Util/helper functions used in apache_beam.ml.inference.
"""
import os
from functools import partial
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Union

import apache_beam as beam
from apache_beam.io.fileio import EmptyMatchTreatment
from apache_beam.io.fileio import MatchContinuously
from apache_beam.ml.inference.base import ModelMetadata
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.transforms import trigger
from apache_beam.transforms import window
from apache_beam.transforms.userstate import CombiningValueStateSpec
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import Timestamp

_START_TIME_STAMP = Timestamp.now()


def _convert_to_result(
    batch: Iterable,
    predictions: Union[Iterable, Dict[Any, Iterable]],
    model_id: Optional[str] = None,
) -> Iterable[PredictionResult]:
  if isinstance(predictions, dict):
    # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
    # key_type2: Iterable<val_type2>, ...} where each Iterable is of
    # length batch_size, to a list of dictionaries:
    # [{key_type1: value_type1, key_type2: value_type2}]
    predictions_per_tensor = [
        dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
    ]
    return [
        PredictionResult(x, y, model_id) for x,
        y in zip(batch, predictions_per_tensor)
    ]
  return [PredictionResult(x, y, model_id) for x, y in zip(batch, predictions)]


class _ConvertIterToSingleton(beam.DoFn):
  """
  Internal only; No backwards compatibility.

  The MatchContinuously transform examines all files present in a given
  directory and returns those that have timestamps older than the
  pipeline's start time. This can produce an Iterable rather than a
  Singleton. This class only returns the file path when it is first
  encountered, and it is cached as part of the side input caching mechanism.
  If the path is seen again, it will not return anything.
  By doing this, we can ensure that the output of this transform can be wrapped
  with beam.pvalue.AsSingleton().
  """
  COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)

  def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
    counter = count_state.read()
    if counter == 0:
      count_state.add(1)
      yield element[1]


class _GetLatestFileByTimeStamp(beam.DoFn):
  """
  Internal only; No backwards compatibility.

  This DoFn checks the timestamps of files against the time that the pipeline
  began running. It returns the files that were modified after the pipeline
  started. If no such files are found, it returns a default file as fallback.
   """
  TIME_STATE = CombiningValueStateSpec(
      'max', combine_fn=partial(max, default=_START_TIME_STAMP))

  def process(self, element, time_state=beam.DoFn.StateParam(TIME_STATE)):
    _, file_metadata = element
    new_ts = file_metadata.last_updated_in_seconds
    old_ts = time_state.read()
    if new_ts > old_ts:
      time_state.clear()
      time_state.add(new_ts)
      model_path = file_metadata.path
    else:
      model_path = ''

    model_name = os.path.splitext(os.path.basename(model_path))[0]
    return [
        (model_path, ModelMetadata(model_id=model_path, model_name=model_name))
    ]


[docs]class WatchFilePattern(beam.PTransform): def __init__( self, file_pattern, interval=360, stop_timestamp=MAX_TIMESTAMP, ): """ Watches a directory for updates to files matching a given file pattern. Args: file_pattern: The file path to read from as a local file path or a GCS ``gs://`` path. The path can contain glob characters (``*``, ``?``, and ``[...]`` sets). interval: Interval at which to check for files matching file_pattern in seconds. stop_timestamp: Timestamp after which no more files will be checked. **Note**: 1. Any previously used filenames cannot be reused. If a file is added or updated to a previously used filename, this transform will ignore that update. To trigger a model update, always upload a file with unique name. 2. Initially, before the pipeline startup time, WatchFilePattern expects at least one file present that matches the file_pattern. 3. This transform is supported in streaming mode since MatchContinuously produces an unbounded source. Running in batch mode can lead to undesired results or result in pipeline being stuck. """ self.file_pattern = file_pattern self.interval = interval self.stop_timestamp = stop_timestamp
[docs] def expand(self, pcoll) -> beam.PCollection[ModelMetadata]: return ( pcoll | 'MatchContinuously' >> MatchContinuously( file_pattern=self.file_pattern, interval=self.interval, stop_timestamp=self.stop_timestamp, empty_match_treatment=EmptyMatchTreatment.DISALLOW) | "AttachKey" >> beam.Map(lambda x: (x.path, x)) | "GetLatestFileMetaData" >> beam.ParDo(_GetLatestFileByTimeStamp()) | "AcceptNewSideInputOnly" >> beam.ParDo(_ConvertIterToSingleton()) | 'ApplyGlobalWindow' >> beam.transforms.WindowInto( window.GlobalWindows(), trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)), accumulation_mode=trigger.AccumulationMode.DISCARDING))