#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
"""Common utility class to help SDK harness to execute an SDF. """
import logging
import threading
from typing import TYPE_CHECKING
from typing import Any
from typing import NamedTuple
from typing import Optional
from typing import Tuple
from typing import Union
from apache_beam.transforms.core import WatermarkEstimatorProvider
from apache_beam.utils.timestamp import Duration
from apache_beam.utils.timestamp import Timestamp
from apache_beam.utils.windowed_value import WindowedValue
if TYPE_CHECKING:
from apache_beam.io.iobase import RestrictionProgress
from apache_beam.io.iobase import RestrictionTracker
from apache_beam.io.iobase import WatermarkEstimator
_LOGGER = logging.getLogger(__name__)
SplitResultPrimary = NamedTuple(
'SplitResultPrimary', [('primary_value', WindowedValue)])
SplitResultResidual = NamedTuple(
'SplitResultResidual',
[('residual_value', WindowedValue), ('current_watermark', Timestamp),
('deferred_timestamp', Optional[Duration])])
[docs]class ThreadsafeRestrictionTracker(object):
"""A thread-safe wrapper which wraps a `RestrictionTracker`.
This wrapper guarantees synchronization of modifying restrictions across
multi-thread.
"""
def __init__(self, restriction_tracker):
# type: (RestrictionTracker) -> None
from apache_beam.io.iobase import RestrictionTracker
if not isinstance(restriction_tracker, RestrictionTracker):
raise ValueError(
'Initialize ThreadsafeRestrictionTracker requires'
'RestrictionTracker.')
self._restriction_tracker = restriction_tracker
# Records an absolute timestamp when defer_remainder is called.
self._timestamp = None
self._lock = threading.RLock()
self._deferred_residual = None
self._deferred_timestamp = None # type: Optional[Union[Timestamp, Duration]]
[docs] def current_restriction(self):
with self._lock:
return self._restriction_tracker.current_restriction()
[docs] def try_claim(self, position):
with self._lock:
return self._restriction_tracker.try_claim(position)
[docs] def defer_remainder(self, deferred_time=None):
"""Performs self-checkpoint on current processing restriction with an
expected resuming time.
Self-checkpoint could happen during processing elements. When executing an
DoFn.process(), you may want to stop processing an element and resuming
later if current element has been processed quit a long time or you also
want to have some outputs from other elements. ``defer_remainder()`` can be
called on per element if needed.
Args:
deferred_time: A relative ``Duration`` that indicates the ideal time gap
between now and resuming, or an absolute ``Timestamp`` for resuming
execution time. If the time_delay is None, the deferred work will be
executed as soon as possible.
"""
# Record current time for calculating deferred_time later.
with self._lock:
self._timestamp = Timestamp.now()
if deferred_time and not isinstance(deferred_time, (Duration, Timestamp)):
raise ValueError(
'The timestamp of deter_remainder() should be a '
'Duration or a Timestamp, or None.')
self._deferred_timestamp = deferred_time
checkpoint = self.try_split(0)
if checkpoint:
_, self._deferred_residual = checkpoint
[docs] def check_done(self):
with self._lock:
return self._restriction_tracker.check_done()
[docs] def current_progress(self):
# type: () -> RestrictionProgress
with self._lock:
return self._restriction_tracker.current_progress()
[docs] def try_split(self, fraction_of_remainder):
with self._lock:
return self._restriction_tracker.try_split(fraction_of_remainder)
[docs] def deferred_status(self):
# type: () -> Optional[Tuple[Any, Duration]]
"""Returns deferred work which is produced by ``defer_remainder()``.
When there is a self-checkpoint performed, the system needs to fulfill the
DelayedBundleApplication with deferred_work for a ProcessBundleResponse.
The system calls this API to get deferred_residual with watermark together
to help the runner to schedule a future work.
Returns: (deferred_residual, time_delay) if having any residual, else None.
"""
if self._deferred_residual:
# If _deferred_timestamp is None, create Duration(0).
if not self._deferred_timestamp:
self._deferred_timestamp = Duration()
# If an absolute timestamp is provided, calculate the delta between
# the absoluted time and the time deferred_status() is called.
elif isinstance(self._deferred_timestamp, Timestamp):
self._deferred_timestamp = (self._deferred_timestamp - Timestamp.now())
# If a Duration is provided, the deferred time should be:
# provided duration - the spent time since the defer_remainder() is
# called.
elif isinstance(self._deferred_timestamp, Duration):
self._deferred_timestamp -= (Timestamp.now() - self._timestamp)
return self._deferred_residual, self._deferred_timestamp
return None
[docs] def is_bounded(self):
return self._restriction_tracker.is_bounded()
[docs]class RestrictionTrackerView(object):
"""A DoFn view of thread-safe RestrictionTracker.
The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only
exposes APIs that will be called by a ``DoFn.process()``. During execution
time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a
restriction_tracker.
"""
def __init__(self, threadsafe_restriction_tracker):
# type: (ThreadsafeRestrictionTracker) -> None
if not isinstance(threadsafe_restriction_tracker,
ThreadsafeRestrictionTracker):
raise ValueError(
'Initialize RestrictionTrackerView requires '
'ThreadsafeRestrictionTracker.')
self._threadsafe_restriction_tracker = threadsafe_restriction_tracker
[docs] def current_restriction(self):
return self._threadsafe_restriction_tracker.current_restriction()
[docs] def try_claim(self, position):
return self._threadsafe_restriction_tracker.try_claim(position)
[docs] def defer_remainder(self, deferred_time=None):
self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
[docs] def is_bounded(self):
self._threadsafe_restriction_tracker.is_bounded()
[docs]class ThreadsafeWatermarkEstimator(object):
"""A threadsafe wrapper which wraps a WatermarkEstimator with locking
mechanism to guarantee multi-thread safety.
"""
def __init__(self, watermark_estimator):
# type: (WatermarkEstimator) -> None
from apache_beam.io.iobase import WatermarkEstimator
if not isinstance(watermark_estimator, WatermarkEstimator):
raise ValueError('Initializing Threadsafe requires a WatermarkEstimator')
self._watermark_estimator = watermark_estimator
self._lock = threading.Lock()
def __getattr__(self, attr):
if hasattr(self._watermark_estimator, attr):
def method_wrapper(*args, **kw):
with self._lock:
return getattr(self._watermark_estimator, attr)(*args, **kw)
return method_wrapper
raise AttributeError(attr)
[docs] def get_estimator_state(self):
with self._lock:
return self._watermark_estimator.get_estimator_state()
[docs] def current_watermark(self):
# type: () -> Timestamp
with self._lock:
return self._watermark_estimator.current_watermark()
[docs] def observe_timestamp(self, timestamp):
# type: (Timestamp) -> None
if not isinstance(timestamp, Timestamp):
raise ValueError(
'Input of observe_timestamp should be a Timestamp '
'object')
with self._lock:
self._watermark_estimator.observe_timestamp(timestamp)
[docs]class NoOpWatermarkEstimatorProvider(WatermarkEstimatorProvider):
"""A WatermarkEstimatorProvider which creates NoOpWatermarkEstimator for the
framework.
"""
[docs] def initial_estimator_state(self, element, restriction):
return None
[docs] def create_watermark_estimator(self, estimator_state):
from apache_beam.io.iobase import WatermarkEstimator
class _NoOpWatermarkEstimator(WatermarkEstimator):
"""A No-op WatermarkEstimator which is provided for the framework if there
is no custom one.
"""
def observe_timestamp(self, timestamp):
pass
def current_watermark(self):
return None
def get_estimator_state(self):
return None
return _NoOpWatermarkEstimator()