Source code for apache_beam.transforms.trigger

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Support for Apache Beam triggers.

Triggers control when in processing time windows get emitted.
"""

# pytype: skip-file

import collections
import copy
import logging
import numbers
from abc import ABCMeta
from abc import abstractmethod
from collections import abc as collections_abc  # ambiguty with direct abc
from enum import Flag
from enum import auto
from itertools import zip_longest

from apache_beam.coders import coder_impl
from apache_beam.coders import observable
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import combiners
from apache_beam.transforms import core
from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import WindowedValue
from apache_beam.transforms.window import WindowFn
from apache_beam.utils import windowed_value
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import MIN_TIMESTAMP
from apache_beam.utils.timestamp import TIME_GRANULARITY

# AfterCount is experimental. No backwards compatibility guarantees.

__all__ = [
    'AccumulationMode',
    'TriggerFn',
    'DefaultTrigger',
    'AfterWatermark',
    'AfterProcessingTime',
    'AfterCount',
    'Repeatedly',
    'AfterAny',
    'AfterAll',
    'AfterEach',
    'OrFinally',
]

_LOGGER = logging.getLogger(__name__)


[docs]class AccumulationMode(object): """Controls what to do with data when a trigger fires multiple times.""" DISCARDING = beam_runner_api_pb2.AccumulationMode.DISCARDING ACCUMULATING = beam_runner_api_pb2.AccumulationMode.ACCUMULATING
# TODO(robertwb): Provide retractions of previous outputs. # RETRACTING = 3 class _StateTag(metaclass=ABCMeta): """An identifier used to store and retrieve typed, combinable state. The given tag must be unique for this step.""" def __init__(self, tag): self.tag = tag class _ReadModifyWriteStateTag(_StateTag): """StateTag pointing to an element.""" def __repr__(self): return 'ValueStateTag(%s)' % (self.tag) def with_prefix(self, prefix): return _ReadModifyWriteStateTag(prefix + self.tag) class _SetStateTag(_StateTag): """StateTag pointing to an element.""" def __repr__(self): return 'SetStateTag({tag})'.format(tag=self.tag) def with_prefix(self, prefix): return _SetStateTag(prefix + self.tag) class _CombiningValueStateTag(_StateTag): """StateTag pointing to an element, accumulated with a combiner. The given tag must be unique for this step. The given CombineFn will be applied (possibly incrementally and eagerly) when adding elements.""" # TODO(robertwb): Also store the coder (perhaps extracted from the combine_fn) def __init__(self, tag, combine_fn): super().__init__(tag) if not combine_fn: raise ValueError('combine_fn must be specified.') if not isinstance(combine_fn, core.CombineFn): combine_fn = core.CombineFn.from_callable(combine_fn) self.combine_fn = combine_fn def __repr__(self): return 'CombiningValueStateTag(%s, %s)' % (self.tag, self.combine_fn) def with_prefix(self, prefix): return _CombiningValueStateTag(prefix + self.tag, self.combine_fn) def without_extraction(self): class NoExtractionCombineFn(core.CombineFn): setup = self.combine_fn.setup create_accumulator = self.combine_fn.create_accumulator add_input = self.combine_fn.add_input merge_accumulators = self.combine_fn.merge_accumulators compact = self.combine_fn.compact extract_output = staticmethod(lambda x: x) teardown = self.combine_fn.teardown return _CombiningValueStateTag(self.tag, NoExtractionCombineFn()) class _ListStateTag(_StateTag): """StateTag pointing to a list of elements.""" def __repr__(self): return 'ListStateTag(%s)' % self.tag def with_prefix(self, prefix): return _ListStateTag(prefix + self.tag) class _WatermarkHoldStateTag(_StateTag): def __init__(self, tag, timestamp_combiner_impl): super().__init__(tag) self.timestamp_combiner_impl = timestamp_combiner_impl def __repr__(self): return 'WatermarkHoldStateTag(%s, %s)' % ( self.tag, self.timestamp_combiner_impl) def with_prefix(self, prefix): return _WatermarkHoldStateTag( prefix + self.tag, self.timestamp_combiner_impl) class DataLossReason(Flag): """Enum defining potential reasons that a trigger may cause data loss. These flags should only cover when the trigger is the cause, though windowing can be taken into account. For instance, AfterWatermark may not flag itself as finishing if the windowing doesn't allow lateness. """ # Trigger will never be the source of data loss. NO_POTENTIAL_LOSS = 0 # Trigger may finish. In this case, data that comes in after the trigger may # be lost. Example: AfterCount(1) will stop firing after the first element. MAY_FINISH = auto() # Deprecated: Beam will emit buffered data at GC time. Any other behavior # should be treated as a bug with the runner used. CONDITION_NOT_GUARANTEED = auto() # Convenience functions for checking if a flag is included. Each is equivalent # to `reason & flag == flag` def _IncludesMayFinish(reason): # type: (DataLossReason) -> bool return reason & DataLossReason.MAY_FINISH == DataLossReason.MAY_FINISH # pylint: disable=unused-argument # TODO(robertwb): Provisional API, Java likely to change as well.
[docs]class TriggerFn(metaclass=ABCMeta): """A TriggerFn determines when window (panes) are emitted. See https://beam.apache.org/documentation/programming-guide/#triggers """
[docs] @abstractmethod def on_element(self, element, window, context): """Called when a new element arrives in a window. Args: element: the element being added window: the window to which the element is being added context: a context (e.g. a TriggerContext instance) for managing state and setting timers """ pass
[docs] @abstractmethod def on_merge(self, to_be_merged, merge_result, context): """Called when multiple windows are merged. Args: to_be_merged: the set of windows to be merged merge_result: the window into which the windows are being merged context: a context (e.g. a TriggerContext instance) for managing state and setting timers """ pass
[docs] @abstractmethod def should_fire(self, time_domain, timestamp, window, context): """Whether this trigger should cause the window to fire. Args: time_domain: WATERMARK for event-time timers and REAL_TIME for processing-time timers. timestamp: for time_domain WATERMARK, it represents the watermark: (a lower bound on) the watermark of the system and for time_domain REAL_TIME, it represents the trigger: timestamp of the processing-time timer. window: the window whose trigger is being considered context: a context (e.g. a TriggerContext instance) for managing state and setting timers Returns: whether this trigger should cause a firing """ pass
[docs] @abstractmethod def has_ontime_pane(self): """Whether this trigger creates an empty pane even if there are no elements. Returns: True if this trigger guarantees that there will always be an ON_TIME pane even if there are no elements in that pane. """ pass
[docs] @abstractmethod def on_fire(self, watermark, window, context): """Called when a trigger actually fires. Args: watermark: (a lower bound on) the watermark of the system window: the window whose trigger is being fired context: a context (e.g. a TriggerContext instance) for managing state and setting timers Returns: whether this trigger is finished """ pass
[docs] @abstractmethod def reset(self, window, context): """Clear any state and timers used by this TriggerFn.""" pass
[docs] def may_lose_data(self, unused_windowing): # type: (core.Windowing) -> DataLossReason """Returns whether or not this trigger could cause data loss. A trigger can cause data loss in the following scenarios: * The trigger has a chance to finish. For instance, AfterWatermark() without a late trigger would cause all late data to be lost. This scenario is only accounted for if the windowing strategy allows late data. Otherwise, the trigger is not responsible for the data loss. Note that this only returns the potential for loss. It does not mean that there will be data loss. It also only accounts for loss related to the trigger, not other potential causes. Args: windowing: The Windowing that this trigger belongs to. It does not need to be the top-level trigger. Returns: The DataLossReason. If there is no potential loss, DataLossReason.NO_POTENTIAL_LOSS is returned. Otherwise, all the potential reasons are returned as a single value. """ # For backwards compatibility's sake, we're assuming the trigger is safe. return DataLossReason.NO_POTENTIAL_LOSS
# pylint: enable=unused-argument
[docs] @staticmethod def from_runner_api(proto, context): return { 'after_all': AfterAll, 'after_any': AfterAny, 'after_each': AfterEach, 'after_end_of_window': AfterWatermark, 'after_processing_time': AfterProcessingTime, # after_processing_time, after_synchronized_processing_time 'always': Always, 'default': DefaultTrigger, 'element_count': AfterCount, 'never': _Never, 'or_finally': OrFinally, 'repeat': Repeatedly, }[proto.WhichOneof('trigger')].from_runner_api(proto, context)
[docs] @abstractmethod def to_runner_api(self, unused_context): pass
[docs]class DefaultTrigger(TriggerFn): """Semantically Repeatedly(AfterWatermark()), but more optimized.""" def __init__(self): pass def __repr__(self): return 'DefaultTrigger()'
[docs] def on_element(self, element, window, context): context.set_timer(str(window), TimeDomain.WATERMARK, window.end)
[docs] def on_merge(self, to_be_merged, merge_result, context): for window in to_be_merged: context.clear_timer(str(window), TimeDomain.WATERMARK)
[docs] def should_fire(self, time_domain, watermark, window, context): if watermark >= window.end: # Explicitly clear the timer so that late elements are not emitted again # when the timer is fired. context.clear_timer(str(window), TimeDomain.WATERMARK) return watermark >= window.end
[docs] def on_fire(self, watermark, window, context): return False
[docs] def reset(self, window, context): context.clear_timer(str(window), TimeDomain.WATERMARK)
[docs] def may_lose_data(self, unused_windowing): return DataLossReason.NO_POTENTIAL_LOSS
def __eq__(self, other): return type(self) == type(other) def __hash__(self): return hash(type(self))
[docs] @staticmethod def from_runner_api(proto, context): return DefaultTrigger()
[docs] def to_runner_api(self, unused_context): return beam_runner_api_pb2.Trigger( default=beam_runner_api_pb2.Trigger.Default())
[docs] def has_ontime_pane(self): return True
[docs]class AfterProcessingTime(TriggerFn): """Fire exactly once after a specified delay from processing time. AfterProcessingTime is experimental. No backwards compatibility guarantees. """ def __init__(self, delay=0): """Initialize a processing time trigger with a delay in seconds.""" self.delay = delay def __repr__(self): return 'AfterProcessingTime(delay=%d)' % self.delay
[docs] def on_element(self, element, window, context): context.set_timer( '', TimeDomain.REAL_TIME, context.get_current_time() + self.delay)
[docs] def on_merge(self, to_be_merged, merge_result, context): # timers will be kept through merging pass
[docs] def should_fire(self, time_domain, timestamp, window, context): if time_domain == TimeDomain.REAL_TIME: return True
[docs] def on_fire(self, timestamp, window, context): return True
[docs] def reset(self, window, context): pass
[docs] def may_lose_data(self, unused_windowing): """AfterProcessingTime may finish.""" return DataLossReason.MAY_FINISH
[docs] @staticmethod def from_runner_api(proto, context): return AfterProcessingTime( delay=( proto.after_processing_time.timestamp_transforms[0].delay. delay_millis) // 1000)
[docs] def to_runner_api(self, context): delay_proto = beam_runner_api_pb2.TimestampTransform( delay=beam_runner_api_pb2.TimestampTransform.Delay( delay_millis=self.delay * 1000)) return beam_runner_api_pb2.Trigger( after_processing_time=beam_runner_api_pb2.Trigger.AfterProcessingTime( timestamp_transforms=[delay_proto]))
[docs] def has_ontime_pane(self): return False
class Always(TriggerFn): """Repeatedly invoke the given trigger, never finishing.""" def __init__(self): pass def __repr__(self): return 'Always' def __eq__(self, other): return type(self) == type(other) def __hash__(self): return 1 def on_element(self, element, window, context): pass def on_merge(self, to_be_merged, merge_result, context): pass def has_ontime_pane(self): return False def reset(self, window, context): pass def should_fire(self, time_domain, watermark, window, context): return True def on_fire(self, watermark, window, context): return False def may_lose_data(self, unused_windowing): """No potential loss, since the trigger always fires.""" return DataLossReason.NO_POTENTIAL_LOSS @staticmethod def from_runner_api(proto, context): return Always() def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( always=beam_runner_api_pb2.Trigger.Always()) class _Never(TriggerFn): """A trigger that never fires. Data may still be released at window closing. """ def __init__(self): pass def __repr__(self): return 'Never' def __eq__(self, other): return type(self) == type(other) def __hash__(self): return hash(type(self)) def on_element(self, element, window, context): pass def on_merge(self, to_be_merged, merge_result, context): pass def has_ontime_pane(self): False def reset(self, window, context): pass def should_fire(self, time_domain, watermark, window, context): return False def on_fire(self, watermark, window, context): return True def may_lose_data(self, unused_windowing): """No potential data loss. Though Never doesn't explicitly trigger, it still collects data on windowing closing. """ return DataLossReason.NO_POTENTIAL_LOSS @staticmethod def from_runner_api(proto, context): return _Never() def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( never=beam_runner_api_pb2.Trigger.Never())
[docs]class AfterWatermark(TriggerFn): """Fire exactly once when the watermark passes the end of the window. Args: early: if not None, a speculative trigger to repeatedly evaluate before the watermark passes the end of the window late: if not None, a speculative trigger to repeatedly evaluate after the watermark passes the end of the window """ LATE_TAG = _CombiningValueStateTag('is_late', any) def __init__(self, early=None, late=None): # TODO(zhoufek): Maybe don't wrap early/late if they are already Repeatedly self.early = Repeatedly(early) if early else None self.late = Repeatedly(late) if late else None def __repr__(self): qualifiers = [] if self.early: qualifiers.append('early=%s' % self.early.underlying) if self.late: qualifiers.append('late=%s' % self.late.underlying) return 'AfterWatermark(%s)' % ', '.join(qualifiers)
[docs] def is_late(self, context): return self.late and context.get_state(self.LATE_TAG)
[docs] def on_element(self, element, window, context): if self.is_late(context): self.late.on_element(element, window, NestedContext(context, 'late')) else: context.set_timer('', TimeDomain.WATERMARK, window.end) if self.early: self.early.on_element(element, window, NestedContext(context, 'early'))
[docs] def on_merge(self, to_be_merged, merge_result, context): # TODO(robertwb): Figure out whether the 'rewind' semantics could be used # here. if self.is_late(context): self.late.on_merge( to_be_merged, merge_result, NestedContext(context, 'late')) else: # Note: Timer clearing solely an optimization. for window in to_be_merged: if window.end != merge_result.end: context.clear_timer('', TimeDomain.WATERMARK) if self.early: self.early.on_merge( to_be_merged, merge_result, NestedContext(context, 'early'))
[docs] def should_fire(self, time_domain, watermark, window, context): if self.is_late(context): return self.late.should_fire( time_domain, watermark, window, NestedContext(context, 'late')) elif watermark >= window.end: # Explicitly clear the timer so that late elements are not emitted again # when the timer is fired. context.clear_timer('', TimeDomain.WATERMARK) return True elif self.early: return self.early.should_fire( time_domain, watermark, window, NestedContext(context, 'early')) return False
[docs] def on_fire(self, watermark, window, context): if self.is_late(context): return self.late.on_fire( watermark, window, NestedContext(context, 'late')) elif watermark >= window.end: context.add_state(self.LATE_TAG, True) return not self.late elif self.early: self.early.on_fire(watermark, window, NestedContext(context, 'early')) return False
[docs] def reset(self, window, context): if self.late: context.clear_state(self.LATE_TAG) if self.early: self.early.reset(window, NestedContext(context, 'early')) if self.late: self.late.reset(window, NestedContext(context, 'late'))
[docs] def may_lose_data(self, windowing): """May cause data loss if lateness allowed and no late trigger set.""" if windowing.allowed_lateness == 0: return DataLossReason.NO_POTENTIAL_LOSS if self.late is None: return DataLossReason.MAY_FINISH return self.late.may_lose_data(windowing)
def __eq__(self, other): return ( type(self) == type(other) and self.early == other.early and self.late == other.late) def __hash__(self): return hash((type(self), self.early, self.late))
[docs] @staticmethod def from_runner_api(proto, context): return AfterWatermark( early=TriggerFn.from_runner_api( proto.after_end_of_window.early_firings, context) if proto.after_end_of_window.HasField('early_firings') else None, late=TriggerFn.from_runner_api( proto.after_end_of_window.late_firings, context) if proto.after_end_of_window.HasField('late_firings') else None)
[docs] def to_runner_api(self, context): early_proto = self.early.underlying.to_runner_api( context) if self.early else None late_proto = self.late.underlying.to_runner_api( context) if self.late else None return beam_runner_api_pb2.Trigger( after_end_of_window=beam_runner_api_pb2.Trigger.AfterEndOfWindow( early_firings=early_proto, late_firings=late_proto))
[docs] def has_ontime_pane(self): return True
[docs]class AfterCount(TriggerFn): """Fire when there are at least count elements in this window pane. AfterCount is experimental. No backwards compatibility guarantees. """ COUNT_TAG = _CombiningValueStateTag('count', combiners.CountCombineFn()) def __init__(self, count): if not isinstance(count, numbers.Integral) or count < 1: raise ValueError("count (%d) must be a positive integer." % count) self.count = count def __repr__(self): return 'AfterCount(%s)' % self.count def __eq__(self, other): return type(self) == type(other) and self.count == other.count def __hash__(self): return hash(self.count)
[docs] def on_element(self, element, window, context): context.add_state(self.COUNT_TAG, 1)
[docs] def on_merge(self, to_be_merged, merge_result, context): # states automatically merged pass
[docs] def should_fire(self, time_domain, watermark, window, context): return context.get_state(self.COUNT_TAG) >= self.count
[docs] def on_fire(self, watermark, window, context): return True
[docs] def reset(self, window, context): context.clear_state(self.COUNT_TAG)
[docs] def may_lose_data(self, unused_windowing): """AfterCount may finish.""" return DataLossReason.MAY_FINISH
[docs] @staticmethod def from_runner_api(proto, unused_context): return AfterCount(proto.element_count.element_count)
[docs] def to_runner_api(self, unused_context): return beam_runner_api_pb2.Trigger( element_count=beam_runner_api_pb2.Trigger.ElementCount( element_count=self.count))
[docs] def has_ontime_pane(self): return False
[docs]class Repeatedly(TriggerFn): """Repeatedly invoke the given trigger, never finishing.""" def __init__(self, underlying): self.underlying = underlying def __repr__(self): return 'Repeatedly(%s)' % self.underlying def __eq__(self, other): return type(self) == type(other) and self.underlying == other.underlying def __hash__(self): return hash(self.underlying)
[docs] def on_element(self, element, window, context): self.underlying.on_element(element, window, context)
[docs] def on_merge(self, to_be_merged, merge_result, context): self.underlying.on_merge(to_be_merged, merge_result, context)
[docs] def should_fire(self, time_domain, watermark, window, context): return self.underlying.should_fire(time_domain, watermark, window, context)
[docs] def on_fire(self, watermark, window, context): if self.underlying.on_fire(watermark, window, context): self.underlying.reset(window, context) return False
[docs] def reset(self, window, context): self.underlying.reset(window, context)
[docs] def may_lose_data(self, windowing): """Repeatedly will run in a loop and pick up whatever is left at GC.""" return DataLossReason.NO_POTENTIAL_LOSS
[docs] @staticmethod def from_runner_api(proto, context): return Repeatedly( TriggerFn.from_runner_api(proto.repeat.subtrigger, context))
[docs] def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( repeat=beam_runner_api_pb2.Trigger.Repeat( subtrigger=self.underlying.to_runner_api(context)))
[docs] def has_ontime_pane(self): return self.underlying.has_ontime_pane()
class _ParallelTriggerFn(TriggerFn, metaclass=ABCMeta): def __init__(self, *triggers): self.triggers = triggers def __repr__(self): return '%s(%s)' % ( self.__class__.__name__, ', '.join(str(t) for t in self.triggers)) def __eq__(self, other): return type(self) == type(other) and self.triggers == other.triggers def __hash__(self): return hash(self.triggers) @abstractmethod def combine_op(self, trigger_results): pass def on_element(self, element, window, context): for ix, trigger in enumerate(self.triggers): trigger.on_element(element, window, self._sub_context(context, ix)) def on_merge(self, to_be_merged, merge_result, context): for ix, trigger in enumerate(self.triggers): trigger.on_merge( to_be_merged, merge_result, self._sub_context(context, ix)) def should_fire(self, time_domain, watermark, window, context): self._time_domain = time_domain return self.combine_op( trigger.should_fire( time_domain, watermark, window, self._sub_context(context, ix)) for ix, trigger in enumerate(self.triggers)) def on_fire(self, watermark, window, context): finished = [] for ix, trigger in enumerate(self.triggers): nested_context = self._sub_context(context, ix) if trigger.should_fire(TimeDomain.WATERMARK, watermark, window, nested_context): finished.append(trigger.on_fire(watermark, window, nested_context)) return self.combine_op(finished) def may_lose_data(self, windowing): may_finish = self.combine_op( _IncludesMayFinish(t.may_lose_data(windowing)) for t in self.triggers) return ( DataLossReason.MAY_FINISH if may_finish else DataLossReason.NO_POTENTIAL_LOSS) def reset(self, window, context): for ix, trigger in enumerate(self.triggers): trigger.reset(window, self._sub_context(context, ix)) @staticmethod def _sub_context(context, index): return NestedContext(context, '%d/' % index) @staticmethod def from_runner_api(proto, context): subtriggers = [ TriggerFn.from_runner_api(subtrigger, context) for subtrigger in proto.after_all.subtriggers or proto.after_any.subtriggers ] if proto.after_all.subtriggers: return AfterAll(*subtriggers) else: return AfterAny(*subtriggers) def to_runner_api(self, context): subtriggers = [ subtrigger.to_runner_api(context) for subtrigger in self.triggers ] if self.combine_op == all: return beam_runner_api_pb2.Trigger( after_all=beam_runner_api_pb2.Trigger.AfterAll( subtriggers=subtriggers)) elif self.combine_op == any: return beam_runner_api_pb2.Trigger( after_any=beam_runner_api_pb2.Trigger.AfterAny( subtriggers=subtriggers)) else: raise NotImplementedError(self) def has_ontime_pane(self): return any(t.has_ontime_pane() for t in self.triggers)
[docs]class AfterAny(_ParallelTriggerFn): """Fires when any subtrigger fires. Also finishes when any subtrigger finishes. """ combine_op = any
[docs]class AfterAll(_ParallelTriggerFn): """Fires when all subtriggers have fired. Also finishes when all subtriggers have finished. """ combine_op = all
[docs]class AfterEach(TriggerFn): INDEX_TAG = _CombiningValueStateTag( 'index', (lambda indices: 0 if not indices else max(indices))) def __init__(self, *triggers): self.triggers = triggers def __repr__(self): return '%s(%s)' % ( self.__class__.__name__, ', '.join(str(t) for t in self.triggers)) def __eq__(self, other): return type(self) == type(other) and self.triggers == other.triggers def __hash__(self): return hash(self.triggers)
[docs] def on_element(self, element, window, context): ix = context.get_state(self.INDEX_TAG) if ix < len(self.triggers): self.triggers[ix].on_element( element, window, self._sub_context(context, ix))
[docs] def on_merge(self, to_be_merged, merge_result, context): # This takes the furthest window on merging. # TODO(robertwb): Revisit this when merging windows logic is settled for # all possible merging situations. ix = context.get_state(self.INDEX_TAG) if ix < len(self.triggers): self.triggers[ix].on_merge( to_be_merged, merge_result, self._sub_context(context, ix))
[docs] def should_fire(self, time_domain, watermark, window, context): ix = context.get_state(self.INDEX_TAG) if ix < len(self.triggers): return self.triggers[ix].should_fire( time_domain, watermark, window, self._sub_context(context, ix))
[docs] def on_fire(self, watermark, window, context): ix = context.get_state(self.INDEX_TAG) if ix < len(self.triggers): if self.triggers[ix].on_fire(watermark, window, self._sub_context(context, ix)): ix += 1 context.add_state(self.INDEX_TAG, ix) return ix == len(self.triggers)
[docs] def reset(self, window, context): context.clear_state(self.INDEX_TAG) for ix, trigger in enumerate(self.triggers): trigger.reset(window, self._sub_context(context, ix))
[docs] def may_lose_data(self, windowing): """If all sub-triggers may finish, this may finish.""" may_finish = all( _IncludesMayFinish(t.may_lose_data(windowing)) for t in self.triggers) return ( DataLossReason.MAY_FINISH if may_finish else DataLossReason.NO_POTENTIAL_LOSS)
@staticmethod def _sub_context(context, index): return NestedContext(context, '%d/' % index)
[docs] @staticmethod def from_runner_api(proto, context): return AfterEach( *[ TriggerFn.from_runner_api(subtrigger, context) for subtrigger in proto.after_each.subtriggers ])
[docs] def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( after_each=beam_runner_api_pb2.Trigger.AfterEach( subtriggers=[ subtrigger.to_runner_api(context) for subtrigger in self.triggers ]))
[docs] def has_ontime_pane(self): return any(t.has_ontime_pane() for t in self.triggers)
[docs]class OrFinally(AfterAny):
[docs] @staticmethod def from_runner_api(proto, context): return OrFinally( TriggerFn.from_runner_api(proto.or_finally.main, context), # getattr is used as finally is a keyword in Python TriggerFn.from_runner_api( getattr(proto.or_finally, 'finally'), context))
[docs] def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( or_finally=beam_runner_api_pb2.Trigger.OrFinally( main=self.triggers[0].to_runner_api(context), # dict keyword argument is used as finally is a keyword in Python **{'finally': self.triggers[1].to_runner_api(context)}))
class TriggerContext(object): def __init__(self, outer, window, clock): self._outer = outer self._window = window self._clock = clock def get_current_time(self): return self._clock.time() def set_timer(self, name, time_domain, timestamp): self._outer.set_timer(self._window, name, time_domain, timestamp) def clear_timer(self, name, time_domain): self._outer.clear_timer(self._window, name, time_domain) def add_state(self, tag, value): self._outer.add_state(self._window, tag, value) def get_state(self, tag): return self._outer.get_state(self._window, tag) def clear_state(self, tag): return self._outer.clear_state(self._window, tag) class NestedContext(object): """Namespaced context useful for defining composite triggers.""" def __init__(self, outer, prefix): self._outer = outer self._prefix = prefix def get_current_time(self): return self._outer.get_current_time() def set_timer(self, name, time_domain, timestamp): self._outer.set_timer(self._prefix + name, time_domain, timestamp) def clear_timer(self, name, time_domain): self._outer.clear_timer(self._prefix + name, time_domain) def add_state(self, tag, value): self._outer.add_state(tag.with_prefix(self._prefix), value) def get_state(self, tag): return self._outer.get_state(tag.with_prefix(self._prefix)) def clear_state(self, tag): self._outer.clear_state(tag.with_prefix(self._prefix)) # pylint: disable=unused-argument class SimpleState(metaclass=ABCMeta): """Basic state storage interface used for triggering. Only timers must hold the watermark (by their timestamp). """ @abstractmethod def set_timer( self, window, name, time_domain, timestamp, dynamic_timer_tag=''): pass @abstractmethod def get_window(self, window_id): pass @abstractmethod def clear_timer(self, window, name, time_domain, dynamic_timer_tag=''): pass @abstractmethod def add_state(self, window, tag, value): pass @abstractmethod def get_state(self, window, tag): pass @abstractmethod def clear_state(self, window, tag): pass def at(self, window, clock): return NestedContext(TriggerContext(self, window, clock), 'trigger') class UnmergedState(SimpleState): """State suitable for use in TriggerDriver. This class must be implemented by each backend. """ @abstractmethod def set_global_state(self, tag, value): pass @abstractmethod def get_global_state(self, tag, default=None): pass # pylint: enable=unused-argument class MergeableStateAdapter(SimpleState): """Wraps an UnmergedState, tracking merged windows.""" # TODO(robertwb): A similar indirection could be used for sliding windows # or other window_fns when a single element typically belongs to many windows. WINDOW_IDS = _ReadModifyWriteStateTag('window_ids') def __init__(self, raw_state): self.raw_state = raw_state self.window_ids = self.raw_state.get_global_state(self.WINDOW_IDS, {}) self.counter = None def set_timer( self, window, name, time_domain, timestamp, dynamic_timer_tag=''): self.raw_state.set_timer( self._get_id(window), name, time_domain, timestamp, dynamic_timer_tag=dynamic_timer_tag) def clear_timer(self, window, name, time_domain, dynamic_timer_tag=''): for window_id in self._get_ids(window): self.raw_state.clear_timer( window_id, name, time_domain, dynamic_timer_tag=dynamic_timer_tag) def add_state(self, window, tag, value): if isinstance(tag, _ReadModifyWriteStateTag): raise ValueError( 'Merging requested for non-mergeable state tag: %r.' % tag) elif isinstance(tag, _CombiningValueStateTag): tag = tag.without_extraction() self.raw_state.add_state(self._get_id(window), tag, value) def get_state(self, window, tag): if isinstance(tag, _CombiningValueStateTag): original_tag, tag = tag, tag.without_extraction() values = [ self.raw_state.get_state(window_id, tag) for window_id in self._get_ids(window) ] if isinstance(tag, _ReadModifyWriteStateTag): raise ValueError( 'Merging requested for non-mergeable state tag: %r.' % tag) elif isinstance(tag, _CombiningValueStateTag): return original_tag.combine_fn.extract_output( original_tag.combine_fn.merge_accumulators(values)) elif isinstance(tag, _ListStateTag): return [v for vs in values for v in vs] elif isinstance(tag, _SetStateTag): return {v for vs in values for v in vs} elif isinstance(tag, _WatermarkHoldStateTag): return tag.timestamp_combiner_impl.combine_all(values) else: raise ValueError('Invalid tag.', tag) def clear_state(self, window, tag): for window_id in self._get_ids(window): self.raw_state.clear_state(window_id, tag) if tag is None: del self.window_ids[window] self._persist_window_ids() def merge(self, to_be_merged, merge_result): for window in to_be_merged: if window != merge_result: if window in self.window_ids: if merge_result in self.window_ids: merge_window_ids = self.window_ids[merge_result] else: merge_window_ids = self.window_ids[merge_result] = [] merge_window_ids.extend(self.window_ids.pop(window)) self._persist_window_ids() def known_windows(self): return list(self.window_ids) def get_window(self, window_id): for window, ids in self.window_ids.items(): if window_id in ids: return window raise ValueError('No window for %s' % window_id) def _get_id(self, window): if window in self.window_ids: return self.window_ids[window][0] window_id = self._get_next_counter() self.window_ids[window] = [window_id] self._persist_window_ids() return window_id def _get_ids(self, window): return self.window_ids.get(window, []) def _get_next_counter(self): if not self.window_ids: self.counter = 0 elif self.counter is None: self.counter = max(k for ids in self.window_ids.values() for k in ids) self.counter += 1 return self.counter def _persist_window_ids(self): self.raw_state.set_global_state(self.WINDOW_IDS, self.window_ids) def __repr__(self): return '\n\t'.join([repr(self.window_ids)] + repr(self.raw_state).split('\n')) def create_trigger_driver( windowing, is_batch=False, phased_combine_fn=None, clock=None): """Create the TriggerDriver for the given windowing and options.""" # TODO(BEAM-10149): Respect closing and on-time behaviors. # For batch, we should always fire once, no matter what. if is_batch and windowing.triggerfn == _Never(): windowing = copy.copy(windowing) windowing.triggerfn = Always() # TODO(robertwb): We can do more if we know elements are in timestamp # sorted order. if windowing.is_default() and is_batch: driver = BatchGlobalTriggerDriver() elif (windowing.windowfn == GlobalWindows() and (windowing.triggerfn in [AfterCount(1), Always()]) and is_batch): # Here we also just pass through all the values exactly once. driver = BatchGlobalTriggerDriver() else: driver = GeneralTriggerDriver(windowing, clock) if phased_combine_fn: # TODO(ccy): Refactor GeneralTriggerDriver to combine values eagerly using # the known phased_combine_fn here. driver = CombiningTriggerDriver(phased_combine_fn, driver) return driver class TriggerDriver(metaclass=ABCMeta): """Breaks a series of bundle and timer firings into window (pane)s.""" @abstractmethod def process_elements( self, state, windowed_values, output_watermark, input_watermark=MIN_TIMESTAMP): pass @abstractmethod def process_timer( self, window_id, name, time_domain, timestamp, state, input_watermark=None): pass def process_entire_key(self, key, windowed_values): # This state holds per-key, multi-window state. state = InMemoryUnmergedState() for wvalue in self.process_elements(state, windowed_values, MIN_TIMESTAMP, MIN_TIMESTAMP): yield wvalue.with_value((key, wvalue.value)) while state.timers: fired = state.get_and_clear_timers() for timer_window, (name, time_domain, fire_time, _) in fired: for wvalue in self.process_timer(timer_window, name, time_domain, fire_time, state): yield wvalue.with_value((key, wvalue.value)) class _UnwindowedValues(observable.ObservableMixin): """Exposes iterable of windowed values as iterable of unwindowed values.""" def __init__(self, windowed_values): super().__init__() self._windowed_values = windowed_values def __iter__(self): for wv in self._windowed_values: unwindowed_value = wv.value self.notify_observers(unwindowed_value) yield unwindowed_value def __repr__(self): return '<_UnwindowedValues of %s>' % self._windowed_values def __reduce__(self): return list, (list(self), ) def __eq__(self, other): if isinstance(other, collections_abc.Iterable): return all( a == b for a, b in zip_longest(self, other, fillvalue=object())) else: return NotImplemented def __hash__(self): return hash(tuple(self)) coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type( _UnwindowedValues) class BatchGlobalTriggerDriver(TriggerDriver): """Groups all received values together. """ GLOBAL_WINDOW_TUPLE = (GlobalWindow(), ) ONLY_FIRING = windowed_value.PaneInfo( is_first=True, is_last=True, timing=windowed_value.PaneInfoTiming.ON_TIME, index=0, nonspeculative_index=0) def process_elements( self, state, windowed_values, unused_output_watermark, unused_input_watermark=MIN_TIMESTAMP): yield WindowedValue( _UnwindowedValues(windowed_values), MIN_TIMESTAMP, self.GLOBAL_WINDOW_TUPLE, self.ONLY_FIRING) def process_timer( self, window_id, name, time_domain, timestamp, state, input_watermark=None): raise TypeError('Triggers never set or called for batch default windowing.') class CombiningTriggerDriver(TriggerDriver): """Uses a phased_combine_fn to process output of wrapped TriggerDriver.""" def __init__(self, phased_combine_fn, underlying): self.phased_combine_fn = phased_combine_fn self.underlying = underlying def process_elements( self, state, windowed_values, output_watermark, input_watermark=MIN_TIMESTAMP): uncombined = self.underlying.process_elements( state, windowed_values, output_watermark, input_watermark) for output in uncombined: yield output.with_value(self.phased_combine_fn.apply(output.value)) def process_timer( self, window_id, name, time_domain, timestamp, state, input_watermark=None): uncombined = self.underlying.process_timer( window_id, name, time_domain, timestamp, state, input_watermark) for output in uncombined: yield output.with_value(self.phased_combine_fn.apply(output.value)) class GeneralTriggerDriver(TriggerDriver): """Breaks a series of bundle and timer firings into window (pane)s. Suitable for all variants of Windowing. """ ELEMENTS = _ListStateTag('elements') TOMBSTONE = _CombiningValueStateTag('tombstone', combiners.CountCombineFn()) INDEX = _CombiningValueStateTag('index', combiners.CountCombineFn()) NONSPECULATIVE_INDEX = _CombiningValueStateTag( 'nonspeculative_index', combiners.CountCombineFn()) def __init__(self, windowing, clock): self.clock = clock self.allowed_lateness = windowing.allowed_lateness self.window_fn = windowing.windowfn self.timestamp_combiner_impl = TimestampCombiner.get_impl( windowing.timestamp_combiner, self.window_fn) # pylint: disable=invalid-name self.WATERMARK_HOLD = _WatermarkHoldStateTag( 'watermark', self.timestamp_combiner_impl) # pylint: enable=invalid-name self.trigger_fn = windowing.triggerfn self.accumulation_mode = windowing.accumulation_mode self.is_merging = True def process_elements( self, state, windowed_values, output_watermark, input_watermark=MIN_TIMESTAMP): if self.is_merging: state = MergeableStateAdapter(state) windows_to_elements = collections.defaultdict(list) for wv in windowed_values: for window in wv.windows: # ignore expired windows if input_watermark > window.end + self.allowed_lateness: continue windows_to_elements[window].append((wv.value, wv.timestamp)) # First handle merging. if self.is_merging: old_windows = set(state.known_windows()) all_windows = old_windows.union(list(windows_to_elements)) if all_windows != old_windows: merged_away = {} class TriggerMergeContext(WindowFn.MergeContext): def merge(_, to_be_merged, merge_result): # pylint: disable=no-self-argument for window in to_be_merged: if window != merge_result: merged_away[window] = merge_result # Clear state associated with PaneInfo since it is # not preserved across merges. state.clear_state(window, self.INDEX) state.clear_state(window, self.NONSPECULATIVE_INDEX) state.merge(to_be_merged, merge_result) # using the outer self argument. self.trigger_fn.on_merge( to_be_merged, merge_result, state.at(merge_result, self.clock)) self.window_fn.merge(TriggerMergeContext(all_windows)) merged_windows_to_elements = collections.defaultdict(list) for window, values in windows_to_elements.items(): while window in merged_away: window = merged_away[window] merged_windows_to_elements[window].extend(values) windows_to_elements = merged_windows_to_elements for window in merged_away: state.clear_state(window, self.WATERMARK_HOLD) # Next handle element adding. for window, elements in windows_to_elements.items(): if state.get_state(window, self.TOMBSTONE): continue # Add watermark hold. # TODO(ccy): Add late data and garbage-collection hold support. output_time = self.timestamp_combiner_impl.merge( window, ( element_output_time for element_output_time in ( self.timestamp_combiner_impl.assign_output_time( window, timestamp) for unused_value, timestamp in elements) if element_output_time >= output_watermark)) if output_time is not None: state.add_state(window, self.WATERMARK_HOLD, output_time) context = state.at(window, self.clock) for value, unused_timestamp in elements: state.add_state(window, self.ELEMENTS, value) self.trigger_fn.on_element(value, window, context) # Maybe fire this window. if self.trigger_fn.should_fire(TimeDomain.WATERMARK, input_watermark, window, context): finished = self.trigger_fn.on_fire(input_watermark, window, context) yield self._output(window, finished, state, output_watermark, False) def process_timer( self, window_id, unused_name, time_domain, timestamp, state, input_watermark=None): if input_watermark is None: input_watermark = timestamp if self.is_merging: state = MergeableStateAdapter(state) window = state.get_window(window_id) if state.get_state(window, self.TOMBSTONE): return if time_domain in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME): if not self.is_merging or window in state.known_windows(): context = state.at(window, self.clock) if self.trigger_fn.should_fire(time_domain, timestamp, window, context): finished = self.trigger_fn.on_fire(timestamp, window, context) yield self._output( window, finished, state, timestamp, time_domain == TimeDomain.WATERMARK) else: raise Exception('Unexpected time domain: %s' % time_domain) def _output(self, window, finished, state, output_watermark, maybe_ontime): """Output window and clean up if appropriate.""" index = state.get_state(window, self.INDEX) state.add_state(window, self.INDEX, 1) if output_watermark <= window.max_timestamp(): nonspeculative_index = -1 timing = windowed_value.PaneInfoTiming.EARLY if state.get_state(window, self.NONSPECULATIVE_INDEX): nonspeculative_index = state.get_state( window, self.NONSPECULATIVE_INDEX) state.add_state(window, self.NONSPECULATIVE_INDEX, 1) _LOGGER.warning( 'Watermark moved backwards in time ' 'or late data moved window end forward.') else: nonspeculative_index = state.get_state(window, self.NONSPECULATIVE_INDEX) state.add_state(window, self.NONSPECULATIVE_INDEX, 1) timing = ( windowed_value.PaneInfoTiming.ON_TIME if maybe_ontime and nonspeculative_index == 0 else windowed_value.PaneInfoTiming.LATE) pane_info = windowed_value.PaneInfo( index == 0, finished, timing, index, nonspeculative_index) values = state.get_state(window, self.ELEMENTS) if finished: # TODO(robertwb): allowed lateness state.clear_state(window, self.ELEMENTS) state.add_state(window, self.TOMBSTONE, 1) elif self.accumulation_mode == AccumulationMode.DISCARDING: state.clear_state(window, self.ELEMENTS) timestamp = state.get_state(window, self.WATERMARK_HOLD) if timestamp is None: # If no watermark hold was set, output at end of window. timestamp = window.max_timestamp() elif output_watermark < window.end and self.trigger_fn.has_ontime_pane(): # Hold the watermark in case there is an empty pane that needs to be fired # at the end of the window. pass else: state.clear_state(window, self.WATERMARK_HOLD) return WindowedValue(values, timestamp, (window, ), pane_info) class InMemoryUnmergedState(UnmergedState): """In-memory implementation of UnmergedState. Used for batch and testing. """ def __init__(self, defensive_copy=False): # TODO(robertwb): Clean defensive_copy. It is too expensive in production. self.timers = collections.defaultdict(dict) self.state = collections.defaultdict(lambda: collections.defaultdict(list)) self.global_state = {} self.defensive_copy = defensive_copy def copy(self): cloned_object = InMemoryUnmergedState(defensive_copy=self.defensive_copy) cloned_object.timers = copy.deepcopy(self.timers) cloned_object.global_state = copy.deepcopy(self.global_state) for window in self.state: for tag in self.state[window]: cloned_object.state[window][tag] = copy.copy(self.state[window][tag]) return cloned_object def set_global_state(self, tag, value): assert isinstance(tag, _ReadModifyWriteStateTag) if self.defensive_copy: value = copy.deepcopy(value) self.global_state[tag.tag] = value def get_global_state(self, tag, default=None): return self.global_state.get(tag.tag, default) def set_timer( self, window, name, time_domain, timestamp, dynamic_timer_tag=''): self.timers[window][(name, time_domain, dynamic_timer_tag)] = timestamp def clear_timer(self, window, name, time_domain, dynamic_timer_tag=''): self.timers[window].pop((name, time_domain, dynamic_timer_tag), None) if not self.timers[window]: del self.timers[window] def get_window(self, window_id): return window_id def add_state(self, window, tag, value): if self.defensive_copy: value = copy.deepcopy(value) if isinstance(tag, _ReadModifyWriteStateTag): self.state[window][tag.tag] = value elif isinstance(tag, _CombiningValueStateTag): # TODO(robertwb): Store merged accumulators. self.state[window][tag.tag].append(value) elif isinstance(tag, _ListStateTag): self.state[window][tag.tag].append(value) elif isinstance(tag, _SetStateTag): self.state[window][tag.tag].append(value) elif isinstance(tag, _WatermarkHoldStateTag): self.state[window][tag.tag].append(value) else: raise ValueError('Invalid tag.', tag) def get_state(self, window, tag): values = self.state[window][tag.tag] if isinstance(tag, _ReadModifyWriteStateTag): return values elif isinstance(tag, _CombiningValueStateTag): return tag.combine_fn.apply(values) elif isinstance(tag, _ListStateTag): return values elif isinstance(tag, _SetStateTag): return values elif isinstance(tag, _WatermarkHoldStateTag): return tag.timestamp_combiner_impl.combine_all(values) else: raise ValueError('Invalid tag.', tag) def clear_state(self, window, tag): self.state[window].pop(tag.tag, None) if not self.state[window]: self.state.pop(window, None) def get_timers( self, clear=False, watermark=MAX_TIMESTAMP, processing_time=None): """Gets expired timers and reports if there are any realtime timers set per state. Expiration is measured against the watermark for event-time timers, and against a wall clock for processing-time timers. """ expired = [] has_realtime_timer = False for window, timers in list(self.timers.items()): for (name, time_domain, dynamic_timer_tag), timestamp in list( timers.items()): if time_domain == TimeDomain.REAL_TIME: time_marker = processing_time has_realtime_timer = True elif time_domain == TimeDomain.WATERMARK: time_marker = watermark else: _LOGGER.error( 'TimeDomain error: No timers defined for time domain %s.', time_domain) if timestamp <= time_marker: expired.append( (window, (name, time_domain, timestamp, dynamic_timer_tag))) if clear: del timers[(name, time_domain, dynamic_timer_tag)] if not timers and clear: del self.timers[window] return expired, has_realtime_timer def get_and_clear_timers(self, watermark=MAX_TIMESTAMP): return self.get_timers(clear=True, watermark=watermark)[0] def get_earliest_hold(self): earliest_hold = MAX_TIMESTAMP for unused_window, tagged_states in self.state.items(): # TODO(BEAM-2519): currently, this assumes that the watermark hold tag is # named "watermark". This is currently only true because the only place # watermark holds are set is in the GeneralTriggerDriver, where we use # this name. We should fix this by allowing enumeration of the tag types # used in adding state. if 'watermark' in tagged_states and tagged_states['watermark']: hold = min(tagged_states['watermark']) - TIME_GRANULARITY earliest_hold = min(earliest_hold, hold) return earliest_hold def __repr__(self): state_str = '\n'.join( '%s: %s' % (key, dict(state)) for key, state in self.state.items()) return 'timers: %s\nstate: %s' % (dict(self.timers), state_str)