#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""User-facing interfaces for the Beam State and Timer APIs."""
# pytype: skip-file
# mypy: disallow-untyped-defs
import collections
import types
from collections.abc import Callable
from collections.abc import Iterable
from typing import TYPE_CHECKING
from typing import Any
from typing import NamedTuple
from typing import Optional
from typing import TypeVar
from apache_beam.coders import Coder
from apache_beam.coders import coders
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.utils import windowed_value
from apache_beam.utils.timestamp import Timestamp
if TYPE_CHECKING:
from apache_beam.runners.pipeline_context import PipelineContext
from apache_beam.transforms.core import DoFn
CallableT = TypeVar('CallableT', bound=Callable)
[docs]
class StateSpec(object):
"""Specification for a user DoFn state cell."""
def __init__(self, name: str, coder: Coder) -> None:
if not isinstance(name, str):
raise TypeError("name is not a string")
if not isinstance(coder, Coder):
raise TypeError("coder is not of type Coder")
self.name = name
self.coder = coder
def __repr__(self) -> str:
return '%s(%s)' % (self.__class__.__name__, self.name)
[docs]
def to_runner_api(
self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec:
raise NotImplementedError
[docs]
class ReadModifyWriteStateSpec(StateSpec):
"""Specification for a user DoFn value state cell."""
[docs]
def to_runner_api(
self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec:
return beam_runner_api_pb2.StateSpec(
read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec(
coder_id=context.coders.get_id(self.coder)),
protocol=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.user_state.BAG.urn))
[docs]
class BagStateSpec(StateSpec):
"""Specification for a user DoFn bag state cell."""
[docs]
def to_runner_api(
self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec:
return beam_runner_api_pb2.StateSpec(
bag_spec=beam_runner_api_pb2.BagStateSpec(
element_coder_id=context.coders.get_id(self.coder)),
protocol=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.user_state.BAG.urn))
[docs]
class SetStateSpec(StateSpec):
"""Specification for a user DoFn Set State cell"""
[docs]
def to_runner_api(
self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec:
return beam_runner_api_pb2.StateSpec(
set_spec=beam_runner_api_pb2.SetStateSpec(
element_coder_id=context.coders.get_id(self.coder)),
protocol=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.user_state.BAG.urn))
[docs]
class CombiningValueStateSpec(StateSpec):
"""Specification for a user DoFn combining value state cell."""
def __init__(
self,
name: str,
coder: Optional[Coder] = None,
combine_fn: Any = None) -> None:
"""Initialize the specification for CombiningValue state.
CombiningValueStateSpec(name, combine_fn) -> Coder-inferred combining value
state spec.
CombiningValueStateSpec(name, coder, combine_fn) -> Combining value state
spec with coder and combine_fn specified.
Args:
name (str): The name by which the state is identified.
coder (Coder): Coder specifying how to encode the values to be combined.
May be inferred.
combine_fn (``CombineFn`` or ``callable``): Function specifying how to
combine the values passed to state.
"""
# Avoid circular import.
from apache_beam.transforms.core import CombineFn
# We want the coder to be optional, but unfortunately it comes
# before the non-optional combine_fn parameter, which we can't
# change for backwards compatibility reasons.
#
# Instead, allow it to be omitted (by either passing two arguments
# or combine_fn by keyword.)
if combine_fn is None:
if coder is None:
raise ValueError('combine_fn must be provided')
else:
coder, combine_fn = None, coder
self.combine_fn = CombineFn.maybe_from_callable(combine_fn)
# The coder here should be for the accumulator type of the given CombineFn.
if coder is None:
coder = self.combine_fn.get_accumulator_coder()
super().__init__(name, coder)
[docs]
def to_runner_api(
self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec:
return beam_runner_api_pb2.StateSpec(
combining_spec=beam_runner_api_pb2.CombiningStateSpec(
combine_fn=self.combine_fn.to_runner_api(context),
accumulator_coder_id=context.coders.get_id(self.coder)),
protocol=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.user_state.BAG.urn))
[docs]
class OrderedListStateSpec(StateSpec):
"""Specification for a user DoFn ordered list state cell."""
[docs]
def to_runner_api(
self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec:
return beam_runner_api_pb2.StateSpec(
ordered_list_spec=beam_runner_api_pb2.OrderedListStateSpec(
element_coder_id=context.coders.get_id(self.coder)),
protocol=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.user_state.ORDERED_LIST.urn))
# TODO(BEAM-9562): Update Timer to have of() and clear() APIs.
Timer = NamedTuple(
'Timer',
[
('user_key', Any),
('dynamic_timer_tag', str),
('windows', tuple['windowed_value.BoundedWindow', ...]),
('clear_bit', bool),
('fire_timestamp', Optional['Timestamp']),
('hold_timestamp', Optional['Timestamp']),
('paneinfo', Optional['windowed_value.PaneInfo']),
])
# TODO(BEAM-9562): Plumb through actual key_coder and window_coder.
[docs]
class TimerSpec(object):
"""Specification for a user stateful DoFn timer."""
prefix = "ts-"
def __init__(self, name: str, time_domain: str) -> None:
self.name = self.prefix + name
if time_domain not in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME):
raise ValueError('Unsupported TimeDomain: %r.' % (time_domain, ))
self.time_domain = time_domain
self._attached_callback: Optional[Callable] = None
def __repr__(self) -> str:
return '%s(%s)' % (self.__class__.__name__, self.name)
[docs]
def to_runner_api(
self, context: 'PipelineContext', key_coder: Coder,
window_coder: Coder) -> beam_runner_api_pb2.TimerFamilySpec:
return beam_runner_api_pb2.TimerFamilySpec(
time_domain=TimeDomain.to_runner_api(self.time_domain),
timer_family_coder_id=context.coders.get_id(
coders._TimerCoder(key_coder, window_coder)))
[docs]
def on_timer(timer_spec: TimerSpec) -> Callable[[CallableT], CallableT]:
"""Decorator for timer firing DoFn method.
This decorator allows a user to specify an on_timer processing method
in a stateful DoFn. Sample usage::
class MyDoFn(DoFn):
TIMER_SPEC = TimerSpec('timer', TimeDomain.WATERMARK)
@on_timer(TIMER_SPEC)
def my_timer_expiry_callback(self):
logging.info('Timer expired!')
"""
if not isinstance(timer_spec, TimerSpec):
raise ValueError('@on_timer decorator expected TimerSpec.')
def _inner(method: CallableT) -> CallableT:
if not callable(method):
raise ValueError('@on_timer decorator expected callable.')
if timer_spec._attached_callback:
raise ValueError(
'Multiple on_timer callbacks registered for %r.' % timer_spec)
timer_spec._attached_callback = method
return method
return _inner
[docs]
def get_dofn_specs(dofn: 'DoFn') -> tuple[set[StateSpec], set[TimerSpec]]:
"""Gets the state and timer specs for a DoFn, if any.
Args:
dofn (apache_beam.transforms.core.DoFn): The DoFn instance to introspect for
timer and state specs.
"""
# Avoid circular import.
from apache_beam.runners.common import MethodWrapper
from apache_beam.transforms.core import _DoFnParam
from apache_beam.transforms.core import _StateDoFnParam
from apache_beam.transforms.core import _TimerDoFnParam
all_state_specs = set()
all_timer_specs = set()
# Validate params to process(), start_bundle(), finish_bundle() and to
# any on_timer callbacks.
for method_name in dir(dofn):
if not isinstance(getattr(dofn, method_name, None), types.MethodType):
continue
method = MethodWrapper(dofn, method_name)
param_ids = [
d.param_id for d in method.defaults if isinstance(d, _DoFnParam)
]
if len(param_ids) != len(set(param_ids)):
raise ValueError(
'DoFn %r has duplicate %s method parameters: %s.' %
(dofn, method_name, param_ids))
for d in method.defaults:
if isinstance(d, _StateDoFnParam):
all_state_specs.add(d.state_spec)
elif isinstance(d, _TimerDoFnParam):
all_timer_specs.add(d.timer_spec)
return all_state_specs, all_timer_specs
[docs]
def is_stateful_dofn(dofn: 'DoFn') -> bool:
"""Determines whether a given DoFn is a stateful DoFn."""
# A Stateful DoFn is a DoFn that uses user state or timers.
all_state_specs, all_timer_specs = get_dofn_specs(dofn)
return bool(all_state_specs or all_timer_specs)
[docs]
def validate_stateful_dofn(dofn: 'DoFn') -> None:
"""Validates the proper specification of a stateful DoFn."""
# Get state and timer specs.
all_state_specs, all_timer_specs = get_dofn_specs(dofn)
# Reject DoFns that have multiple state or timer specs with the same name.
if len(all_state_specs) != len(set(s.name for s in all_state_specs)):
raise ValueError(
'DoFn %r has multiple StateSpecs with the same name: %s.' %
(dofn, all_state_specs))
if len(all_timer_specs) != len(set(s.name for s in all_timer_specs)):
raise ValueError(
'DoFn %r has multiple TimerSpecs with the same name: %s.' %
(dofn, all_timer_specs))
# Reject DoFns that use timer specs without corresponding timer callbacks.
for timer_spec in all_timer_specs:
if not timer_spec._attached_callback:
raise ValueError((
'DoFn %r has a TimerSpec without an associated on_timer '
'callback: %s.') % (dofn, timer_spec))
method_name = timer_spec._attached_callback.__name__
if (timer_spec._attached_callback != getattr(dofn, method_name,
None).__func__): # type: ignore[union-attr]
raise ValueError((
'The on_timer callback for %s is not the specified .%s method '
'for DoFn %r (perhaps it was overwritten?).') %
(timer_spec, method_name, dofn))
[docs]
class BaseTimer(object):
[docs]
def clear(self, dynamic_timer_tag: str = '') -> None:
raise NotImplementedError
[docs]
def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None:
raise NotImplementedError
_TimerTuple = collections.namedtuple('timer_tuple', ('cleared', 'timestamp')) # type: ignore[name-match]
[docs]
class RuntimeTimer(BaseTimer):
"""Timer interface object passed to user code."""
def __init__(self) -> None:
self._timer_recordings: dict[str, _TimerTuple] = {}
self._cleared = False
self._new_timestamp: Optional[Timestamp] = None
[docs]
def clear(self, dynamic_timer_tag: str = '') -> None:
self._timer_recordings[dynamic_timer_tag] = _TimerTuple(
cleared=True, timestamp=None)
[docs]
def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None:
self._timer_recordings[dynamic_timer_tag] = _TimerTuple(
cleared=False, timestamp=timestamp)
[docs]
class RuntimeState(object):
"""State interface object passed to user code."""
[docs]
def prefetch(self) -> None:
# The default implementation here does nothing.
pass
[docs]
def finalize(self) -> None:
pass
[docs]
class ReadModifyWriteRuntimeState(RuntimeState):
[docs]
def read(self) -> Any:
raise NotImplementedError(type(self))
[docs]
def write(self, value: Any) -> None:
raise NotImplementedError(type(self))
[docs]
def clear(self) -> None:
raise NotImplementedError(type(self))
[docs]
def commit(self) -> None:
raise NotImplementedError(type(self))
[docs]
class AccumulatingRuntimeState(RuntimeState):
[docs]
def read(self) -> Iterable[Any]:
raise NotImplementedError(type(self))
[docs]
def add(self, value: Any) -> None:
raise NotImplementedError(type(self))
[docs]
def clear(self) -> None:
raise NotImplementedError(type(self))
[docs]
def commit(self) -> None:
raise NotImplementedError(type(self))
[docs]
class BagRuntimeState(AccumulatingRuntimeState):
"""Bag state interface object passed to user code."""
[docs]
class SetRuntimeState(AccumulatingRuntimeState):
"""Set state interface object passed to user code."""
[docs]
class CombiningValueRuntimeState(AccumulatingRuntimeState):
"""Combining value state interface object passed to user code."""
[docs]
class OrderedListRuntimeState(AccumulatingRuntimeState):
"""Ordered list state interface object passed to user code."""
[docs]
def read(self) -> Iterable[tuple[Timestamp, Any]]:
raise NotImplementedError(type(self))
[docs]
def add(self, value: tuple[Timestamp, Any]) -> None:
raise NotImplementedError(type(self))
[docs]
def read_range(
self, min_time_stamp: Timestamp,
limit_time_stamp: Timestamp) -> Iterable[tuple[Timestamp, Any]]:
raise NotImplementedError(type(self))
[docs]
def clear_range(
self, min_time_stamp: Timestamp, limit_time_stamp: Timestamp) -> None:
raise NotImplementedError(type(self))
[docs]
class UserStateContext(object):
"""Wrapper allowing user state and timers to be accessed by a DoFnInvoker."""
[docs]
def get_timer(
self,
timer_spec: TimerSpec,
key: Any,
window: 'windowed_value.BoundedWindow',
timestamp: Timestamp,
pane: windowed_value.PaneInfo,
) -> BaseTimer:
raise NotImplementedError(type(self))
[docs]
def get_state(
self,
state_spec: StateSpec,
key: Any,
window: 'windowed_value.BoundedWindow',
) -> RuntimeState:
raise NotImplementedError(type(self))
[docs]
def commit(self) -> None:
raise NotImplementedError(type(self))