#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This module contains Splittable DoFn logic that is specific to DirectRunner.
"""
# pytype: skip-file
from __future__ import absolute_import
import uuid
from builtins import object
from threading import Lock
from threading import Timer
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Optional
import apache_beam as beam
from apache_beam import TimeDomain
from apache_beam import pvalue
from apache_beam.coders import typecoders
from apache_beam.io.iobase import RestrictionTracker
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PTransformOverride
from apache_beam.runners.common import DoFnContext
from apache_beam.runners.common import DoFnInvoker
from apache_beam.runners.common import DoFnSignature
from apache_beam.runners.common import OutputProcessor
from apache_beam.runners.direct.evaluation_context import DirectStepContext
from apache_beam.runners.direct.util import KeyedWorkItem
from apache_beam.runners.direct.watermark_manager import WatermarkManager
from apache_beam.runners.sdf_utils import NoOpWatermarkEstimatorProvider
from apache_beam.transforms.core import ParDo
from apache_beam.transforms.core import ProcessContinuation
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.trigger import _ValueStateTag
from apache_beam.utils.windowed_value import WindowedValue
if TYPE_CHECKING:
from apache_beam.iobase import WatermarkEstimator
[docs]class SplittableParDoOverride(PTransformOverride):
"""A transform override for ParDo transformss of SplittableDoFns.
Replaces the ParDo transform with a SplittableParDo transform that performs
SDF specific logic.
"""
[docs] def matches(self, applied_ptransform):
assert isinstance(applied_ptransform, AppliedPTransform)
transform = applied_ptransform.transform
if isinstance(transform, ParDo):
signature = DoFnSignature(transform.fn)
return signature.is_splittable_dofn()
[docs]class SplittableParDo(PTransform):
"""A transform that processes a PCollection using a Splittable DoFn."""
def __init__(self, ptransform):
assert isinstance(ptransform, ParDo)
self._ptransform = ptransform
[docs] def expand(self, pcoll):
sdf = self._ptransform.fn
signature = DoFnSignature(sdf)
restriction_coder = signature.get_restriction_coder()
element_coder = typecoders.registry.get_coder(pcoll.element_type)
keyed_elements = (
pcoll
| 'pair' >> ParDo(PairWithRestrictionFn(sdf))
| 'split' >> ParDo(SplitRestrictionFn(sdf))
| 'explode' >> ParDo(ExplodeWindowsFn())
| 'random' >> ParDo(RandomUniqueKeyFn()))
return keyed_elements | ProcessKeyedElements(
sdf,
element_coder,
restriction_coder,
pcoll.windowing,
self._ptransform.args,
self._ptransform.kwargs,
self._ptransform.side_inputs)
[docs]class ElementAndRestriction(object):
"""A holder for an element and a restriction."""
def __init__(self, element, restriction):
self.element = element
self.restriction = restriction
[docs]class PairWithRestrictionFn(beam.DoFn):
"""A transform that pairs each element with a restriction."""
def __init__(self, do_fn):
self._do_fn = do_fn
[docs] def start_bundle(self):
signature = DoFnSignature(self._do_fn)
self._invoker = DoFnInvoker.create_invoker(
signature,
output_processor=_NoneShallPassOutputProcessor(),
process_invocation=False)
[docs] def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
initial_restriction = self._invoker.invoke_initial_restriction(element)
yield ElementAndRestriction(element, initial_restriction)
[docs]class SplitRestrictionFn(beam.DoFn):
"""A transform that perform initial splitting of Splittable DoFn inputs."""
def __init__(self, do_fn):
self._do_fn = do_fn
[docs] def start_bundle(self):
signature = DoFnSignature(self._do_fn)
self._invoker = DoFnInvoker.create_invoker(
signature,
output_processor=_NoneShallPassOutputProcessor(),
process_invocation=False)
[docs] def process(self, element_and_restriction, *args, **kwargs):
element = element_and_restriction.element
restriction = element_and_restriction.restriction
restriction_parts = self._invoker.invoke_split(element, restriction)
for part in restriction_parts:
yield ElementAndRestriction(element, part)
[docs]class ExplodeWindowsFn(beam.DoFn):
"""A transform that forces the runner to explode windows.
This is done to make sure that Splittable DoFn proceses an element for each of
the windows that element belongs to.
"""
[docs] def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
yield element
[docs]class RandomUniqueKeyFn(beam.DoFn):
"""A transform that assigns a unique key to each element."""
[docs] def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
# We ignore UUID collisions here since they are extremely rare.
yield (uuid.uuid4().bytes, element)
[docs]class ProcessKeyedElements(PTransform):
"""A primitive transform that performs SplittableDoFn magic.
Input to this transform should be a PCollection of keyed ElementAndRestriction
objects.
"""
def __init__(
self,
sdf,
element_coder,
restriction_coder,
windowing_strategy,
ptransform_args,
ptransform_kwargs,
ptransform_side_inputs):
self.sdf = sdf
self.element_coder = element_coder
self.restriction_coder = restriction_coder
self.windowing_strategy = windowing_strategy
self.ptransform_args = ptransform_args
self.ptransform_kwargs = ptransform_kwargs
self.ptransform_side_inputs = ptransform_side_inputs
[docs] def expand(self, pcoll):
return pvalue.PCollection.from_(pcoll)
[docs]class ProcessKeyedElementsViaKeyedWorkItemsOverride(PTransformOverride):
"""A transform override for ProcessElements transform."""
[docs] def matches(self, applied_ptransform):
return isinstance(applied_ptransform.transform, ProcessKeyedElements)
[docs]class ProcessKeyedElementsViaKeyedWorkItems(PTransform):
"""A transform that processes Splittable DoFn input via KeyedWorkItems."""
def __init__(self, process_keyed_elements_transform):
self._process_keyed_elements_transform = process_keyed_elements_transform
[docs] def expand(self, pcoll):
process_elements = ProcessElements(self._process_keyed_elements_transform)
process_elements.args = (
self._process_keyed_elements_transform.ptransform_args)
process_elements.kwargs = (
self._process_keyed_elements_transform.ptransform_kwargs)
process_elements.side_inputs = (
self._process_keyed_elements_transform.ptransform_side_inputs)
return pcoll | beam.core.GroupByKey() | process_elements
[docs]class ProcessElements(PTransform):
"""A primitive transform for processing keyed elements or KeyedWorkItems.
Will be evaluated by
`runners.direct.transform_evaluator._ProcessElementsEvaluator`.
"""
def __init__(self, process_keyed_elements_transform):
self._process_keyed_elements_transform = process_keyed_elements_transform
self.sdf = self._process_keyed_elements_transform.sdf
[docs] def expand(self, pcoll):
return pvalue.PCollection.from_(pcoll)
[docs] def new_process_fn(self, sdf):
return ProcessFn(
sdf,
self._process_keyed_elements_transform.ptransform_args,
self._process_keyed_elements_transform.ptransform_kwargs)
[docs]class ProcessFn(beam.DoFn):
"""A `DoFn` that executes machineary for invoking a Splittable `DoFn`.
Input to the `ParDo` step that includes a `ProcessFn` will be a `PCollection`
of `ElementAndRestriction` objects.
This class is mainly responsible for following.
(1) setup environment for properly invoking a Splittable `DoFn`.
(2) invoke `process()` method of a Splittable `DoFn`.
(3) after the `process()` invocation of the Splittable `DoFn`, determine if a
re-invocation of the element is needed. If this is the case, set state and
a timer for a re-invocation and hold output watermark till this
re-invocation.
(4) after the final invocation of a given element clear any previous state set
for re-invoking the element and release the output watermark.
"""
def __init__(self, sdf, args_for_invoker, kwargs_for_invoker):
self.sdf = sdf
self._element_tag = _ValueStateTag('element')
self._restriction_tag = _ValueStateTag('restriction')
self.watermark_hold_tag = _ValueStateTag('watermark_hold')
self._process_element_invoker = None
self._output_processor = _OutputProcessor()
self.sdf_invoker = DoFnInvoker.create_invoker(
DoFnSignature(self.sdf),
context=DoFnContext('unused_context'),
output_processor=self._output_processor,
input_args=args_for_invoker,
input_kwargs=kwargs_for_invoker)
self._step_context = None
@property
def step_context(self):
return self._step_context
@step_context.setter
def step_context(self, step_context):
assert isinstance(step_context, DirectStepContext)
self._step_context = step_context
[docs] def set_process_element_invoker(self, process_element_invoker):
assert isinstance(process_element_invoker, SDFProcessElementInvoker)
self._process_element_invoker = process_element_invoker
[docs] def process(
self,
element,
timestamp=beam.DoFn.TimestampParam,
window=beam.DoFn.WindowParam,
*args,
**kwargs):
if isinstance(element, KeyedWorkItem):
# Must be a timer firing.
key = element.encoded_key
else:
key, values = element
values = list(values)
assert len(values) == 1
# Value here will either be a WindowedValue or an ElementAndRestriction
# object.
# TODO: handle key collisions here.
assert len(values) == 1, 'Internal error. Processing of splittable ' \
'DoFn cannot continue since elements did not ' \
'have unique keys.'
value = values[0]
if len(values) != 1:
raise ValueError('')
state = self._step_context.get_keyed_state(key)
element_state = state.get_state(window, self._element_tag)
# Initially element_state is an empty list.
is_seed_call = not element_state
if not is_seed_call:
element = state.get_state(window, self._element_tag)
restriction = state.get_state(window, self._restriction_tag)
windowed_element = WindowedValue(element, timestamp, [window])
else:
# After values iterator is expanded above we should have gotten a list
# with a single ElementAndRestriction object.
assert isinstance(value, ElementAndRestriction)
element_and_restriction = value
element = element_and_restriction.element
restriction = element_and_restriction.restriction
if isinstance(value, WindowedValue):
windowed_element = WindowedValue(
element, value.timestamp, value.windows)
else:
windowed_element = WindowedValue(element, timestamp, [window])
tracker = self.sdf_invoker.invoke_create_tracker(restriction)
assert self._process_element_invoker
assert isinstance(self._process_element_invoker, SDFProcessElementInvoker)
output_values = self._process_element_invoker.invoke_process_element(
self.sdf_invoker,
self._output_processor,
windowed_element,
tracker,
*args,
**kwargs)
sdf_result = None
for output in output_values:
if isinstance(output, SDFProcessElementInvoker.Result):
# SDFProcessElementInvoker.Result should be the last item yielded.
sdf_result = output
break
yield output
assert sdf_result, ('SDFProcessElementInvoker must return a '
'SDFProcessElementInvoker.Result object as the last '
'value of a SDF invoke_process_element() invocation.')
if not sdf_result.residual_restriction:
# All work for current residual and restriction pair is complete.
state.clear_state(window, self._element_tag)
state.clear_state(window, self._restriction_tag)
# Releasing output watermark by setting it to positive infinity.
state.add_state(
window, self.watermark_hold_tag, WatermarkManager.WATERMARK_POS_INF)
else:
state.add_state(window, self._element_tag, element)
state.add_state(
window, self._restriction_tag, sdf_result.residual_restriction)
# Holding output watermark by setting it to negative infinity.
state.add_state(
window, self.watermark_hold_tag, WatermarkManager.WATERMARK_NEG_INF)
# Setting a timer to be reinvoked to continue processing the element.
# Currently Python SDK only supports setting timers based on watermark. So
# forcing a reinvocation by setting a timer for watermark negative
# infinity.
# TODO(chamikara): update this by setting a timer for the proper
# processing time when Python SDK supports that.
state.set_timer(
window, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_NEG_INF)
[docs]class SDFProcessElementInvoker(object):
"""A utility that invokes SDF `process()` method and requests checkpoints.
This class is responsible for invoking the `process()` method of a Splittable
`DoFn` and making sure that invocation terminated properly. Based on the input
configuration, this class may decide to request a checkpoint for a `process()`
execution so that runner can process current output and resume the invocation
at a later time.
More specifically, when initializing a `SDFProcessElementInvoker`, caller may
specify the number of output elements or processing time after which a
checkpoint should be requested. This class is responsible for properly
requesting a checkpoint based on either of these criteria.
When the `process()` call of Splittable `DoFn` ends, this class performs
validations to make sure that processing ended gracefully and returns a
`SDFProcessElementInvoker.Result` that contains information which can be used
by the caller to perform another `process()` invocation for the residual.
A `process()` invocation may decide to give up processing voluntarily by
returning a `ProcessContinuation` object (see documentation of
`ProcessContinuation` for more details). So if a 'ProcessContinuation' is
produced this class ends the execution and performs steps to finalize the
current invocation.
"""
[docs] class Result(object):
def __init__(
self,
residual_restriction=None,
process_continuation=None,
future_output_watermark=None):
"""Returned as a result of a `invoke_process_element()` invocation.
Args:
residual_restriction: a restriction for the unprocessed part of the
element.
process_continuation: a `ProcessContinuation` if one was returned as the
last element of the SDF `process()` invocation.
future_output_watermark: output watermark of the results that will be
produced when invoking the Splittable `DoFn`
for the current element with
`residual_restriction`.
"""
self.residual_restriction = residual_restriction
self.process_continuation = process_continuation
self.future_output_watermark = future_output_watermark
def __init__(self, max_num_outputs, max_duration):
self._max_num_outputs = max_num_outputs
self._max_duration = max_duration
self._checkpoint_lock = Lock()
[docs] def test_method(self):
raise ValueError
[docs] def invoke_process_element(
self, sdf_invoker, output_processor, element, tracker, *args, **kwargs):
"""Invokes `process()` method of a Splittable `DoFn` for a given element.
Args:
sdf_invoker: a `DoFnInvoker` for the Splittable `DoFn`.
element: the element to process
tracker: a `RestrictionTracker` for the element that will be passed when
invoking the `process()` method of the Splittable `DoFn`.
Returns:
a `SDFProcessElementInvoker.Result` object.
"""
assert isinstance(sdf_invoker, DoFnInvoker)
assert isinstance(tracker, RestrictionTracker)
class CheckpointState(object):
def __init__(self):
self.checkpointed = None
self.residual_restriction = None
checkpoint_state = CheckpointState()
def initiate_checkpoint():
with self._checkpoint_lock:
if checkpoint_state.checkpointed:
return
checkpoint_state.residual_restriction = tracker.checkpoint()
checkpoint_state.checkpointed = object()
output_processor.reset()
noop_estimator = (
NoOpWatermarkEstimatorProvider().create_watermark_estimator(None))
Timer(self._max_duration, initiate_checkpoint).start()
sdf_invoker.invoke_process(
element,
restriction_tracker=tracker,
watermark_estimator=noop_estimator,
additional_args=args,
additional_kwargs=kwargs)
assert output_processor.output_iter is not None
output_count = 0
# We have to expand and re-yield here to support ending execution for a
# given number of output elements as well as to capture the
# ProcessContinuation of one was returned.
process_continuation = None
for output in output_processor.output_iter:
# A ProcessContinuation, if returned, should be the last element.
assert not process_continuation
if isinstance(output, ProcessContinuation):
# Taking a checkpoint so that we can determine primary and residual
# restrictions.
initiate_checkpoint()
# A ProcessContinuation should always be the last element produced by
# the output iterator.
# TODO: support continuing after the specified amount of delay.
# Continuing here instead of breaking to enforce that this is the last
# element.
process_continuation = output
continue
yield output
output_count += 1
if self._max_num_outputs and output_count >= self._max_num_outputs:
initiate_checkpoint()
tracker.check_done()
result = (
SDFProcessElementInvoker.Result(
residual_restriction=checkpoint_state.residual_restriction)
if checkpoint_state.residual_restriction else
SDFProcessElementInvoker.Result())
yield result
class _OutputProcessor(OutputProcessor):
def __init__(self):
self.output_iter = None
def process_outputs(
self, windowed_input_element, output_iter, watermark_estimator=None):
# type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None
self.output_iter = output_iter
def reset(self):
self.output_iter = None
class _NoneShallPassOutputProcessor(OutputProcessor):
def process_outputs(
self, windowed_input_element, output_iter, watermark_estimator=None):
# type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None
raise RuntimeError()