Source code for apache_beam.transforms.trigger

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Support for Dataflow triggers.

Triggers control when in processing time windows get emitted.
"""

import collections
import copy
import itertools
import logging
import numbers
from abc import ABCMeta
from abc import abstractmethod

from apache_beam.coders import observable
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import combiners
from apache_beam.transforms import core
from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import WindowedValue
from apache_beam.transforms.window import WindowFn
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import MIN_TIMESTAMP
from apache_beam.utils.timestamp import TIME_GRANULARITY

# AfterCount is experimental. No backwards compatibility guarantees.

__all__ = [
    'AccumulationMode',
    'TriggerFn',
    'DefaultTrigger',
    'AfterWatermark',
    'AfterProcessingTime',
    'AfterCount',
    'Repeatedly',
    'AfterAny',
    'AfterAll',
    'AfterEach',
    'OrFinally',
    ]


[docs]class AccumulationMode(object): """Controls what to do with data when a trigger fires multiple times. """ DISCARDING = beam_runner_api_pb2.AccumulationMode.DISCARDING ACCUMULATING = beam_runner_api_pb2.AccumulationMode.ACCUMULATING
# TODO(robertwb): Provide retractions of previous outputs. # RETRACTING = 3 class _StateTag(object): """An identifier used to store and retrieve typed, combinable state. The given tag must be unique for this stage. If CombineFn is None then all elements will be returned as a list, otherwise the given CombineFn will be applied (possibly incrementally and eagerly) when adding elements. """ __metaclass__ = ABCMeta def __init__(self, tag): self.tag = tag class _ValueStateTag(_StateTag): """StateTag pointing to an element.""" def __repr__(self): return 'ValueStateTag(%s)' % (self.tag) def with_prefix(self, prefix): return _ValueStateTag(prefix + self.tag) class _CombiningValueStateTag(_StateTag): """StateTag pointing to an element, accumulated with a combiner.""" # TODO(robertwb): Also store the coder (perhaps extracted from the combine_fn) def __init__(self, tag, combine_fn): super(_CombiningValueStateTag, self).__init__(tag) if not combine_fn: raise ValueError('combine_fn must be specified.') if not isinstance(combine_fn, core.CombineFn): combine_fn = core.CombineFn.from_callable(combine_fn) self.combine_fn = combine_fn def __repr__(self): return 'CombiningValueStateTag(%s, %s)' % (self.tag, self.combine_fn) def with_prefix(self, prefix): return _CombiningValueStateTag(prefix + self.tag, self.combine_fn) class _ListStateTag(_StateTag): """StateTag pointing to a list of elements.""" def __repr__(self): return 'ListStateTag(%s)' % self.tag def with_prefix(self, prefix): return _ListStateTag(prefix + self.tag) class _WatermarkHoldStateTag(_StateTag): def __init__(self, tag, timestamp_combiner_impl): super(_WatermarkHoldStateTag, self).__init__(tag) self.timestamp_combiner_impl = timestamp_combiner_impl def __repr__(self): return 'WatermarkHoldStateTag(%s, %s)' % (self.tag, self.timestamp_combiner_impl) def with_prefix(self, prefix): return _WatermarkHoldStateTag(prefix + self.tag, self.timestamp_combiner_impl) # pylint: disable=unused-argument # TODO(robertwb): Provisional API, Java likely to change as well.
[docs]class TriggerFn(object): """A TriggerFn determines when window (panes) are emitted. See https://beam.apache.org/documentation/programming-guide/#triggers """ __metaclass__ = ABCMeta
[docs] @abstractmethod def on_element(self, element, window, context): """Called when a new element arrives in a window. Args: element: the element being added window: the window to which the element is being added context: a context (e.g. a TriggerContext instance) for managing state and setting timers """ pass
[docs] @abstractmethod def on_merge(self, to_be_merged, merge_result, context): """Called when multiple windows are merged. Args: to_be_merged: the set of windows to be merged merge_result: the window into which the windows are being merged context: a context (e.g. a TriggerContext instance) for managing state and setting timers """ pass
[docs] @abstractmethod def should_fire(self, time_domain, timestamp, window, context): """Whether this trigger should cause the window to fire. Args: time_domain: WATERMARK for event-time timers and REAL_TIME for processing-time timers. timestamp: for time_domain WATERMARK, it represents the watermark: (a lower bound on) the watermark of the system and for time_domain REAL_TIME, it represents the trigger: timestamp of the processing-time timer. window: the window whose trigger is being considered context: a context (e.g. a TriggerContext instance) for managing state and setting timers Returns: whether this trigger should cause a firing """ pass
[docs] @abstractmethod def on_fire(self, watermark, window, context): """Called when a trigger actually fires. Args: watermark: (a lower bound on) the watermark of the system window: the window whose trigger is being fired context: a context (e.g. a TriggerContext instance) for managing state and setting timers Returns: whether this trigger is finished """ pass
[docs] @abstractmethod def reset(self, window, context): """Clear any state and timers used by this TriggerFn.""" pass
# pylint: enable=unused-argument
[docs] @staticmethod def from_runner_api(proto, context): return { 'after_all': AfterAll, 'after_any': AfterAny, 'after_each': AfterEach, 'after_end_of_window': AfterWatermark, 'after_processing_time': AfterProcessingTime, # after_processing_time, after_synchronized_processing_time # always 'default': DefaultTrigger, 'element_count': AfterCount, # never 'or_finally': OrFinally, 'repeat': Repeatedly, }[proto.WhichOneof('trigger')].from_runner_api(proto, context)
[docs] @abstractmethod def to_runner_api(self, unused_context): pass
[docs]class DefaultTrigger(TriggerFn): """Semantically Repeatedly(AfterWatermark()), but more optimized.""" def __init__(self): pass def __repr__(self): return 'DefaultTrigger()'
[docs] def on_element(self, element, window, context): context.set_timer('', TimeDomain.WATERMARK, window.end)
[docs] def on_merge(self, to_be_merged, merge_result, context): # Note: Timer clearing solely an optimization. for window in to_be_merged: if window.end != merge_result.end: context.clear_timer('', TimeDomain.WATERMARK)
[docs] def should_fire(self, time_domain, watermark, window, context): return watermark >= window.end
[docs] def on_fire(self, watermark, window, context): return False
[docs] def reset(self, window, context): context.clear_timer('', TimeDomain.WATERMARK)
def __eq__(self, other): return type(self) == type(other)
[docs] @staticmethod def from_runner_api(proto, context): return DefaultTrigger()
[docs] def to_runner_api(self, unused_context): return beam_runner_api_pb2.Trigger( default=beam_runner_api_pb2.Trigger.Default())
[docs]class AfterProcessingTime(TriggerFn): """Fire exactly once after a specified delay from processing time. AfterProcessingTime is experimental. No backwards compatibility guarantees. """ def __init__(self, delay=0): self.delay = delay def __repr__(self): return 'AfterProcessingTime(delay=%d)' % self.delay
[docs] def on_element(self, element, window, context): context.set_timer( '', TimeDomain.REAL_TIME, context.get_current_time() + self.delay)
[docs] def on_merge(self, to_be_merged, merge_result, context): # timers will be kept through merging pass
[docs] def should_fire(self, time_domain, timestamp, window, context): if time_domain == TimeDomain.REAL_TIME: return True
[docs] def on_fire(self, timestamp, window, context): return True
[docs] def reset(self, window, context): pass
[docs] @staticmethod def from_runner_api(proto, context): return AfterProcessingTime( delay=( proto.after_processing_time .timestamp_transforms[0] .delay .delay_millis))
[docs] def to_runner_api(self, context): delay_proto = beam_runner_api_pb2.TimestampTransform( delay=beam_runner_api_pb2.TimestampTransform.Delay( delay_millis=self.delay)) return beam_runner_api_pb2.Trigger( after_processing_time=beam_runner_api_pb2.Trigger.AfterProcessingTime( timestamp_transforms=[delay_proto]))
[docs]class AfterWatermark(TriggerFn): """Fire exactly once when the watermark passes the end of the window. Args: early: if not None, a speculative trigger to repeatedly evaluate before the watermark passes the end of the window late: if not None, a speculative trigger to repeatedly evaluate after the watermark passes the end of the window """ LATE_TAG = _CombiningValueStateTag('is_late', any) def __init__(self, early=None, late=None): self.early = Repeatedly(early) if early else None self.late = Repeatedly(late) if late else None def __repr__(self): qualifiers = [] if self.early: qualifiers.append('early=%s' % self.early.underlying) if self.late: qualifiers.append('late=%s' % self.late.underlying) return 'AfterWatermark(%s)' % ', '.join(qualifiers)
[docs] def is_late(self, context): return self.late and context.get_state(self.LATE_TAG)
[docs] def on_element(self, element, window, context): if self.is_late(context): self.late.on_element(element, window, NestedContext(context, 'late')) else: context.set_timer('', TimeDomain.WATERMARK, window.end) if self.early: self.early.on_element(element, window, NestedContext(context, 'early'))
[docs] def on_merge(self, to_be_merged, merge_result, context): # TODO(robertwb): Figure out whether the 'rewind' semantics could be used # here. if self.is_late(context): self.late.on_merge( to_be_merged, merge_result, NestedContext(context, 'late')) else: # Note: Timer clearing solely an optimization. for window in to_be_merged: if window.end != merge_result.end: context.clear_timer('', TimeDomain.WATERMARK) if self.early: self.early.on_merge( to_be_merged, merge_result, NestedContext(context, 'early'))
[docs] def should_fire(self, time_domain, watermark, window, context): if self.is_late(context): return self.late.should_fire(time_domain, watermark, window, NestedContext(context, 'late')) elif watermark >= window.end: return True elif self.early: return self.early.should_fire(time_domain, watermark, window, NestedContext(context, 'early')) return False
[docs] def on_fire(self, watermark, window, context): if self.is_late(context): return self.late.on_fire( watermark, window, NestedContext(context, 'late')) elif watermark >= window.end: context.add_state(self.LATE_TAG, True) return not self.late elif self.early: self.early.on_fire(watermark, window, NestedContext(context, 'early')) return False
[docs] def reset(self, window, context): if self.late: context.clear_state(self.LATE_TAG) if self.early: self.early.reset(window, NestedContext(context, 'early')) if self.late: self.late.reset(window, NestedContext(context, 'late'))
def __eq__(self, other): return (type(self) == type(other) and self.early == other.early and self.late == other.late) def __hash__(self): return hash((type(self), self.early, self.late))
[docs] @staticmethod def from_runner_api(proto, context): return AfterWatermark( early=TriggerFn.from_runner_api( proto.after_end_of_window.early_firings, context) if proto.after_end_of_window.HasField('early_firings') else None, late=TriggerFn.from_runner_api( proto.after_end_of_window.late_firings, context) if proto.after_end_of_window.HasField('late_firings') else None)
[docs] def to_runner_api(self, context): early_proto = self.early.underlying.to_runner_api( context) if self.early else None late_proto = self.late.underlying.to_runner_api( context) if self.late else None return beam_runner_api_pb2.Trigger( after_end_of_window=beam_runner_api_pb2.Trigger.AfterEndOfWindow( early_firings=early_proto, late_firings=late_proto))
[docs]class AfterCount(TriggerFn): """Fire when there are at least count elements in this window pane. AfterCount is experimental. No backwards compatibility guarantees. """ COUNT_TAG = _CombiningValueStateTag('count', combiners.CountCombineFn()) def __init__(self, count): if not isinstance(count, numbers.Integral) or count < 1: raise ValueError("count (%d) must be a positive integer." % count) self.count = count def __repr__(self): return 'AfterCount(%s)' % self.count def __eq__(self, other): return type(self) == type(other) and self.count == other.count
[docs] def on_element(self, element, window, context): context.add_state(self.COUNT_TAG, 1)
[docs] def on_merge(self, to_be_merged, merge_result, context): # states automatically merged pass
[docs] def should_fire(self, time_domain, watermark, window, context): return context.get_state(self.COUNT_TAG) >= self.count
[docs] def on_fire(self, watermark, window, context): return True
[docs] def reset(self, window, context): context.clear_state(self.COUNT_TAG)
[docs] @staticmethod def from_runner_api(proto, unused_context): return AfterCount(proto.element_count.element_count)
[docs] def to_runner_api(self, unused_context): return beam_runner_api_pb2.Trigger( element_count=beam_runner_api_pb2.Trigger.ElementCount( element_count=self.count))
[docs]class Repeatedly(TriggerFn): """Repeatedly invoke the given trigger, never finishing.""" def __init__(self, underlying): self.underlying = underlying def __repr__(self): return 'Repeatedly(%s)' % self.underlying def __eq__(self, other): return type(self) == type(other) and self.underlying == other.underlying
[docs] def on_element(self, element, window, context): self.underlying.on_element(element, window, context)
[docs] def on_merge(self, to_be_merged, merge_result, context): self.underlying.on_merge(to_be_merged, merge_result, context)
[docs] def should_fire(self, time_domain, watermark, window, context): return self.underlying.should_fire(time_domain, watermark, window, context)
[docs] def on_fire(self, watermark, window, context): if self.underlying.on_fire(watermark, window, context): self.underlying.reset(window, context) return False
[docs] def reset(self, window, context): self.underlying.reset(window, context)
[docs] @staticmethod def from_runner_api(proto, context): return Repeatedly( TriggerFn.from_runner_api(proto.repeat.subtrigger, context))
[docs] def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( repeat=beam_runner_api_pb2.Trigger.Repeat( subtrigger=self.underlying.to_runner_api(context)))
class _ParallelTriggerFn(TriggerFn): __metaclass__ = ABCMeta def __init__(self, *triggers): self.triggers = triggers def __repr__(self): return '%s(%s)' % (self.__class__.__name__, ', '.join(str(t) for t in self.triggers)) def __eq__(self, other): return type(self) == type(other) and self.triggers == other.triggers @abstractmethod def combine_op(self, trigger_results): pass def on_element(self, element, window, context): for ix, trigger in enumerate(self.triggers): trigger.on_element(element, window, self._sub_context(context, ix)) def on_merge(self, to_be_merged, merge_result, context): for ix, trigger in enumerate(self.triggers): trigger.on_merge( to_be_merged, merge_result, self._sub_context(context, ix)) def should_fire(self, time_domain, watermark, window, context): self._time_domain = time_domain return self.combine_op( trigger.should_fire(time_domain, watermark, window, self._sub_context(context, ix)) for ix, trigger in enumerate(self.triggers)) def on_fire(self, watermark, window, context): finished = [] for ix, trigger in enumerate(self.triggers): nested_context = self._sub_context(context, ix) if trigger.should_fire(TimeDomain.WATERMARK, watermark, window, nested_context): finished.append(trigger.on_fire(watermark, window, nested_context)) return self.combine_op(finished) def reset(self, window, context): for ix, trigger in enumerate(self.triggers): trigger.reset(window, self._sub_context(context, ix)) @staticmethod def _sub_context(context, index): return NestedContext(context, '%d/' % index) @staticmethod def from_runner_api(proto, context): subtriggers = [ TriggerFn.from_runner_api(subtrigger, context) for subtrigger in proto.after_all.subtriggers or proto.after_any.subtriggers] if proto.after_all.subtriggers: return AfterAll(*subtriggers) else: return AfterAny(*subtriggers) def to_runner_api(self, context): subtriggers = [ subtrigger.to_runner_api(context) for subtrigger in self.triggers] if self.combine_op == all: return beam_runner_api_pb2.Trigger( after_all=beam_runner_api_pb2.Trigger.AfterAll( subtriggers=subtriggers)) elif self.combine_op == any: return beam_runner_api_pb2.Trigger( after_any=beam_runner_api_pb2.Trigger.AfterAny( subtriggers=subtriggers)) else: raise NotImplementedError(self)
[docs]class AfterAny(_ParallelTriggerFn): """Fires when any subtrigger fires. Also finishes when any subtrigger finishes. """ combine_op = any
[docs]class AfterAll(_ParallelTriggerFn): """Fires when all subtriggers have fired. Also finishes when all subtriggers have finished. """ combine_op = all
[docs]class AfterEach(TriggerFn): INDEX_TAG = _CombiningValueStateTag('index', ( lambda indices: 0 if not indices else max(indices))) def __init__(self, *triggers): self.triggers = triggers def __repr__(self): return '%s(%s)' % (self.__class__.__name__, ', '.join(str(t) for t in self.triggers)) def __eq__(self, other): return type(self) == type(other) and self.triggers == other.triggers
[docs] def on_element(self, element, window, context): ix = context.get_state(self.INDEX_TAG) if ix < len(self.triggers): self.triggers[ix].on_element( element, window, self._sub_context(context, ix))
[docs] def on_merge(self, to_be_merged, merge_result, context): # This takes the furthest window on merging. # TODO(robertwb): Revisit this when merging windows logic is settled for # all possible merging situations. ix = context.get_state(self.INDEX_TAG) if ix < len(self.triggers): self.triggers[ix].on_merge( to_be_merged, merge_result, self._sub_context(context, ix))
[docs] def should_fire(self, time_domain, watermark, window, context): ix = context.get_state(self.INDEX_TAG) if ix < len(self.triggers): return self.triggers[ix].should_fire( time_domain, watermark, window, self._sub_context(context, ix))
[docs] def on_fire(self, watermark, window, context): ix = context.get_state(self.INDEX_TAG) if ix < len(self.triggers): if self.triggers[ix].on_fire( watermark, window, self._sub_context(context, ix)): ix += 1 context.add_state(self.INDEX_TAG, ix) return ix == len(self.triggers)
[docs] def reset(self, window, context): context.clear_state(self.INDEX_TAG) for ix, trigger in enumerate(self.triggers): trigger.reset(window, self._sub_context(context, ix))
@staticmethod def _sub_context(context, index): return NestedContext(context, '%d/' % index)
[docs] @staticmethod def from_runner_api(proto, context): return AfterEach(*[ TriggerFn.from_runner_api(subtrigger, context) for subtrigger in proto.after_each.subtriggers])
[docs] def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( after_each=beam_runner_api_pb2.Trigger.AfterEach( subtriggers=[ subtrigger.to_runner_api(context) for subtrigger in self.triggers]))
[docs]class OrFinally(AfterAny):
[docs] @staticmethod def from_runner_api(proto, context): return OrFinally( TriggerFn.from_runner_api(proto.or_finally.main, context), # getattr is used as finally is a keyword in Python TriggerFn.from_runner_api(getattr(proto.or_finally, 'finally'), context))
[docs] def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( or_finally=beam_runner_api_pb2.Trigger.OrFinally( main=self.triggers[0].to_runner_api(context), # dict keyword argument is used as finally is a keyword in Python **{'finally': self.triggers[1].to_runner_api(context)}))
class TriggerContext(object): def __init__(self, outer, window, clock): self._outer = outer self._window = window self._clock = clock def get_current_time(self): return self._clock.time() def set_timer(self, name, time_domain, timestamp): self._outer.set_timer(self._window, name, time_domain, timestamp) def clear_timer(self, name, time_domain): self._outer.clear_timer(self._window, name, time_domain) def add_state(self, tag, value): self._outer.add_state(self._window, tag, value) def get_state(self, tag): return self._outer.get_state(self._window, tag) def clear_state(self, tag): return self._outer.clear_state(self._window, tag) class NestedContext(object): """Namespaced context useful for defining composite triggers.""" def __init__(self, outer, prefix): self._outer = outer self._prefix = prefix def set_timer(self, name, time_domain, timestamp): self._outer.set_timer(self._prefix + name, time_domain, timestamp) def clear_timer(self, name, time_domain): self._outer.clear_timer(self._prefix + name, time_domain) def add_state(self, tag, value): self._outer.add_state(tag.with_prefix(self._prefix), value) def get_state(self, tag): return self._outer.get_state(tag.with_prefix(self._prefix)) def clear_state(self, tag): self._outer.clear_state(tag.with_prefix(self._prefix)) # pylint: disable=unused-argument class SimpleState(object): """Basic state storage interface used for triggering. Only timers must hold the watermark (by their timestamp). """ __metaclass__ = ABCMeta @abstractmethod def set_timer(self, window, name, time_domain, timestamp): pass @abstractmethod def get_window(self, window_id): pass @abstractmethod def clear_timer(self, window, name, time_domain): pass @abstractmethod def add_state(self, window, tag, value): pass @abstractmethod def get_state(self, window, tag): pass @abstractmethod def clear_state(self, window, tag): pass def at(self, window, clock=None): return TriggerContext(self, window, clock) class UnmergedState(SimpleState): """State suitable for use in TriggerDriver. This class must be implemented by each backend. """ @abstractmethod def set_global_state(self, tag, value): pass @abstractmethod def get_global_state(self, tag, default=None): pass # pylint: enable=unused-argument class MergeableStateAdapter(SimpleState): """Wraps an UnmergedState, tracking merged windows.""" # TODO(robertwb): A similar indirection could be used for sliding windows # or other window_fns when a single element typically belongs to many windows. WINDOW_IDS = _ValueStateTag('window_ids') def __init__(self, raw_state): self.raw_state = raw_state self.window_ids = self.raw_state.get_global_state(self.WINDOW_IDS, {}) self.counter = None def set_timer(self, window, name, time_domain, timestamp): self.raw_state.set_timer(self._get_id(window), name, time_domain, timestamp) def clear_timer(self, window, name, time_domain): for window_id in self._get_ids(window): self.raw_state.clear_timer(window_id, name, time_domain) def add_state(self, window, tag, value): if isinstance(tag, _ValueStateTag): raise ValueError( 'Merging requested for non-mergeable state tag: %r.' % tag) self.raw_state.add_state(self._get_id(window), tag, value) def get_state(self, window, tag): values = [self.raw_state.get_state(window_id, tag) for window_id in self._get_ids(window)] if isinstance(tag, _ValueStateTag): raise ValueError( 'Merging requested for non-mergeable state tag: %r.' % tag) elif isinstance(tag, _CombiningValueStateTag): # TODO(robertwb): Strip combine_fn.extract_output from raw_state tag. if not values: accumulator = tag.combine_fn.create_accumulator() elif len(values) == 1: accumulator = values[0] else: accumulator = tag.combine_fn.merge_accumulators(values) # TODO(robertwb): Store the merged value in the first tag. return tag.combine_fn.extract_output(accumulator) elif isinstance(tag, _ListStateTag): return [v for vs in values for v in vs] elif isinstance(tag, _WatermarkHoldStateTag): return tag.timestamp_combiner_impl.combine_all(values) else: raise ValueError('Invalid tag.', tag) def clear_state(self, window, tag): for window_id in self._get_ids(window): self.raw_state.clear_state(window_id, tag) if tag is None: del self.window_ids[window] self._persist_window_ids() def merge(self, to_be_merged, merge_result): for window in to_be_merged: if window != merge_result: if window in self.window_ids: if merge_result in self.window_ids: merge_window_ids = self.window_ids[merge_result] else: merge_window_ids = self.window_ids[merge_result] = [] merge_window_ids.extend(self.window_ids.pop(window)) self._persist_window_ids() def known_windows(self): return self.window_ids.keys() def get_window(self, window_id): for window, ids in self.window_ids.items(): if window_id in ids: return window raise ValueError('No window for %s' % window_id) def _get_id(self, window): if window in self.window_ids: return self.window_ids[window][0] window_id = self._get_next_counter() self.window_ids[window] = [window_id] self._persist_window_ids() return window_id def _get_ids(self, window): return self.window_ids.get(window, []) def _get_next_counter(self): if not self.window_ids: self.counter = 0 elif self.counter is None: self.counter = max(k for ids in self.window_ids.values() for k in ids) self.counter += 1 return self.counter def _persist_window_ids(self): self.raw_state.set_global_state(self.WINDOW_IDS, self.window_ids) def __repr__(self): return '\n\t'.join([repr(self.window_ids)] + repr(self.raw_state).split('\n')) def create_trigger_driver(windowing, is_batch=False, phased_combine_fn=None, clock=None): """Create the TriggerDriver for the given windowing and options.""" # TODO(robertwb): We can do more if we know elements are in timestamp # sorted order. if windowing.is_default() and is_batch: driver = DefaultGlobalBatchTriggerDriver() else: driver = GeneralTriggerDriver(windowing, clock) if phased_combine_fn: # TODO(ccy): Refactor GeneralTriggerDriver to combine values eagerly using # the known phased_combine_fn here. driver = CombiningTriggerDriver(phased_combine_fn, driver) return driver class TriggerDriver(object): """Breaks a series of bundle and timer firings into window (pane)s.""" __metaclass__ = ABCMeta @abstractmethod def process_elements(self, state, windowed_values, output_watermark): pass @abstractmethod def process_timer(self, window_id, name, time_domain, timestamp, state): pass def process_entire_key( self, key, windowed_values, output_watermark=MIN_TIMESTAMP): state = InMemoryUnmergedState() for wvalue in self.process_elements( state, windowed_values, output_watermark): yield wvalue.with_value((key, wvalue.value)) while state.timers: fired = state.get_and_clear_timers() for timer_window, (name, time_domain, fire_time) in fired: for wvalue in self.process_timer( timer_window, name, time_domain, fire_time, state): yield wvalue.with_value((key, wvalue.value)) class _UnwindowedValues(observable.ObservableMixin): """Exposes iterable of windowed values as iterable of unwindowed values.""" def __init__(self, windowed_values): super(_UnwindowedValues, self).__init__() self._windowed_values = windowed_values def __iter__(self): for wv in self._windowed_values: unwindowed_value = wv.value self.notify_observers(unwindowed_value) yield unwindowed_value def __repr__(self): return '<_UnwindowedValues of %s>' % self._windowed_values def __reduce__(self): return list, (list(self),) def __eq__(self, other): if isinstance(other, collections.Iterable): return all( a == b for a, b in itertools.izip_longest(self, other, fillvalue=object())) else: return NotImplemented def __ne__(self, other): return not self == other class DefaultGlobalBatchTriggerDriver(TriggerDriver): """Breaks a bundles into window (pane)s according to the default triggering. """ GLOBAL_WINDOW_TUPLE = (GlobalWindow(),) def __init__(self): pass def process_elements(self, state, windowed_values, unused_output_watermark): yield WindowedValue( _UnwindowedValues(windowed_values), MIN_TIMESTAMP, self.GLOBAL_WINDOW_TUPLE) def process_timer(self, window_id, name, time_domain, timestamp, state): raise TypeError('Triggers never set or called for batch default windowing.') class CombiningTriggerDriver(TriggerDriver): """Uses a phased_combine_fn to process output of wrapped TriggerDriver.""" def __init__(self, phased_combine_fn, underlying): self.phased_combine_fn = phased_combine_fn self.underlying = underlying def process_elements(self, state, windowed_values, output_watermark): uncombined = self.underlying.process_elements(state, windowed_values, output_watermark) for output in uncombined: yield output.with_value(self.phased_combine_fn.apply(output.value)) def process_timer(self, window_id, name, time_domain, timestamp, state): uncombined = self.underlying.process_timer(window_id, name, time_domain, timestamp, state) for output in uncombined: yield output.with_value(self.phased_combine_fn.apply(output.value)) class GeneralTriggerDriver(TriggerDriver): """Breaks a series of bundle and timer firings into window (pane)s. Suitable for all variants of Windowing. """ ELEMENTS = _ListStateTag('elements') TOMBSTONE = _CombiningValueStateTag('tombstone', combiners.CountCombineFn()) def __init__(self, windowing, clock): self.clock = clock self.window_fn = windowing.windowfn self.timestamp_combiner_impl = TimestampCombiner.get_impl( windowing.timestamp_combiner, self.window_fn) # pylint: disable=invalid-name self.WATERMARK_HOLD = _WatermarkHoldStateTag( 'watermark', self.timestamp_combiner_impl) # pylint: enable=invalid-name self.trigger_fn = windowing.triggerfn self.accumulation_mode = windowing.accumulation_mode self.is_merging = True def process_elements(self, state, windowed_values, output_watermark): if self.is_merging: state = MergeableStateAdapter(state) windows_to_elements = collections.defaultdict(list) for wv in windowed_values: for window in wv.windows: windows_to_elements[window].append((wv.value, wv.timestamp)) # First handle merging. if self.is_merging: old_windows = set(state.known_windows()) all_windows = old_windows.union(windows_to_elements.keys()) if all_windows != old_windows: merged_away = {} class TriggerMergeContext(WindowFn.MergeContext): def merge(_, to_be_merged, merge_result): # pylint: disable=no-self-argument for window in to_be_merged: if window != merge_result: merged_away[window] = merge_result state.merge(to_be_merged, merge_result) # using the outer self argument. self.trigger_fn.on_merge( to_be_merged, merge_result, state.at(merge_result)) self.window_fn.merge(TriggerMergeContext(all_windows)) merged_windows_to_elements = collections.defaultdict(list) for window, values in windows_to_elements.items(): while window in merged_away: window = merged_away[window] merged_windows_to_elements[window].extend(values) windows_to_elements = merged_windows_to_elements for window in merged_away: state.clear_state(window, self.WATERMARK_HOLD) # Next handle element adding. for window, elements in windows_to_elements.items(): if state.get_state(window, self.TOMBSTONE): continue # Add watermark hold. # TODO(ccy): Add late data and garbage-collection hold support. output_time = self.timestamp_combiner_impl.merge( window, (element_output_time for element_output_time in (self.timestamp_combiner_impl.assign_output_time(window, timestamp) for unused_value, timestamp in elements) if element_output_time >= output_watermark)) if output_time is not None: state.add_state(window, self.WATERMARK_HOLD, output_time) context = state.at(window, self.clock) for value, unused_timestamp in elements: state.add_state(window, self.ELEMENTS, value) self.trigger_fn.on_element(value, window, context) # Maybe fire this window. watermark = MIN_TIMESTAMP if self.trigger_fn.should_fire(TimeDomain.WATERMARK, watermark, window, context): finished = self.trigger_fn.on_fire(watermark, window, context) yield self._output(window, finished, state) def process_timer(self, window_id, unused_name, time_domain, timestamp, state): if self.is_merging: state = MergeableStateAdapter(state) window = state.get_window(window_id) if state.get_state(window, self.TOMBSTONE): return if time_domain in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME): if not self.is_merging or window in state.known_windows(): context = state.at(window, self.clock) if self.trigger_fn.should_fire(time_domain, timestamp, window, context): finished = self.trigger_fn.on_fire(timestamp, window, context) yield self._output(window, finished, state) else: raise Exception('Unexpected time domain: %s' % time_domain) def _output(self, window, finished, state): """Output window and clean up if appropriate.""" values = state.get_state(window, self.ELEMENTS) if finished: # TODO(robertwb): allowed lateness state.clear_state(window, self.ELEMENTS) state.add_state(window, self.TOMBSTONE, 1) elif self.accumulation_mode == AccumulationMode.DISCARDING: state.clear_state(window, self.ELEMENTS) timestamp = state.get_state(window, self.WATERMARK_HOLD) if timestamp is None: # If no watermark hold was set, output at end of window. timestamp = window.end else: state.clear_state(window, self.WATERMARK_HOLD) return WindowedValue(values, timestamp, (window,)) class InMemoryUnmergedState(UnmergedState): """In-memory implementation of UnmergedState. Used for batch and testing. """ def __init__(self, defensive_copy=True): # TODO(robertwb): Skip defensive_copy in production if it's too expensive. self.timers = collections.defaultdict(dict) self.state = collections.defaultdict(lambda: collections.defaultdict(list)) self.global_state = {} self.defensive_copy = defensive_copy def copy(self): cloned_object = InMemoryUnmergedState(defensive_copy=self.defensive_copy) cloned_object.timers = copy.deepcopy(self.timers) cloned_object.global_state = copy.deepcopy(self.global_state) for window in self.state: for tag in self.state[window]: cloned_object.state[window][tag] = copy.copy(self.state[window][tag]) return cloned_object def set_global_state(self, tag, value): assert isinstance(tag, _ValueStateTag) if self.defensive_copy: value = copy.deepcopy(value) self.global_state[tag.tag] = value def get_global_state(self, tag, default=None): return self.global_state.get(tag.tag, default) def set_timer(self, window, name, time_domain, timestamp): self.timers[window][(name, time_domain)] = timestamp def clear_timer(self, window, name, time_domain): self.timers[window].pop((name, time_domain), None) if not self.timers[window]: del self.timers[window] def get_window(self, window_id): return window_id def add_state(self, window, tag, value): if self.defensive_copy: value = copy.deepcopy(value) if isinstance(tag, _ValueStateTag): self.state[window][tag.tag] = value elif isinstance(tag, _CombiningValueStateTag): self.state[window][tag.tag].append(value) elif isinstance(tag, _ListStateTag): self.state[window][tag.tag].append(value) elif isinstance(tag, _WatermarkHoldStateTag): self.state[window][tag.tag].append(value) else: raise ValueError('Invalid tag.', tag) def get_state(self, window, tag): values = self.state[window][tag.tag] if isinstance(tag, _ValueStateTag): return values elif isinstance(tag, _CombiningValueStateTag): return tag.combine_fn.apply(values) elif isinstance(tag, _ListStateTag): return values elif isinstance(tag, _WatermarkHoldStateTag): return tag.timestamp_combiner_impl.combine_all(values) else: raise ValueError('Invalid tag.', tag) def clear_state(self, window, tag): self.state[window].pop(tag.tag, None) if not self.state[window]: self.state.pop(window, None) def get_timers(self, clear=False, watermark=MAX_TIMESTAMP, processing_time=None): """Gets expired timers and reports if there are any realtime timers set per state. Expiration is measured against the watermark for event-time timers, and against a wall clock for processing-time timers. """ expired = [] has_realtime_timer = False for window, timers in list(self.timers.items()): for (name, time_domain), timestamp in list(timers.items()): if time_domain == TimeDomain.REAL_TIME: time_marker = processing_time has_realtime_timer = True elif time_domain == TimeDomain.WATERMARK: time_marker = watermark else: logging.error( 'TimeDomain error: No timers defined for time domain %s.', time_domain) if timestamp <= time_marker: expired.append((window, (name, time_domain, timestamp))) if clear: del timers[(name, time_domain)] if not timers and clear: del self.timers[window] return expired, has_realtime_timer def get_and_clear_timers(self, watermark=MAX_TIMESTAMP): return self.get_timers(clear=True, watermark=watermark)[0] def get_earliest_hold(self): earliest_hold = MAX_TIMESTAMP for unused_window, tagged_states in self.state.iteritems(): # TODO(BEAM-2519): currently, this assumes that the watermark hold tag is # named "watermark". This is currently only true because the only place # watermark holds are set is in the GeneralTriggerDriver, where we use # this name. We should fix this by allowing enumeration of the tag types # used in adding state. if 'watermark' in tagged_states and tagged_states['watermark']: hold = min(tagged_states['watermark']) - TIME_GRANULARITY earliest_hold = min(earliest_hold, hold) return earliest_hold def __repr__(self): state_str = '\n'.join('%s: %s' % (key, dict(state)) for key, state in self.state.items()) return 'timers: %s\nstate: %s' % (dict(self.timers), state_str)