#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Support for Dataflow triggers.
Triggers control when in processing time windows get emitted.
"""
from __future__ import absolute_import
import collections
import copy
import logging
import numbers
from abc import ABCMeta
from abc import abstractmethod
from builtins import object
from future.moves.itertools import zip_longest
from future.utils import iteritems
from future.utils import with_metaclass
from apache_beam.coders import coder_impl
from apache_beam.coders import observable
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import combiners
from apache_beam.transforms import core
from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import WindowedValue
from apache_beam.transforms.window import WindowFn
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import MIN_TIMESTAMP
from apache_beam.utils.timestamp import TIME_GRANULARITY
# AfterCount is experimental. No backwards compatibility guarantees.
__all__ = [
    'AccumulationMode',
    'TriggerFn',
    'DefaultTrigger',
    'AfterWatermark',
    'AfterProcessingTime',
    'AfterCount',
    'Repeatedly',
    'AfterAny',
    'AfterAll',
    'AfterEach',
    'OrFinally',
    ]
[docs]class AccumulationMode(object):
  """Controls what to do with data when a trigger fires multiple times."""
  DISCARDING = beam_runner_api_pb2.AccumulationMode.DISCARDING
  ACCUMULATING = beam_runner_api_pb2.AccumulationMode.ACCUMULATING 
  # TODO(robertwb): Provide retractions of previous outputs.
  # RETRACTING = 3
class _StateTag(with_metaclass(ABCMeta, object)):
  """An identifier used to store and retrieve typed, combinable state.
  The given tag must be unique for this step."""
  def __init__(self, tag):
    self.tag = tag
class _ValueStateTag(_StateTag):
  """StateTag pointing to an element."""
  def __repr__(self):
    return 'ValueStateTag(%s)' % (self.tag)
  def with_prefix(self, prefix):
    return _ValueStateTag(prefix + self.tag)
class _SetStateTag(_StateTag):
  """StateTag pointing to an element."""
  def __repr__(self):
    return 'SetStateTag({tag})'.format(tag=self.tag)
  def with_prefix(self, prefix):
    return _SetStateTag(prefix + self.tag)
class _CombiningValueStateTag(_StateTag):
  """StateTag pointing to an element, accumulated with a combiner.
  The given tag must be unique for this step. The given CombineFn will be
  applied (possibly incrementally and eagerly) when adding elements."""
  # TODO(robertwb): Also store the coder (perhaps extracted from the combine_fn)
  def __init__(self, tag, combine_fn):
    super(_CombiningValueStateTag, self).__init__(tag)
    if not combine_fn:
      raise ValueError('combine_fn must be specified.')
    if not isinstance(combine_fn, core.CombineFn):
      combine_fn = core.CombineFn.from_callable(combine_fn)
    self.combine_fn = combine_fn
  def __repr__(self):
    return 'CombiningValueStateTag(%s, %s)' % (self.tag, self.combine_fn)
  def with_prefix(self, prefix):
    return _CombiningValueStateTag(prefix + self.tag, self.combine_fn)
  def without_extraction(self):
    class NoExtractionCombineFn(core.CombineFn):
      create_accumulator = self.combine_fn.create_accumulator
      add_input = self.combine_fn.add_input
      merge_accumulators = self.combine_fn.merge_accumulators
      compact = self.combine_fn.compact
      extract_output = staticmethod(lambda x: x)
    return _CombiningValueStateTag(self.tag, NoExtractionCombineFn())
class _ListStateTag(_StateTag):
  """StateTag pointing to a list of elements."""
  def __repr__(self):
    return 'ListStateTag(%s)' % self.tag
  def with_prefix(self, prefix):
    return _ListStateTag(prefix + self.tag)
class _WatermarkHoldStateTag(_StateTag):
  def __init__(self, tag, timestamp_combiner_impl):
    super(_WatermarkHoldStateTag, self).__init__(tag)
    self.timestamp_combiner_impl = timestamp_combiner_impl
  def __repr__(self):
    return 'WatermarkHoldStateTag(%s, %s)' % (self.tag,
                                              self.timestamp_combiner_impl)
  def with_prefix(self, prefix):
    return _WatermarkHoldStateTag(prefix + self.tag,
                                  self.timestamp_combiner_impl)
# pylint: disable=unused-argument
# TODO(robertwb): Provisional API, Java likely to change as well.
[docs]class TriggerFn(with_metaclass(ABCMeta, object)):
  """A TriggerFn determines when window (panes) are emitted.
  See https://beam.apache.org/documentation/programming-guide/#triggers
  """
[docs]  @abstractmethod
  def on_element(self, element, window, context):
    """Called when a new element arrives in a window.
    Args:
      element: the element being added
      window: the window to which the element is being added
      context: a context (e.g. a TriggerContext instance) for managing state
          and setting timers
    """
    pass 
[docs]  @abstractmethod
  def on_merge(self, to_be_merged, merge_result, context):
    """Called when multiple windows are merged.
    Args:
      to_be_merged: the set of windows to be merged
      merge_result: the window into which the windows are being merged
      context: a context (e.g. a TriggerContext instance) for managing state
          and setting timers
    """
    pass 
[docs]  @abstractmethod
  def should_fire(self, time_domain, timestamp, window, context):
    """Whether this trigger should cause the window to fire.
    Args:
      time_domain: WATERMARK for event-time timers and REAL_TIME for
          processing-time timers.
      timestamp: for time_domain WATERMARK, it represents the
          watermark: (a lower bound on) the watermark of the system
          and for time_domain REAL_TIME, it represents the
          trigger: timestamp of the processing-time timer.
      window: the window whose trigger is being considered
      context: a context (e.g. a TriggerContext instance) for managing state
          and setting timers
    Returns:
      whether this trigger should cause a firing
    """
    pass 
[docs]  @abstractmethod
  def on_fire(self, watermark, window, context):
    """Called when a trigger actually fires.
    Args:
      watermark: (a lower bound on) the watermark of the system
      window: the window whose trigger is being fired
      context: a context (e.g. a TriggerContext instance) for managing state
          and setting timers
    Returns:
      whether this trigger is finished
    """
    pass 
[docs]  @abstractmethod
  def reset(self, window, context):
    """Clear any state and timers used by this TriggerFn."""
    pass 
# pylint: enable=unused-argument
[docs]  @staticmethod
  def from_runner_api(proto, context):
    return {
        'after_all': AfterAll,
        'after_any': AfterAny,
        'after_each': AfterEach,
        'after_end_of_window': AfterWatermark,
        'after_processing_time': AfterProcessingTime,
        # after_processing_time, after_synchronized_processing_time
        # always
        'default': DefaultTrigger,
        'element_count': AfterCount,
        # never
        'or_finally': OrFinally,
        'repeat': Repeatedly,
    }[proto.WhichOneof('trigger')].from_runner_api(proto, context) 
[docs]  @abstractmethod
  def to_runner_api(self, unused_context):
    pass 
  def __ne__(self, other):
    # TODO(BEAM-5949): Needed for Python 2 compatibility.
    return not self == other 
[docs]class DefaultTrigger(TriggerFn):
  """Semantically Repeatedly(AfterWatermark()), but more optimized."""
  def __init__(self):
    pass
  def __repr__(self):
    return 'DefaultTrigger()'
[docs]  def on_element(self, element, window, context):
    context.set_timer('', TimeDomain.WATERMARK, window.end) 
[docs]  def on_merge(self, to_be_merged, merge_result, context):
    # Note: Timer clearing solely an optimization.
    for window in to_be_merged:
      if window.end != merge_result.end:
        context.clear_timer('', TimeDomain.WATERMARK) 
[docs]  def should_fire(self, time_domain, watermark, window, context):
    return watermark >= window.end 
[docs]  def on_fire(self, watermark, window, context):
    return False 
[docs]  def reset(self, window, context):
    context.clear_timer('', TimeDomain.WATERMARK) 
  def __eq__(self, other):
    return type(self) == type(other)
  def __hash__(self):
    return hash(type(self))
[docs]  @staticmethod
  def from_runner_api(proto, context):
    return DefaultTrigger() 
[docs]  def to_runner_api(self, unused_context):
    return beam_runner_api_pb2.Trigger(
        default=beam_runner_api_pb2.Trigger.Default())  
[docs]class AfterProcessingTime(TriggerFn):
  """Fire exactly once after a specified delay from processing time.
  AfterProcessingTime is experimental. No backwards compatibility guarantees.
  """
  def __init__(self, delay=0):
    """Initialize a processing time trigger with a delay in seconds."""
    self.delay = delay
  def __repr__(self):
    return 'AfterProcessingTime(delay=%d)' % self.delay
[docs]  def on_element(self, element, window, context):
    context.set_timer(
        '', TimeDomain.REAL_TIME, context.get_current_time() + self.delay) 
[docs]  def on_merge(self, to_be_merged, merge_result, context):
    # timers will be kept through merging
    pass 
[docs]  def should_fire(self, time_domain, timestamp, window, context):
    if time_domain == TimeDomain.REAL_TIME:
      return True 
[docs]  def on_fire(self, timestamp, window, context):
    return True 
[docs]  def reset(self, window, context):
    pass 
[docs]  @staticmethod
  def from_runner_api(proto, context):
    return AfterProcessingTime(
        delay=(
            proto.after_processing_time
            .timestamp_transforms[0]
            .delay
            .delay_millis) // 1000) 
[docs]  def to_runner_api(self, context):
    delay_proto = beam_runner_api_pb2.TimestampTransform(
        delay=beam_runner_api_pb2.TimestampTransform.Delay(
            delay_millis=self.delay * 1000))
    return beam_runner_api_pb2.Trigger(
        after_processing_time=beam_runner_api_pb2.Trigger.AfterProcessingTime(
            timestamp_transforms=[delay_proto]))  
[docs]class AfterWatermark(TriggerFn):
  """Fire exactly once when the watermark passes the end of the window.
  Args:
      early: if not None, a speculative trigger to repeatedly evaluate before
        the watermark passes the end of the window
      late: if not None, a speculative trigger to repeatedly evaluate after
        the watermark passes the end of the window
  """
  LATE_TAG = _CombiningValueStateTag('is_late', any)
  def __init__(self, early=None, late=None):
    self.early = Repeatedly(early) if early else None
    self.late = Repeatedly(late) if late else None
  def __repr__(self):
    qualifiers = []
    if self.early:
      qualifiers.append('early=%s' % self.early.underlying)
    if self.late:
      qualifiers.append('late=%s' % self.late.underlying)
    return 'AfterWatermark(%s)' % ', '.join(qualifiers)
[docs]  def is_late(self, context):
    return self.late and context.get_state(self.LATE_TAG) 
[docs]  def on_element(self, element, window, context):
    if self.is_late(context):
      self.late.on_element(element, window, NestedContext(context, 'late'))
    else:
      context.set_timer('', TimeDomain.WATERMARK, window.end)
      if self.early:
        self.early.on_element(element, window, NestedContext(context, 'early')) 
[docs]  def on_merge(self, to_be_merged, merge_result, context):
    # TODO(robertwb): Figure out whether the 'rewind' semantics could be used
    # here.
    if self.is_late(context):
      self.late.on_merge(
          to_be_merged, merge_result, NestedContext(context, 'late'))
    else:
      # Note: Timer clearing solely an optimization.
      for window in to_be_merged:
        if window.end != merge_result.end:
          context.clear_timer('', TimeDomain.WATERMARK)
      if self.early:
        self.early.on_merge(
            to_be_merged, merge_result, NestedContext(context, 'early')) 
[docs]  def should_fire(self, time_domain, watermark, window, context):
    if self.is_late(context):
      return self.late.should_fire(time_domain, watermark,
                                   window, NestedContext(context, 'late'))
    elif watermark >= window.end:
      return True
    elif self.early:
      return self.early.should_fire(time_domain, watermark,
                                    window, NestedContext(context, 'early'))
    return False 
[docs]  def on_fire(self, watermark, window, context):
    if self.is_late(context):
      return self.late.on_fire(
          watermark, window, NestedContext(context, 'late'))
    elif watermark >= window.end:
      context.add_state(self.LATE_TAG, True)
      return not self.late
    elif self.early:
      self.early.on_fire(watermark, window, NestedContext(context, 'early'))
      return False 
[docs]  def reset(self, window, context):
    if self.late:
      context.clear_state(self.LATE_TAG)
    if self.early:
      self.early.reset(window, NestedContext(context, 'early'))
    if self.late:
      self.late.reset(window, NestedContext(context, 'late')) 
  def __eq__(self, other):
    return (type(self) == type(other)
            and self.early == other.early
            and self.late == other.late)
  def __hash__(self):
    return hash((type(self), self.early, self.late))
[docs]  @staticmethod
  def from_runner_api(proto, context):
    return AfterWatermark(
        early=TriggerFn.from_runner_api(
            proto.after_end_of_window.early_firings, context)
        if proto.after_end_of_window.HasField('early_firings')
        else None,
        late=TriggerFn.from_runner_api(
            proto.after_end_of_window.late_firings, context)
        if proto.after_end_of_window.HasField('late_firings')
        else None) 
[docs]  def to_runner_api(self, context):
    early_proto = self.early.underlying.to_runner_api(
        context) if self.early else None
    late_proto = self.late.underlying.to_runner_api(
        context) if self.late else None
    return beam_runner_api_pb2.Trigger(
        after_end_of_window=beam_runner_api_pb2.Trigger.AfterEndOfWindow(
            early_firings=early_proto,
            late_firings=late_proto))  
[docs]class AfterCount(TriggerFn):
  """Fire when there are at least count elements in this window pane.
  AfterCount is experimental. No backwards compatibility guarantees.
  """
  COUNT_TAG = _CombiningValueStateTag('count', combiners.CountCombineFn())
  def __init__(self, count):
    if not isinstance(count, numbers.Integral) or count < 1:
      raise ValueError("count (%d) must be a positive integer." % count)
    self.count = count
  def __repr__(self):
    return 'AfterCount(%s)' % self.count
  def __eq__(self, other):
    return type(self) == type(other) and self.count == other.count
  def __hash__(self):
    return hash(self.count)
[docs]  def on_element(self, element, window, context):
    context.add_state(self.COUNT_TAG, 1) 
[docs]  def on_merge(self, to_be_merged, merge_result, context):
    # states automatically merged
    pass 
[docs]  def should_fire(self, time_domain, watermark, window, context):
    return context.get_state(self.COUNT_TAG) >= self.count 
[docs]  def on_fire(self, watermark, window, context):
    return True 
[docs]  def reset(self, window, context):
    context.clear_state(self.COUNT_TAG) 
[docs]  @staticmethod
  def from_runner_api(proto, unused_context):
    return AfterCount(proto.element_count.element_count) 
[docs]  def to_runner_api(self, unused_context):
    return beam_runner_api_pb2.Trigger(
        element_count=beam_runner_api_pb2.Trigger.ElementCount(
            element_count=self.count))  
[docs]class Repeatedly(TriggerFn):
  """Repeatedly invoke the given trigger, never finishing."""
  def __init__(self, underlying):
    self.underlying = underlying
  def __repr__(self):
    return 'Repeatedly(%s)' % self.underlying
  def __eq__(self, other):
    return type(self) == type(other) and self.underlying == other.underlying
  def __hash__(self):
    return hash(self.underlying)
[docs]  def on_element(self, element, window, context):
    self.underlying.on_element(element, window, context) 
[docs]  def on_merge(self, to_be_merged, merge_result, context):
    self.underlying.on_merge(to_be_merged, merge_result, context) 
[docs]  def should_fire(self, time_domain, watermark, window, context):
    return self.underlying.should_fire(time_domain, watermark, window, context) 
[docs]  def on_fire(self, watermark, window, context):
    if self.underlying.on_fire(watermark, window, context):
      self.underlying.reset(window, context)
    return False 
[docs]  def reset(self, window, context):
    self.underlying.reset(window, context) 
[docs]  @staticmethod
  def from_runner_api(proto, context):
    return Repeatedly(
        TriggerFn.from_runner_api(proto.repeat.subtrigger, context)) 
[docs]  def to_runner_api(self, context):
    return beam_runner_api_pb2.Trigger(
        repeat=beam_runner_api_pb2.Trigger.Repeat(
            subtrigger=self.underlying.to_runner_api(context)))  
class _ParallelTriggerFn(with_metaclass(ABCMeta, TriggerFn)):
  def __init__(self, *triggers):
    self.triggers = triggers
  def __repr__(self):
    return '%s(%s)' % (self.__class__.__name__,
                       ', '.join(str(t) for t in self.triggers))
  def __eq__(self, other):
    return type(self) == type(other) and self.triggers == other.triggers
  def __hash__(self):
    return hash(self.triggers)
  @abstractmethod
  def combine_op(self, trigger_results):
    pass
  def on_element(self, element, window, context):
    for ix, trigger in enumerate(self.triggers):
      trigger.on_element(element, window, self._sub_context(context, ix))
  def on_merge(self, to_be_merged, merge_result, context):
    for ix, trigger in enumerate(self.triggers):
      trigger.on_merge(
          to_be_merged, merge_result, self._sub_context(context, ix))
  def should_fire(self, time_domain, watermark, window, context):
    self._time_domain = time_domain
    return self.combine_op(
        trigger.should_fire(time_domain, watermark, window,
                            self._sub_context(context, ix))
        for ix, trigger in enumerate(self.triggers))
  def on_fire(self, watermark, window, context):
    finished = []
    for ix, trigger in enumerate(self.triggers):
      nested_context = self._sub_context(context, ix)
      if trigger.should_fire(TimeDomain.WATERMARK, watermark,
                             window, nested_context):
        finished.append(trigger.on_fire(watermark, window, nested_context))
    return self.combine_op(finished)
  def reset(self, window, context):
    for ix, trigger in enumerate(self.triggers):
      trigger.reset(window, self._sub_context(context, ix))
  @staticmethod
  def _sub_context(context, index):
    return NestedContext(context, '%d/' % index)
  @staticmethod
  def from_runner_api(proto, context):
    subtriggers = [
        TriggerFn.from_runner_api(subtrigger, context)
        for subtrigger
        in proto.after_all.subtriggers or proto.after_any.subtriggers]
    if proto.after_all.subtriggers:
      return AfterAll(*subtriggers)
    else:
      return AfterAny(*subtriggers)
  def to_runner_api(self, context):
    subtriggers = [
        subtrigger.to_runner_api(context) for subtrigger in self.triggers]
    if self.combine_op == all:
      return beam_runner_api_pb2.Trigger(
          after_all=beam_runner_api_pb2.Trigger.AfterAll(
              subtriggers=subtriggers))
    elif self.combine_op == any:
      return beam_runner_api_pb2.Trigger(
          after_any=beam_runner_api_pb2.Trigger.AfterAny(
              subtriggers=subtriggers))
    else:
      raise NotImplementedError(self)
[docs]class AfterAny(_ParallelTriggerFn):
  """Fires when any subtrigger fires.
  Also finishes when any subtrigger finishes.
  """
  combine_op = any 
[docs]class AfterAll(_ParallelTriggerFn):
  """Fires when all subtriggers have fired.
  Also finishes when all subtriggers have finished.
  """
  combine_op = all 
[docs]class AfterEach(TriggerFn):
  INDEX_TAG = _CombiningValueStateTag('index', (
      lambda indices: 0 if not indices else max(indices)))
  def __init__(self, *triggers):
    self.triggers = triggers
  def __repr__(self):
    return '%s(%s)' % (self.__class__.__name__,
                       ', '.join(str(t) for t in self.triggers))
  def __eq__(self, other):
    return type(self) == type(other) and self.triggers == other.triggers
  def __hash__(self):
    return hash(self.triggers)
[docs]  def on_element(self, element, window, context):
    ix = context.get_state(self.INDEX_TAG)
    if ix < len(self.triggers):
      self.triggers[ix].on_element(
          element, window, self._sub_context(context, ix)) 
[docs]  def on_merge(self, to_be_merged, merge_result, context):
    # This takes the furthest window on merging.
    # TODO(robertwb): Revisit this when merging windows logic is settled for
    # all possible merging situations.
    ix = context.get_state(self.INDEX_TAG)
    if ix < len(self.triggers):
      self.triggers[ix].on_merge(
          to_be_merged, merge_result, self._sub_context(context, ix)) 
[docs]  def should_fire(self, time_domain, watermark, window, context):
    ix = context.get_state(self.INDEX_TAG)
    if ix < len(self.triggers):
      return self.triggers[ix].should_fire(
          time_domain, watermark, window, self._sub_context(context, ix)) 
[docs]  def on_fire(self, watermark, window, context):
    ix = context.get_state(self.INDEX_TAG)
    if ix < len(self.triggers):
      if self.triggers[ix].on_fire(
          watermark, window, self._sub_context(context, ix)):
        ix += 1
        context.add_state(self.INDEX_TAG, ix)
      return ix == len(self.triggers) 
[docs]  def reset(self, window, context):
    context.clear_state(self.INDEX_TAG)
    for ix, trigger in enumerate(self.triggers):
      trigger.reset(window, self._sub_context(context, ix)) 
  @staticmethod
  def _sub_context(context, index):
    return NestedContext(context, '%d/' % index)
[docs]  @staticmethod
  def from_runner_api(proto, context):
    return AfterEach(*[
        TriggerFn.from_runner_api(subtrigger, context)
        for subtrigger in proto.after_each.subtriggers]) 
[docs]  def to_runner_api(self, context):
    return beam_runner_api_pb2.Trigger(
        after_each=beam_runner_api_pb2.Trigger.AfterEach(
            subtriggers=[
                subtrigger.to_runner_api(context)
                for subtrigger in self.triggers]))  
[docs]class OrFinally(AfterAny):
[docs]  @staticmethod
  def from_runner_api(proto, context):
    return OrFinally(
        TriggerFn.from_runner_api(proto.or_finally.main, context),
        # getattr is used as finally is a keyword in Python
        TriggerFn.from_runner_api(getattr(proto.or_finally, 'finally'),
                                  context)) 
[docs]  def to_runner_api(self, context):
    return beam_runner_api_pb2.Trigger(
        or_finally=beam_runner_api_pb2.Trigger.OrFinally(
            main=self.triggers[0].to_runner_api(context),
            # dict keyword argument is used as finally is a keyword in Python
            **{'finally': self.triggers[1].to_runner_api(context)}))  
class TriggerContext(object):
  def __init__(self, outer, window, clock):
    self._outer = outer
    self._window = window
    self._clock = clock
  def get_current_time(self):
    return self._clock.time()
  def set_timer(self, name, time_domain, timestamp):
    self._outer.set_timer(self._window, name, time_domain, timestamp)
  def clear_timer(self, name, time_domain):
    self._outer.clear_timer(self._window, name, time_domain)
  def add_state(self, tag, value):
    self._outer.add_state(self._window, tag, value)
  def get_state(self, tag):
    return self._outer.get_state(self._window, tag)
  def clear_state(self, tag):
    return self._outer.clear_state(self._window, tag)
class NestedContext(object):
  """Namespaced context useful for defining composite triggers."""
  def __init__(self, outer, prefix):
    self._outer = outer
    self._prefix = prefix
  def get_current_time(self):
    return self._outer.get_current_time()
  def set_timer(self, name, time_domain, timestamp):
    self._outer.set_timer(self._prefix + name, time_domain, timestamp)
  def clear_timer(self, name, time_domain):
    self._outer.clear_timer(self._prefix + name, time_domain)
  def add_state(self, tag, value):
    self._outer.add_state(tag.with_prefix(self._prefix), value)
  def get_state(self, tag):
    return self._outer.get_state(tag.with_prefix(self._prefix))
  def clear_state(self, tag):
    self._outer.clear_state(tag.with_prefix(self._prefix))
# pylint: disable=unused-argument
class SimpleState(with_metaclass(ABCMeta, object)):
  """Basic state storage interface used for triggering.
  Only timers must hold the watermark (by their timestamp).
  """
  @abstractmethod
  def set_timer(self, window, name, time_domain, timestamp):
    pass
  @abstractmethod
  def get_window(self, window_id):
    pass
  @abstractmethod
  def clear_timer(self, window, name, time_domain):
    pass
  @abstractmethod
  def add_state(self, window, tag, value):
    pass
  @abstractmethod
  def get_state(self, window, tag):
    pass
  @abstractmethod
  def clear_state(self, window, tag):
    pass
  def at(self, window, clock):
    return TriggerContext(self, window, clock)
class UnmergedState(SimpleState):
  """State suitable for use in TriggerDriver.
  This class must be implemented by each backend.
  """
  @abstractmethod
  def set_global_state(self, tag, value):
    pass
  @abstractmethod
  def get_global_state(self, tag, default=None):
    pass
# pylint: enable=unused-argument
class MergeableStateAdapter(SimpleState):
  """Wraps an UnmergedState, tracking merged windows."""
  # TODO(robertwb): A similar indirection could be used for sliding windows
  # or other window_fns when a single element typically belongs to many windows.
  WINDOW_IDS = _ValueStateTag('window_ids')
  def __init__(self, raw_state):
    self.raw_state = raw_state
    self.window_ids = self.raw_state.get_global_state(self.WINDOW_IDS, {})
    self.counter = None
  def set_timer(self, window, name, time_domain, timestamp):
    self.raw_state.set_timer(self._get_id(window), name, time_domain, timestamp)
  def clear_timer(self, window, name, time_domain):
    for window_id in self._get_ids(window):
      self.raw_state.clear_timer(window_id, name, time_domain)
  def add_state(self, window, tag, value):
    if isinstance(tag, _ValueStateTag):
      raise ValueError(
          'Merging requested for non-mergeable state tag: %r.' % tag)
    elif isinstance(tag, _CombiningValueStateTag):
      tag = tag.without_extraction()
    self.raw_state.add_state(self._get_id(window), tag, value)
  def get_state(self, window, tag):
    if isinstance(tag, _CombiningValueStateTag):
      original_tag, tag = tag, tag.without_extraction()
    values = [self.raw_state.get_state(window_id, tag)
              for window_id in self._get_ids(window)]
    if isinstance(tag, _ValueStateTag):
      raise ValueError(
          'Merging requested for non-mergeable state tag: %r.' % tag)
    elif isinstance(tag, _CombiningValueStateTag):
      return original_tag.combine_fn.extract_output(
          original_tag.combine_fn.merge_accumulators(values))
    elif isinstance(tag, _ListStateTag):
      return [v for vs in values for v in vs]
    elif isinstance(tag, _SetStateTag):
      return {v for vs in values for v in vs}
    elif isinstance(tag, _WatermarkHoldStateTag):
      return tag.timestamp_combiner_impl.combine_all(values)
    else:
      raise ValueError('Invalid tag.', tag)
  def clear_state(self, window, tag):
    for window_id in self._get_ids(window):
      self.raw_state.clear_state(window_id, tag)
    if tag is None:
      del self.window_ids[window]
      self._persist_window_ids()
  def merge(self, to_be_merged, merge_result):
    for window in to_be_merged:
      if window != merge_result:
        if window in self.window_ids:
          if merge_result in self.window_ids:
            merge_window_ids = self.window_ids[merge_result]
          else:
            merge_window_ids = self.window_ids[merge_result] = []
          merge_window_ids.extend(self.window_ids.pop(window))
          self._persist_window_ids()
  def known_windows(self):
    return list(self.window_ids)
  def get_window(self, window_id):
    for window, ids in self.window_ids.items():
      if window_id in ids:
        return window
    raise ValueError('No window for %s' % window_id)
  def _get_id(self, window):
    if window in self.window_ids:
      return self.window_ids[window][0]
    window_id = self._get_next_counter()
    self.window_ids[window] = [window_id]
    self._persist_window_ids()
    return window_id
  def _get_ids(self, window):
    return self.window_ids.get(window, [])
  def _get_next_counter(self):
    if not self.window_ids:
      self.counter = 0
    elif self.counter is None:
      self.counter = max(k for ids in self.window_ids.values() for k in ids)
    self.counter += 1
    return self.counter
  def _persist_window_ids(self):
    self.raw_state.set_global_state(self.WINDOW_IDS, self.window_ids)
  def __repr__(self):
    return '\n\t'.join([repr(self.window_ids)] +
                       repr(self.raw_state).split('\n'))
def create_trigger_driver(windowing,
                          is_batch=False, phased_combine_fn=None, clock=None):
  """Create the TriggerDriver for the given windowing and options."""
  # TODO(robertwb): We can do more if we know elements are in timestamp
  # sorted order.
  if windowing.is_default() and is_batch:
    driver = DiscardingGlobalTriggerDriver()
  elif (windowing.windowfn == GlobalWindows()
        and windowing.triggerfn == AfterCount(1)
        and windowing.accumulation_mode == AccumulationMode.DISCARDING):
    # Here we also just pass through all the values every time.
    driver = DiscardingGlobalTriggerDriver()
  else:
    driver = GeneralTriggerDriver(windowing, clock)
  if phased_combine_fn:
    # TODO(ccy): Refactor GeneralTriggerDriver to combine values eagerly using
    # the known phased_combine_fn here.
    driver = CombiningTriggerDriver(phased_combine_fn, driver)
  return driver
class TriggerDriver(with_metaclass(ABCMeta, object)):
  """Breaks a series of bundle and timer firings into window (pane)s."""
  @abstractmethod
  def process_elements(self, state, windowed_values, output_watermark):
    pass
  @abstractmethod
  def process_timer(self, window_id, name, time_domain, timestamp, state):
    pass
  def process_entire_key(
      self, key, windowed_values, output_watermark=MIN_TIMESTAMP):
    state = InMemoryUnmergedState()
    for wvalue in self.process_elements(
        state, windowed_values, output_watermark):
      yield wvalue.with_value((key, wvalue.value))
    while state.timers:
      fired = state.get_and_clear_timers()
      for timer_window, (name, time_domain, fire_time) in fired:
        for wvalue in self.process_timer(
            timer_window, name, time_domain, fire_time, state):
          yield wvalue.with_value((key, wvalue.value))
class _UnwindowedValues(observable.ObservableMixin):
  """Exposes iterable of windowed values as iterable of unwindowed values."""
  def __init__(self, windowed_values):
    super(_UnwindowedValues, self).__init__()
    self._windowed_values = windowed_values
  def __iter__(self):
    for wv in self._windowed_values:
      unwindowed_value = wv.value
      self.notify_observers(unwindowed_value)
      yield unwindowed_value
  def __repr__(self):
    return '<_UnwindowedValues of %s>' % self._windowed_values
  def __reduce__(self):
    return list, (list(self),)
  def __eq__(self, other):
    if isinstance(other, collections.Iterable):
      return all(
          a == b
          for a, b in zip_longest(self, other, fillvalue=object()))
    else:
      return NotImplemented
  def __hash__(self):
    return hash(tuple(self))
  def __ne__(self, other):
    # TODO(BEAM-5949): Needed for Python 2 compatibility.
    return not self == other
coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type(
    _UnwindowedValues)
class DiscardingGlobalTriggerDriver(TriggerDriver):
  """Groups all received values together.
  """
  GLOBAL_WINDOW_TUPLE = (GlobalWindow(),)
  def process_elements(self, state, windowed_values, unused_output_watermark):
    yield WindowedValue(
        _UnwindowedValues(windowed_values),
        MIN_TIMESTAMP,
        self.GLOBAL_WINDOW_TUPLE)
  def process_timer(self, window_id, name, time_domain, timestamp, state):
    raise TypeError('Triggers never set or called for batch default windowing.')
class CombiningTriggerDriver(TriggerDriver):
  """Uses a phased_combine_fn to process output of wrapped TriggerDriver."""
  def __init__(self, phased_combine_fn, underlying):
    self.phased_combine_fn = phased_combine_fn
    self.underlying = underlying
  def process_elements(self, state, windowed_values, output_watermark):
    uncombined = self.underlying.process_elements(state, windowed_values,
                                                  output_watermark)
    for output in uncombined:
      yield output.with_value(self.phased_combine_fn.apply(output.value))
  def process_timer(self, window_id, name, time_domain, timestamp, state):
    uncombined = self.underlying.process_timer(window_id, name, time_domain,
                                               timestamp, state)
    for output in uncombined:
      yield output.with_value(self.phased_combine_fn.apply(output.value))
class GeneralTriggerDriver(TriggerDriver):
  """Breaks a series of bundle and timer firings into window (pane)s.
  Suitable for all variants of Windowing.
  """
  ELEMENTS = _ListStateTag('elements')
  TOMBSTONE = _CombiningValueStateTag('tombstone', combiners.CountCombineFn())
  def __init__(self, windowing, clock):
    self.clock = clock
    self.window_fn = windowing.windowfn
    self.timestamp_combiner_impl = TimestampCombiner.get_impl(
        windowing.timestamp_combiner, self.window_fn)
    # pylint: disable=invalid-name
    self.WATERMARK_HOLD = _WatermarkHoldStateTag(
        'watermark', self.timestamp_combiner_impl)
    # pylint: enable=invalid-name
    self.trigger_fn = windowing.triggerfn
    self.accumulation_mode = windowing.accumulation_mode
    self.is_merging = True
  def process_elements(self, state, windowed_values, output_watermark):
    if self.is_merging:
      state = MergeableStateAdapter(state)
    windows_to_elements = collections.defaultdict(list)
    for wv in windowed_values:
      for window in wv.windows:
        windows_to_elements[window].append((wv.value, wv.timestamp))
    # First handle merging.
    if self.is_merging:
      old_windows = set(state.known_windows())
      all_windows = old_windows.union(list(windows_to_elements))
      if all_windows != old_windows:
        merged_away = {}
        class TriggerMergeContext(WindowFn.MergeContext):
          def merge(_, to_be_merged, merge_result):  # pylint: disable=no-self-argument
            for window in to_be_merged:
              if window != merge_result:
                merged_away[window] = merge_result
            state.merge(to_be_merged, merge_result)
            # using the outer self argument.
            self.trigger_fn.on_merge(
                to_be_merged, merge_result, state.at(merge_result, self.clock))
        self.window_fn.merge(TriggerMergeContext(all_windows))
        merged_windows_to_elements = collections.defaultdict(list)
        for window, values in windows_to_elements.items():
          while window in merged_away:
            window = merged_away[window]
          merged_windows_to_elements[window].extend(values)
        windows_to_elements = merged_windows_to_elements
        for window in merged_away:
          state.clear_state(window, self.WATERMARK_HOLD)
    # Next handle element adding.
    for window, elements in windows_to_elements.items():
      if state.get_state(window, self.TOMBSTONE):
        continue
      # Add watermark hold.
      # TODO(ccy): Add late data and garbage-collection hold support.
      output_time = self.timestamp_combiner_impl.merge(
          window,
          (element_output_time for element_output_time in
           (self.timestamp_combiner_impl.assign_output_time(window, timestamp)
            for unused_value, timestamp in elements)
           if element_output_time >= output_watermark))
      if output_time is not None:
        state.clear_state(window, self.WATERMARK_HOLD)
        state.add_state(window, self.WATERMARK_HOLD, output_time)
      context = state.at(window, self.clock)
      for value, unused_timestamp in elements:
        state.add_state(window, self.ELEMENTS, value)
        self.trigger_fn.on_element(value, window, context)
      # Maybe fire this window.
      watermark = MIN_TIMESTAMP
      if self.trigger_fn.should_fire(TimeDomain.WATERMARK, watermark,
                                     window, context):
        finished = self.trigger_fn.on_fire(watermark, window, context)
        yield self._output(window, finished, state)
  def process_timer(self, window_id, unused_name, time_domain, timestamp,
                    state):
    if self.is_merging:
      state = MergeableStateAdapter(state)
    window = state.get_window(window_id)
    if state.get_state(window, self.TOMBSTONE):
      return
    if time_domain in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME):
      if not self.is_merging or window in state.known_windows():
        context = state.at(window, self.clock)
        if self.trigger_fn.should_fire(time_domain, timestamp,
                                       window, context):
          finished = self.trigger_fn.on_fire(timestamp, window, context)
          yield self._output(window, finished, state)
    else:
      raise Exception('Unexpected time domain: %s' % time_domain)
  def _output(self, window, finished, state):
    """Output window and clean up if appropriate."""
    values = state.get_state(window, self.ELEMENTS)
    if finished:
      # TODO(robertwb): allowed lateness
      state.clear_state(window, self.ELEMENTS)
      state.add_state(window, self.TOMBSTONE, 1)
    elif self.accumulation_mode == AccumulationMode.DISCARDING:
      state.clear_state(window, self.ELEMENTS)
    timestamp = state.get_state(window, self.WATERMARK_HOLD)
    if timestamp is None:
      # If no watermark hold was set, output at end of window.
      timestamp = window.end
    else:
      state.clear_state(window, self.WATERMARK_HOLD)
    return WindowedValue(values, timestamp, (window,))
class InMemoryUnmergedState(UnmergedState):
  """In-memory implementation of UnmergedState.
  Used for batch and testing.
  """
  def __init__(self, defensive_copy=True):
    # TODO(robertwb): Skip defensive_copy in production if it's too expensive.
    self.timers = collections.defaultdict(dict)
    self.state = collections.defaultdict(lambda: collections.defaultdict(list))
    self.global_state = {}
    self.defensive_copy = defensive_copy
  def copy(self):
    cloned_object = InMemoryUnmergedState(defensive_copy=self.defensive_copy)
    cloned_object.timers = copy.deepcopy(self.timers)
    cloned_object.global_state = copy.deepcopy(self.global_state)
    for window in self.state:
      for tag in self.state[window]:
        cloned_object.state[window][tag] = copy.copy(self.state[window][tag])
    return cloned_object
  def set_global_state(self, tag, value):
    assert isinstance(tag, _ValueStateTag)
    if self.defensive_copy:
      value = copy.deepcopy(value)
    self.global_state[tag.tag] = value
  def get_global_state(self, tag, default=None):
    return self.global_state.get(tag.tag, default)
  def set_timer(self, window, name, time_domain, timestamp):
    self.timers[window][(name, time_domain)] = timestamp
  def clear_timer(self, window, name, time_domain):
    self.timers[window].pop((name, time_domain), None)
    if not self.timers[window]:
      del self.timers[window]
  def get_window(self, window_id):
    return window_id
  def add_state(self, window, tag, value):
    if self.defensive_copy:
      value = copy.deepcopy(value)
    if isinstance(tag, _ValueStateTag):
      self.state[window][tag.tag] = value
    elif isinstance(tag, _CombiningValueStateTag):
      # TODO(robertwb): Store merged accumulators.
      self.state[window][tag.tag].append(value)
    elif isinstance(tag, _ListStateTag):
      self.state[window][tag.tag].append(value)
    elif isinstance(tag, _SetStateTag):
      self.state[window][tag.tag].append(value)
    elif isinstance(tag, _WatermarkHoldStateTag):
      self.state[window][tag.tag].append(value)
    else:
      raise ValueError('Invalid tag.', tag)
  def get_state(self, window, tag):
    values = self.state[window][tag.tag]
    if isinstance(tag, _ValueStateTag):
      return values
    elif isinstance(tag, _CombiningValueStateTag):
      return tag.combine_fn.apply(values)
    elif isinstance(tag, _ListStateTag):
      return values
    elif isinstance(tag, _SetStateTag):
      return values
    elif isinstance(tag, _WatermarkHoldStateTag):
      return tag.timestamp_combiner_impl.combine_all(values)
    else:
      raise ValueError('Invalid tag.', tag)
  def clear_state(self, window, tag):
    self.state[window].pop(tag.tag, None)
    if not self.state[window]:
      self.state.pop(window, None)
  def get_timers(self, clear=False, watermark=MAX_TIMESTAMP,
                 processing_time=None):
    """Gets expired timers and reports if there
    are any realtime timers set per state.
    Expiration is measured against the watermark for event-time timers,
    and against a wall clock for processing-time timers.
    """
    expired = []
    has_realtime_timer = False
    for window, timers in list(self.timers.items()):
      for (name, time_domain), timestamp in list(timers.items()):
        if time_domain == TimeDomain.REAL_TIME:
          time_marker = processing_time
          has_realtime_timer = True
        elif time_domain == TimeDomain.WATERMARK:
          time_marker = watermark
        else:
          logging.error(
              'TimeDomain error: No timers defined for time domain %s.',
              time_domain)
        if timestamp <= time_marker:
          expired.append((window, (name, time_domain, timestamp)))
          if clear:
            del timers[(name, time_domain)]
      if not timers and clear:
        del self.timers[window]
    return expired, has_realtime_timer
  def get_and_clear_timers(self, watermark=MAX_TIMESTAMP):
    return self.get_timers(clear=True, watermark=watermark)[0]
  def get_earliest_hold(self):
    earliest_hold = MAX_TIMESTAMP
    for unused_window, tagged_states in iteritems(self.state):
      # TODO(BEAM-2519): currently, this assumes that the watermark hold tag is
      # named "watermark".  This is currently only true because the only place
      # watermark holds are set is in the GeneralTriggerDriver, where we use
      # this name.  We should fix this by allowing enumeration of the tag types
      # used in adding state.
      if 'watermark' in tagged_states and tagged_states['watermark']:
        hold = min(tagged_states['watermark']) - TIME_GRANULARITY
        earliest_hold = min(earliest_hold, hold)
    return earliest_hold
  def __repr__(self):
    state_str = '\n'.join('%s: %s' % (key, dict(state))
                          for key, state in self.state.items())
    return 'timers: %s\nstate: %s' % (dict(self.timers), state_str)