Source code for apache_beam.runners.direct.sdf_direct_runner

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""This module contains Splittable DoFn logic that is specific to DirectRunner.
"""

# pytype: skip-file

from __future__ import absolute_import

import uuid
from builtins import object
from threading import Lock
from threading import Timer
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Optional

import apache_beam as beam
from apache_beam import TimeDomain
from apache_beam import pvalue
from apache_beam.coders import typecoders
from apache_beam.io.iobase import RestrictionTracker
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PTransformOverride
from apache_beam.runners.common import DoFnContext
from apache_beam.runners.common import DoFnInvoker
from apache_beam.runners.common import DoFnSignature
from apache_beam.runners.common import OutputProcessor
from apache_beam.runners.direct.evaluation_context import DirectStepContext
from apache_beam.runners.direct.util import KeyedWorkItem
from apache_beam.runners.direct.watermark_manager import WatermarkManager
from apache_beam.runners.sdf_utils import NoOpWatermarkEstimatorProvider
from apache_beam.transforms.core import ParDo
from apache_beam.transforms.core import ProcessContinuation
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.trigger import _ValueStateTag
from apache_beam.utils.windowed_value import WindowedValue

if TYPE_CHECKING:
  from apache_beam.iobase import WatermarkEstimator


[docs]class SplittableParDoOverride(PTransformOverride): """A transform override for ParDo transformss of SplittableDoFns. Replaces the ParDo transform with a SplittableParDo transform that performs SDF specific logic. """
[docs] def matches(self, applied_ptransform): assert isinstance(applied_ptransform, AppliedPTransform) transform = applied_ptransform.transform if isinstance(transform, ParDo): signature = DoFnSignature(transform.fn) return signature.is_splittable_dofn()
[docs] def get_replacement_transform(self, ptransform): assert isinstance(ptransform, ParDo) do_fn = ptransform.fn signature = DoFnSignature(do_fn) if signature.is_splittable_dofn(): return SplittableParDo(ptransform) else: return ptransform
[docs]class SplittableParDo(PTransform): """A transform that processes a PCollection using a Splittable DoFn.""" def __init__(self, ptransform): assert isinstance(ptransform, ParDo) self._ptransform = ptransform
[docs] def expand(self, pcoll): sdf = self._ptransform.fn signature = DoFnSignature(sdf) restriction_coder = signature.get_restriction_coder() element_coder = typecoders.registry.get_coder(pcoll.element_type) keyed_elements = ( pcoll | 'pair' >> ParDo(PairWithRestrictionFn(sdf)) | 'split' >> ParDo(SplitRestrictionFn(sdf)) | 'explode' >> ParDo(ExplodeWindowsFn()) | 'random' >> ParDo(RandomUniqueKeyFn())) return keyed_elements | ProcessKeyedElements( sdf, element_coder, restriction_coder, pcoll.windowing, self._ptransform.args, self._ptransform.kwargs, self._ptransform.side_inputs)
[docs]class ElementAndRestriction(object): """A holder for an element and a restriction.""" def __init__(self, element, restriction): self.element = element self.restriction = restriction
[docs]class PairWithRestrictionFn(beam.DoFn): """A transform that pairs each element with a restriction.""" def __init__(self, do_fn): self._do_fn = do_fn
[docs] def start_bundle(self): signature = DoFnSignature(self._do_fn) self._invoker = DoFnInvoker.create_invoker( signature, output_processor=_NoneShallPassOutputProcessor(), process_invocation=False)
[docs] def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs): initial_restriction = self._invoker.invoke_initial_restriction(element) yield ElementAndRestriction(element, initial_restriction)
[docs]class SplitRestrictionFn(beam.DoFn): """A transform that perform initial splitting of Splittable DoFn inputs.""" def __init__(self, do_fn): self._do_fn = do_fn
[docs] def start_bundle(self): signature = DoFnSignature(self._do_fn) self._invoker = DoFnInvoker.create_invoker( signature, output_processor=_NoneShallPassOutputProcessor(), process_invocation=False)
[docs] def process(self, element_and_restriction, *args, **kwargs): element = element_and_restriction.element restriction = element_and_restriction.restriction restriction_parts = self._invoker.invoke_split(element, restriction) for part in restriction_parts: yield ElementAndRestriction(element, part)
[docs]class ExplodeWindowsFn(beam.DoFn): """A transform that forces the runner to explode windows. This is done to make sure that Splittable DoFn proceses an element for each of the windows that element belongs to. """
[docs] def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs): yield element
[docs]class RandomUniqueKeyFn(beam.DoFn): """A transform that assigns a unique key to each element."""
[docs] def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs): # We ignore UUID collisions here since they are extremely rare. yield (uuid.uuid4().bytes, element)
[docs]class ProcessKeyedElements(PTransform): """A primitive transform that performs SplittableDoFn magic. Input to this transform should be a PCollection of keyed ElementAndRestriction objects. """ def __init__( self, sdf, element_coder, restriction_coder, windowing_strategy, ptransform_args, ptransform_kwargs, ptransform_side_inputs): self.sdf = sdf self.element_coder = element_coder self.restriction_coder = restriction_coder self.windowing_strategy = windowing_strategy self.ptransform_args = ptransform_args self.ptransform_kwargs = ptransform_kwargs self.ptransform_side_inputs = ptransform_side_inputs
[docs] def expand(self, pcoll): return pvalue.PCollection.from_(pcoll)
[docs]class ProcessKeyedElementsViaKeyedWorkItemsOverride(PTransformOverride): """A transform override for ProcessElements transform."""
[docs] def matches(self, applied_ptransform): return isinstance(applied_ptransform.transform, ProcessKeyedElements)
[docs] def get_replacement_transform(self, ptransform): return ProcessKeyedElementsViaKeyedWorkItems(ptransform)
[docs]class ProcessKeyedElementsViaKeyedWorkItems(PTransform): """A transform that processes Splittable DoFn input via KeyedWorkItems.""" def __init__(self, process_keyed_elements_transform): self._process_keyed_elements_transform = process_keyed_elements_transform
[docs] def expand(self, pcoll): process_elements = ProcessElements(self._process_keyed_elements_transform) process_elements.args = ( self._process_keyed_elements_transform.ptransform_args) process_elements.kwargs = ( self._process_keyed_elements_transform.ptransform_kwargs) process_elements.side_inputs = ( self._process_keyed_elements_transform.ptransform_side_inputs) return pcoll | beam.core.GroupByKey() | process_elements
[docs]class ProcessElements(PTransform): """A primitive transform for processing keyed elements or KeyedWorkItems. Will be evaluated by `runners.direct.transform_evaluator._ProcessElementsEvaluator`. """ def __init__(self, process_keyed_elements_transform): self._process_keyed_elements_transform = process_keyed_elements_transform self.sdf = self._process_keyed_elements_transform.sdf
[docs] def expand(self, pcoll): return pvalue.PCollection.from_(pcoll)
[docs] def new_process_fn(self, sdf): return ProcessFn( sdf, self._process_keyed_elements_transform.ptransform_args, self._process_keyed_elements_transform.ptransform_kwargs)
[docs]class ProcessFn(beam.DoFn): """A `DoFn` that executes machineary for invoking a Splittable `DoFn`. Input to the `ParDo` step that includes a `ProcessFn` will be a `PCollection` of `ElementAndRestriction` objects. This class is mainly responsible for following. (1) setup environment for properly invoking a Splittable `DoFn`. (2) invoke `process()` method of a Splittable `DoFn`. (3) after the `process()` invocation of the Splittable `DoFn`, determine if a re-invocation of the element is needed. If this is the case, set state and a timer for a re-invocation and hold output watermark till this re-invocation. (4) after the final invocation of a given element clear any previous state set for re-invoking the element and release the output watermark. """ def __init__(self, sdf, args_for_invoker, kwargs_for_invoker): self.sdf = sdf self._element_tag = _ValueStateTag('element') self._restriction_tag = _ValueStateTag('restriction') self.watermark_hold_tag = _ValueStateTag('watermark_hold') self._process_element_invoker = None self._output_processor = _OutputProcessor() self.sdf_invoker = DoFnInvoker.create_invoker( DoFnSignature(self.sdf), context=DoFnContext('unused_context'), output_processor=self._output_processor, input_args=args_for_invoker, input_kwargs=kwargs_for_invoker) self._step_context = None @property def step_context(self): return self._step_context @step_context.setter def step_context(self, step_context): assert isinstance(step_context, DirectStepContext) self._step_context = step_context
[docs] def set_process_element_invoker(self, process_element_invoker): assert isinstance(process_element_invoker, SDFProcessElementInvoker) self._process_element_invoker = process_element_invoker
[docs] def process( self, element, timestamp=beam.DoFn.TimestampParam, window=beam.DoFn.WindowParam, *args, **kwargs): if isinstance(element, KeyedWorkItem): # Must be a timer firing. key = element.encoded_key else: key, values = element values = list(values) assert len(values) == 1 # Value here will either be a WindowedValue or an ElementAndRestriction # object. # TODO: handle key collisions here. assert len(values) == 1, 'Internal error. Processing of splittable ' \ 'DoFn cannot continue since elements did not ' \ 'have unique keys.' value = values[0] if len(values) != 1: raise ValueError('') state = self._step_context.get_keyed_state(key) element_state = state.get_state(window, self._element_tag) # Initially element_state is an empty list. is_seed_call = not element_state if not is_seed_call: element = state.get_state(window, self._element_tag) restriction = state.get_state(window, self._restriction_tag) windowed_element = WindowedValue(element, timestamp, [window]) else: # After values iterator is expanded above we should have gotten a list # with a single ElementAndRestriction object. assert isinstance(value, ElementAndRestriction) element_and_restriction = value element = element_and_restriction.element restriction = element_and_restriction.restriction if isinstance(value, WindowedValue): windowed_element = WindowedValue( element, value.timestamp, value.windows) else: windowed_element = WindowedValue(element, timestamp, [window]) tracker = self.sdf_invoker.invoke_create_tracker(restriction) assert self._process_element_invoker assert isinstance(self._process_element_invoker, SDFProcessElementInvoker) output_values = self._process_element_invoker.invoke_process_element( self.sdf_invoker, self._output_processor, windowed_element, tracker, *args, **kwargs) sdf_result = None for output in output_values: if isinstance(output, SDFProcessElementInvoker.Result): # SDFProcessElementInvoker.Result should be the last item yielded. sdf_result = output break yield output assert sdf_result, ('SDFProcessElementInvoker must return a ' 'SDFProcessElementInvoker.Result object as the last ' 'value of a SDF invoke_process_element() invocation.') if not sdf_result.residual_restriction: # All work for current residual and restriction pair is complete. state.clear_state(window, self._element_tag) state.clear_state(window, self._restriction_tag) # Releasing output watermark by setting it to positive infinity. state.add_state( window, self.watermark_hold_tag, WatermarkManager.WATERMARK_POS_INF) else: state.add_state(window, self._element_tag, element) state.add_state( window, self._restriction_tag, sdf_result.residual_restriction) # Holding output watermark by setting it to negative infinity. state.add_state( window, self.watermark_hold_tag, WatermarkManager.WATERMARK_NEG_INF) # Setting a timer to be reinvoked to continue processing the element. # Currently Python SDK only supports setting timers based on watermark. So # forcing a reinvocation by setting a timer for watermark negative # infinity. # TODO(chamikara): update this by setting a timer for the proper # processing time when Python SDK supports that. state.set_timer( window, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_NEG_INF)
[docs]class SDFProcessElementInvoker(object): """A utility that invokes SDF `process()` method and requests checkpoints. This class is responsible for invoking the `process()` method of a Splittable `DoFn` and making sure that invocation terminated properly. Based on the input configuration, this class may decide to request a checkpoint for a `process()` execution so that runner can process current output and resume the invocation at a later time. More specifically, when initializing a `SDFProcessElementInvoker`, caller may specify the number of output elements or processing time after which a checkpoint should be requested. This class is responsible for properly requesting a checkpoint based on either of these criteria. When the `process()` call of Splittable `DoFn` ends, this class performs validations to make sure that processing ended gracefully and returns a `SDFProcessElementInvoker.Result` that contains information which can be used by the caller to perform another `process()` invocation for the residual. A `process()` invocation may decide to give up processing voluntarily by returning a `ProcessContinuation` object (see documentation of `ProcessContinuation` for more details). So if a 'ProcessContinuation' is produced this class ends the execution and performs steps to finalize the current invocation. """
[docs] class Result(object): def __init__( self, residual_restriction=None, process_continuation=None, future_output_watermark=None): """Returned as a result of a `invoke_process_element()` invocation. Args: residual_restriction: a restriction for the unprocessed part of the element. process_continuation: a `ProcessContinuation` if one was returned as the last element of the SDF `process()` invocation. future_output_watermark: output watermark of the results that will be produced when invoking the Splittable `DoFn` for the current element with `residual_restriction`. """ self.residual_restriction = residual_restriction self.process_continuation = process_continuation self.future_output_watermark = future_output_watermark
def __init__(self, max_num_outputs, max_duration): self._max_num_outputs = max_num_outputs self._max_duration = max_duration self._checkpoint_lock = Lock()
[docs] def test_method(self): raise ValueError
[docs] def invoke_process_element( self, sdf_invoker, output_processor, element, tracker, *args, **kwargs): """Invokes `process()` method of a Splittable `DoFn` for a given element. Args: sdf_invoker: a `DoFnInvoker` for the Splittable `DoFn`. element: the element to process tracker: a `RestrictionTracker` for the element that will be passed when invoking the `process()` method of the Splittable `DoFn`. Returns: a `SDFProcessElementInvoker.Result` object. """ assert isinstance(sdf_invoker, DoFnInvoker) assert isinstance(tracker, RestrictionTracker) class CheckpointState(object): def __init__(self): self.checkpointed = None self.residual_restriction = None checkpoint_state = CheckpointState() def initiate_checkpoint(): with self._checkpoint_lock: if checkpoint_state.checkpointed: return checkpoint_state.residual_restriction = tracker.checkpoint() checkpoint_state.checkpointed = object() output_processor.reset() noop_estimator = ( NoOpWatermarkEstimatorProvider().create_watermark_estimator(None)) Timer(self._max_duration, initiate_checkpoint).start() sdf_invoker.invoke_process( element, restriction_tracker=tracker, watermark_estimator=noop_estimator, additional_args=args, additional_kwargs=kwargs) assert output_processor.output_iter is not None output_count = 0 # We have to expand and re-yield here to support ending execution for a # given number of output elements as well as to capture the # ProcessContinuation of one was returned. process_continuation = None for output in output_processor.output_iter: # A ProcessContinuation, if returned, should be the last element. assert not process_continuation if isinstance(output, ProcessContinuation): # Taking a checkpoint so that we can determine primary and residual # restrictions. initiate_checkpoint() # A ProcessContinuation should always be the last element produced by # the output iterator. # TODO: support continuing after the specified amount of delay. # Continuing here instead of breaking to enforce that this is the last # element. process_continuation = output continue yield output output_count += 1 if self._max_num_outputs and output_count >= self._max_num_outputs: initiate_checkpoint() tracker.check_done() result = ( SDFProcessElementInvoker.Result( residual_restriction=checkpoint_state.residual_restriction) if checkpoint_state.residual_restriction else SDFProcessElementInvoker.Result()) yield result
class _OutputProcessor(OutputProcessor): def __init__(self): self.output_iter = None def process_outputs( self, windowed_input_element, output_iter, watermark_estimator=None): # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None self.output_iter = output_iter def reset(self): self.output_iter = None class _NoneShallPassOutputProcessor(OutputProcessor): def process_outputs( self, windowed_input_element, output_iter, watermark_estimator=None): # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None raise RuntimeError()