"""User-facing interfaces for the Beam State and Timer APIs."""

# pytype: skip-file
# mypy: disallow-untyped-defs

import collections
import types
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import NamedTuple
from typing import Optional
from typing import Set
from typing import Tuple
from typing import TypeVar

from apache_beam.coders import Coder
from apache_beam.coders import coders
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.utils import windowed_value
from apache_beam.utils.timestamp import Timestamp

  from apache_beam.runners.pipeline_context import PipelineContext
  from apache_beam.transforms.core import DoFn

CallableT = TypeVar('CallableT', bound=Callable)

[docs] class StateSpec(object): """Specification for a user DoFn state cell.""" def __init__(self, name: str, coder: Coder) -> None: if not isinstance(name, str): raise TypeError("name is not a string") if not isinstance(coder, Coder): raise TypeError("coder is not of type Coder") = name self.coder = coder def __repr__(self) -> str: return '%s(%s)' % (self.__class__.__name__,
[docs] def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: raise NotImplementedError
[docs] class ReadModifyWriteStateSpec(StateSpec): """Specification for a user DoFn value state cell."""
[docs] def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec( coder_id=context.coders.get_id(self.coder)), protocol=beam_runner_api_pb2.FunctionSpec( urn=common_urns.user_state.BAG.urn))
[docs] class BagStateSpec(StateSpec): """Specification for a user DoFn bag state cell."""
[docs] def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( bag_spec=beam_runner_api_pb2.BagStateSpec( element_coder_id=context.coders.get_id(self.coder)), protocol=beam_runner_api_pb2.FunctionSpec( urn=common_urns.user_state.BAG.urn))
[docs] class SetStateSpec(StateSpec): """Specification for a user DoFn Set State cell"""
[docs] def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( set_spec=beam_runner_api_pb2.SetStateSpec( element_coder_id=context.coders.get_id(self.coder)), protocol=beam_runner_api_pb2.FunctionSpec( urn=common_urns.user_state.BAG.urn))
[docs] class CombiningValueStateSpec(StateSpec): """Specification for a user DoFn combining value state cell.""" def __init__( self, name: str, coder: Optional[Coder] = None, combine_fn: Any = None) -> None: """Initialize the specification for CombiningValue state. CombiningValueStateSpec(name, combine_fn) -> Coder-inferred combining value state spec. CombiningValueStateSpec(name, coder, combine_fn) -> Combining value state spec with coder and combine_fn specified. Args: name (str): The name by which the state is identified. coder (Coder): Coder specifying how to encode the values to be combined. May be inferred. combine_fn (``CombineFn`` or ``callable``): Function specifying how to combine the values passed to state. """ # Avoid circular import. from apache_beam.transforms.core import CombineFn # We want the coder to be optional, but unfortunately it comes # before the non-optional combine_fn parameter, which we can't # change for backwards compatibility reasons. # # Instead, allow it to be omitted (by either passing two arguments # or combine_fn by keyword.) if combine_fn is None: if coder is None: raise ValueError('combine_fn must be provided') else: coder, combine_fn = None, coder self.combine_fn = CombineFn.maybe_from_callable(combine_fn) # The coder here should be for the accumulator type of the given CombineFn. if coder is None: coder = self.combine_fn.get_accumulator_coder() super().__init__(name, coder)
[docs] def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( combining_spec=beam_runner_api_pb2.CombiningStateSpec( combine_fn=self.combine_fn.to_runner_api(context), accumulator_coder_id=context.coders.get_id(self.coder)), protocol=beam_runner_api_pb2.FunctionSpec( urn=common_urns.user_state.BAG.urn))
# TODO(BEAM-9562): Update Timer to have of() and clear() APIs. Timer = NamedTuple( 'Timer', [ ('user_key', Any), ('dynamic_timer_tag', str), ('windows', Tuple['windowed_value.BoundedWindow', ...]), ('clear_bit', bool), ('fire_timestamp', Optional['Timestamp']), ('hold_timestamp', Optional['Timestamp']), ('paneinfo', Optional['windowed_value.PaneInfo']), ]) # TODO(BEAM-9562): Plumb through actual key_coder and window_coder.
[docs] class TimerSpec(object): """Specification for a user stateful DoFn timer.""" prefix = "ts-" def __init__(self, name: str, time_domain: str) -> None: = self.prefix + name if time_domain not in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME): raise ValueError('Unsupported TimeDomain: %r.' % (time_domain, )) self.time_domain = time_domain self._attached_callback: Optional[Callable] = None def __repr__(self) -> str: return '%s(%s)' % (self.__class__.__name__,
[docs] def to_runner_api( self, context: 'PipelineContext', key_coder: Coder, window_coder: Coder) -> beam_runner_api_pb2.TimerFamilySpec: return beam_runner_api_pb2.TimerFamilySpec( time_domain=TimeDomain.to_runner_api(self.time_domain), timer_family_coder_id=context.coders.get_id( coders._TimerCoder(key_coder, window_coder)))
[docs] def on_timer(timer_spec: TimerSpec) -> Callable[[CallableT], CallableT]: """Decorator for timer firing DoFn method. This decorator allows a user to specify an on_timer processing method in a stateful DoFn. Sample usage:: class MyDoFn(DoFn): TIMER_SPEC = TimerSpec('timer', TimeDomain.WATERMARK) @on_timer(TIMER_SPEC) def my_timer_expiry_callback(self):'Timer expired!') """ if not isinstance(timer_spec, TimerSpec): raise ValueError('@on_timer decorator expected TimerSpec.') def _inner(method: CallableT) -> CallableT: if not callable(method): raise ValueError('@on_timer decorator expected callable.') if timer_spec._attached_callback: raise ValueError( 'Multiple on_timer callbacks registered for %r.' % timer_spec) timer_spec._attached_callback = method return method return _inner
[docs] def get_dofn_specs(dofn: 'DoFn') -> Tuple[Set[StateSpec], Set[TimerSpec]]: """Gets the state and timer specs for a DoFn, if any. Args: dofn (apache_beam.transforms.core.DoFn): The DoFn instance to introspect for timer and state specs. """ # Avoid circular import. from apache_beam.runners.common import MethodWrapper from apache_beam.transforms.core import _DoFnParam from apache_beam.transforms.core import _StateDoFnParam from apache_beam.transforms.core import _TimerDoFnParam all_state_specs = set() all_timer_specs = set() # Validate params to process(), start_bundle(), finish_bundle() and to # any on_timer callbacks. for method_name in dir(dofn): if not isinstance(getattr(dofn, method_name, None), types.MethodType): continue method = MethodWrapper(dofn, method_name) param_ids = [ d.param_id for d in method.defaults if isinstance(d, _DoFnParam) ] if len(param_ids) != len(set(param_ids)): raise ValueError( 'DoFn %r has duplicate %s method parameters: %s.' % (dofn, method_name, param_ids)) for d in method.defaults: if isinstance(d, _StateDoFnParam): all_state_specs.add(d.state_spec) elif isinstance(d, _TimerDoFnParam): all_timer_specs.add(d.timer_spec) return all_state_specs, all_timer_specs
[docs] def is_stateful_dofn(dofn: 'DoFn') -> bool: """Determines whether a given DoFn is a stateful DoFn.""" # A Stateful DoFn is a DoFn that uses user state or timers. all_state_specs, all_timer_specs = get_dofn_specs(dofn) return bool(all_state_specs or all_timer_specs)
[docs] def validate_stateful_dofn(dofn: 'DoFn') -> None: """Validates the proper specification of a stateful DoFn.""" # Get state and timer specs. all_state_specs, all_timer_specs = get_dofn_specs(dofn) # Reject DoFns that have multiple state or timer specs with the same name. if len(all_state_specs) != len(set( for s in all_state_specs)): raise ValueError( 'DoFn %r has multiple StateSpecs with the same name: %s.' % (dofn, all_state_specs)) if len(all_timer_specs) != len(set( for s in all_timer_specs)): raise ValueError( 'DoFn %r has multiple TimerSpecs with the same name: %s.' % (dofn, all_timer_specs)) # Reject DoFns that use timer specs without corresponding timer callbacks. for timer_spec in all_timer_specs: if not timer_spec._attached_callback: raise ValueError(( 'DoFn %r has a TimerSpec without an associated on_timer ' 'callback: %s.') % (dofn, timer_spec)) method_name = timer_spec._attached_callback.__name__ if (timer_spec._attached_callback != getattr(dofn, method_name, None).__func__): raise ValueError(( 'The on_timer callback for %s is not the specified .%s method ' 'for DoFn %r (perhaps it was overwritten?).') % (timer_spec, method_name, dofn))
[docs] class BaseTimer(object):
[docs] def clear(self, dynamic_timer_tag: str = '') -> None: raise NotImplementedError
[docs] def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None: raise NotImplementedError
_TimerTuple = collections.namedtuple('timer_tuple', ('cleared', 'timestamp'))
[docs] class RuntimeTimer(BaseTimer): """Timer interface object passed to user code.""" def __init__(self) -> None: self._timer_recordings: Dict[str, _TimerTuple] = {} self._cleared = False self._new_timestamp: Optional[Timestamp] = None
[docs] def clear(self, dynamic_timer_tag: str = '') -> None: self._timer_recordings[dynamic_timer_tag] = _TimerTuple( cleared=True, timestamp=None)
[docs] def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None: self._timer_recordings[dynamic_timer_tag] = _TimerTuple( cleared=False, timestamp=timestamp)
[docs] class RuntimeState(object): """State interface object passed to user code."""
[docs] def prefetch(self) -> None: # The default implementation here does nothing. pass
[docs] def finalize(self) -> None: pass
[docs] class ReadModifyWriteRuntimeState(RuntimeState):
[docs] def read(self) -> Any: raise NotImplementedError(type(self))
[docs] def write(self, value: Any) -> None: raise NotImplementedError(type(self))
[docs] def clear(self) -> None: raise NotImplementedError(type(self))
[docs] def commit(self) -> None: raise NotImplementedError(type(self))
[docs] class AccumulatingRuntimeState(RuntimeState):
[docs] def read(self) -> Iterable[Any]: raise NotImplementedError(type(self))
[docs] def add(self, value: Any) -> None: raise NotImplementedError(type(self))
[docs] def clear(self) -> None: raise NotImplementedError(type(self))
[docs] def commit(self) -> None: raise NotImplementedError(type(self))
[docs] class BagRuntimeState(AccumulatingRuntimeState): """Bag state interface object passed to user code."""
[docs] class SetRuntimeState(AccumulatingRuntimeState): """Set state interface object passed to user code."""
[docs] class CombiningValueRuntimeState(AccumulatingRuntimeState): """Combining value state interface object passed to user code."""
[docs] class UserStateContext(object): """Wrapper allowing user state and timers to be accessed by a DoFnInvoker."""
[docs] def get_timer( self, timer_spec: TimerSpec, key: Any, window: 'windowed_value.BoundedWindow', timestamp: Timestamp, pane: windowed_value.PaneInfo, ) -> BaseTimer: raise NotImplementedError(type(self))
[docs] def get_state( self, state_spec: StateSpec, key: Any, window: 'windowed_value.BoundedWindow', ) -> RuntimeState: raise NotImplementedError(type(self))
[docs] def commit(self) -> None: raise NotImplementedError(type(self))