Source code for apache_beam.transforms.util

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Simple utility PTransforms.
"""

# pytype: skip-file

import collections
import contextlib
import logging
import random
import re
import threading
import time
import uuid
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import List
from typing import Tuple
from typing import TypeVar
from typing import Union

from apache_beam import coders
from apache_beam import typehints
from apache_beam.metrics import Metrics
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import AsSideInput
from apache_beam.transforms import window
from apache_beam.transforms.combiners import CountCombineFn
from apache_beam.transforms.core import CombinePerKey
from apache_beam.transforms.core import Create
from apache_beam.transforms.core import DoFn
from apache_beam.transforms.core import FlatMap
from apache_beam.transforms.core import Flatten
from apache_beam.transforms.core import GroupByKey
from apache_beam.transforms.core import Map
from apache_beam.transforms.core import MapTuple
from apache_beam.transforms.core import ParDo
from apache_beam.transforms.core import Windowing
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.ptransform import ptransform_fn
from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.transforms.trigger import AccumulationMode
from apache_beam.transforms.trigger import Always
from apache_beam.transforms.userstate import BagStateSpec
from apache_beam.transforms.userstate import CombiningValueStateSpec
from apache_beam.transforms.userstate import TimerSpec
from apache_beam.transforms.userstate import on_timer
from apache_beam.transforms.window import NonMergingWindowFn
from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import TimestampedValue
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.decorators import get_signature
from apache_beam.typehints.sharded_key_type import ShardedKeyType
from apache_beam.utils import windowed_value
from apache_beam.utils.annotations import deprecated
from apache_beam.utils.sharded_key import ShardedKey

if TYPE_CHECKING:
  from apache_beam import pvalue
  from apache_beam.runners.pipeline_context import PipelineContext

__all__ = [
    'BatchElements',
    'CoGroupByKey',
    'Distinct',
    'Keys',
    'KvSwap',
    'LogElements',
    'Regex',
    'Reify',
    'RemoveDuplicates',
    'Reshuffle',
    'ToString',
    'Values',
    'WithKeys',
    'GroupIntoBatches'
]

K = TypeVar('K')
V = TypeVar('V')
T = TypeVar('T')


[docs]class CoGroupByKey(PTransform): """Groups results across several PCollections by key. Given an input dict of serializable keys (called "tags") to 0 or more PCollections of (key, value) tuples, it creates a single output PCollection of (key, value) tuples whose keys are the unique input keys from all inputs, and whose values are dicts mapping each tag to an iterable of whatever values were under the key in the corresponding PCollection, in this manner:: ('some key', {'tag1': ['value 1 under "some key" in pcoll1', 'value 2 under "some key" in pcoll1', ...], 'tag2': ... , ... }) where `[]` refers to an iterable, not a list. For example, given:: {'tag1': pc1, 'tag2': pc2, 333: pc3} where:: pc1 = beam.Create([(k1, v1)])) pc2 = beam.Create([]) pc3 = beam.Create([(k1, v31), (k1, v32), (k2, v33)]) The output PCollection would consist of items:: [(k1, {'tag1': [v1], 'tag2': [], 333: [v31, v32]}), (k2, {'tag1': [], 'tag2': [], 333: [v33]})] where `[]` refers to an iterable, not a list. CoGroupByKey also works for tuples, lists, or other flat iterables of PCollections, in which case the values of the resulting PCollections will be tuples whose nth value is the iterable of values from the nth PCollection---conceptually, the "tags" are the indices into the input. Thus, for this input:: (pc1, pc2, pc3) the output would be:: [(k1, ([v1], [], [v31, v32]), (k2, ([], [], [v33]))] where, again, `[]` refers to an iterable, not a list. Attributes: **kwargs: Accepts a single named argument "pipeline", which specifies the pipeline that "owns" this PTransform. Ordinarily CoGroupByKey can obtain this information from one of the input PCollections, but if there are none (or if there's a chance there may be none), this argument is the only way to provide pipeline information, and should be considered mandatory. """ def __init__(self, *, pipeline=None): self.pipeline = pipeline def _extract_input_pvalues(self, pvalueish): try: # If this works, it's a dict. return pvalueish, tuple(pvalueish.values()) except AttributeError: # Cast iterables a tuple so we can do re-iteration. pcolls = tuple(pvalueish) return pcolls, pcolls
[docs] def expand(self, pcolls): if not pcolls: pcolls = (self.pipeline | Create([]), ) if isinstance(pcolls, dict): tags = list(pcolls.keys()) if all(isinstance(tag, str) and len(tag) < 10 for tag in tags): # Small, string tags. Pass them as data. pcolls_dict = pcolls restore_tags = None else: # Pass the tags in the restore_tags closure. tags = list(pcolls.keys()) pcolls_dict = {str(ix): pcolls[tag] for (ix, tag) in enumerate(tags)} restore_tags = lambda vs: { tag: vs[str(ix)] for (ix, tag) in enumerate(tags) } else: # Tags are tuple indices. tags = [str(ix) for ix in range(len(pcolls))] pcolls_dict = dict(zip(tags, pcolls)) restore_tags = lambda vs: tuple(vs[tag] for tag in tags) input_key_types = [] input_value_types = [] for pcoll in pcolls_dict.values(): key_type, value_type = typehints.trivial_inference.key_value_types( pcoll.element_type) input_key_types.append(key_type) input_value_types.append(value_type) output_key_type = typehints.Union[tuple(input_key_types)] iterable_input_value_types = tuple( typehints.Iterable[t] for t in input_value_types) output_value_type = typehints.Dict[ str, typehints.Union[iterable_input_value_types or [typehints.Any]]] result = ( pcolls_dict | 'CoGroupByKeyImpl' >> _CoGBKImpl(pipeline=self.pipeline).with_output_types( typehints.Tuple[output_key_type, output_value_type])) if restore_tags: if isinstance(pcolls, dict): dict_key_type = typehints.Union[tuple( trivial_inference.instance_to_type(tag) for tag in tags)] output_value_type = typehints.Dict[ dict_key_type, typehints.Union[iterable_input_value_types]] else: output_value_type = typehints.Tuple[iterable_input_value_types] result |= 'RestoreTags' >> MapTuple( lambda k, vs: (k, restore_tags(vs))).with_output_types( typehints.Tuple[output_key_type, output_value_type]) return result
class _CoGBKImpl(PTransform): def __init__(self, *, pipeline=None): self.pipeline = pipeline def expand(self, pcolls): # Check input PCollections for PCollection-ness, and that they all belong # to the same pipeline. for pcoll in pcolls.values(): self._check_pcollection(pcoll) if self.pipeline: assert pcoll.pipeline == self.pipeline, ( 'All input PCollections must belong to the same pipeline.') tags = list(pcolls.keys()) def add_tag(tag): return lambda k, v: (k, (tag, v)) def collect_values(key, tagged_values): grouped_values = {tag: [] for tag in tags} for tag, value in tagged_values: grouped_values[tag].append(value) return key, grouped_values return ([ pcoll | 'Tag[%s]' % tag >> MapTuple(add_tag(tag)) for (tag, pcoll) in pcolls.items() ] | Flatten(pipeline=self.pipeline) | GroupByKey() | MapTuple(collect_values))
[docs]@ptransform_fn @typehints.with_input_types(Tuple[K, V]) @typehints.with_output_types(K) def Keys(pcoll, label='Keys'): # pylint: disable=invalid-name """Produces a PCollection of first elements of 2-tuples in a PCollection.""" return pcoll | label >> MapTuple(lambda k, _: k)
[docs]@ptransform_fn @typehints.with_input_types(Tuple[K, V]) @typehints.with_output_types(V) def Values(pcoll, label='Values'): # pylint: disable=invalid-name """Produces a PCollection of second elements of 2-tuples in a PCollection.""" return pcoll | label >> MapTuple(lambda _, v: v)
[docs]@ptransform_fn @typehints.with_input_types(Tuple[K, V]) @typehints.with_output_types(Tuple[V, K]) def KvSwap(pcoll, label='KvSwap'): # pylint: disable=invalid-name """Produces a PCollection reversing 2-tuples in a PCollection.""" return pcoll | label >> MapTuple(lambda k, v: (v, k))
[docs]@ptransform_fn @typehints.with_input_types(T) @typehints.with_output_types(T) def Distinct(pcoll): # pylint: disable=invalid-name """Produces a PCollection containing distinct elements of a PCollection.""" return ( pcoll | 'ToPairs' >> Map(lambda v: (v, None)) | 'Group' >> CombinePerKey(lambda vs: None) | 'Distinct' >> Keys())
[docs]@deprecated(since='2.12', current='Distinct') @ptransform_fn @typehints.with_input_types(T) @typehints.with_output_types(T) def RemoveDuplicates(pcoll): """Produces a PCollection containing distinct elements of a PCollection.""" return pcoll | 'RemoveDuplicates' >> Distinct()
class _BatchSizeEstimator(object): """Estimates the best size for batches given historical timing. """ _MAX_DATA_POINTS = 100 _MAX_GROWTH_FACTOR = 2 def __init__( self, min_batch_size=1, max_batch_size=10000, target_batch_overhead=.05, target_batch_duration_secs=10, target_batch_duration_secs_including_fixed_cost=None, variance=0.25, clock=time.time, ignore_first_n_seen_per_batch_size=0, record_metrics=True): if min_batch_size > max_batch_size: raise ValueError( "Minimum (%s) must not be greater than maximum (%s)" % (min_batch_size, max_batch_size)) if target_batch_overhead and not 0 < target_batch_overhead <= 1: raise ValueError( "target_batch_overhead (%s) must be between 0 and 1" % (target_batch_overhead)) if target_batch_duration_secs and target_batch_duration_secs <= 0: raise ValueError( "target_batch_duration_secs (%s) must be positive" % (target_batch_duration_secs)) if (target_batch_duration_secs_including_fixed_cost and target_batch_duration_secs_including_fixed_cost <= 0): raise ValueError( "target_batch_duration_secs_including_fixed_cost " "(%s) must be positive" % (target_batch_duration_secs_including_fixed_cost)) if not (target_batch_overhead or target_batch_duration_secs or target_batch_duration_secs_including_fixed_cost): raise ValueError( "At least one of target_batch_overhead or " "target_batch_duration_secs or " "target_batch_duration_secs_including_fixed_cost must be positive.") if ignore_first_n_seen_per_batch_size < 0: raise ValueError( 'ignore_first_n_seen_per_batch_size (%s) must be non ' 'negative' % (ignore_first_n_seen_per_batch_size)) self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._target_batch_overhead = target_batch_overhead self._target_batch_duration_secs = target_batch_duration_secs self._target_batch_duration_secs_including_fixed_cost = ( target_batch_duration_secs_including_fixed_cost) self._variance = variance self._clock = clock self._data = [] self._ignore_next_timing = False self._ignore_first_n_seen_per_batch_size = ( ignore_first_n_seen_per_batch_size) self._batch_size_num_seen = {} self._replay_last_batch_size = None self._record_metrics = record_metrics self._element_count = 0 self._batch_count = 0 if record_metrics: self._size_distribution = Metrics.distribution( 'BatchElements', 'batch_size') self._time_distribution = Metrics.distribution( 'BatchElements', 'msec_per_batch') else: self._size_distribution = self._time_distribution = None # Beam distributions only accept integer values, so we use this to # accumulate under-reported values until they add up to whole milliseconds. # (Milliseconds are chosen because that's conventionally used elsewhere in # profiling-style counters.) self._remainder_msecs = 0 def ignore_next_timing(self): """Call to indicate the next timing should be ignored. For example, the first emit of a ParDo operation is known to be anomalous due to setup that may occur. """ self._ignore_next_timing = True @contextlib.contextmanager def record_time(self, batch_size): start = self._clock() yield elapsed = self._clock() - start elapsed_msec = 1e3 * elapsed + self._remainder_msecs if self._record_metrics: self._size_distribution.update(batch_size) self._time_distribution.update(int(elapsed_msec)) self._element_count += batch_size self._batch_count += 1 self._remainder_msecs = elapsed_msec - int(elapsed_msec) # If we ignore the next timing, replay the batch size to get accurate # timing. if self._ignore_next_timing: self._ignore_next_timing = False self._replay_last_batch_size = min(batch_size, self._max_batch_size) else: self._data.append((batch_size, elapsed)) if len(self._data) >= self._MAX_DATA_POINTS: self._thin_data() def _thin_data(self): # Make sure we don't change the parity of len(self._data) # As it's used below to alternate jitter. self._data.pop(random.randrange(len(self._data) // 4)) self._data.pop(random.randrange(len(self._data) // 2)) @staticmethod def linear_regression_no_numpy(xs, ys): # Least squares fit for y = a + bx over all points. n = float(len(xs)) xbar = sum(xs) / n ybar = sum(ys) / n if xbar == 0: return ybar, 0 if all(xs[0] == x for x in xs): # Simply use the mean if all values in xs are same. return 0, ybar / xbar b = ( sum([(x - xbar) * (y - ybar) for x, y in zip(xs, ys)]) / sum([(x - xbar)**2 for x in xs])) a = ybar - b * xbar return a, b @staticmethod def linear_regression_numpy(xs, ys): # pylint: disable=wrong-import-order, wrong-import-position import numpy as np from numpy import sum n = len(xs) if all(xs[0] == x for x in xs): # If all values of xs are same then fallback to linear_regression_no_numpy return _BatchSizeEstimator.linear_regression_no_numpy(xs, ys) xs = np.asarray(xs, dtype=float) ys = np.asarray(ys, dtype=float) # First do a simple least squares fit for y = a + bx over all points. b, a = np.polyfit(xs, ys, 1) if n < 10: return a, b else: # Refine this by throwing out outliers, according to Cook's distance. # https://en.wikipedia.org/wiki/Cook%27s_distance sum_x = sum(xs) sum_x2 = sum(xs**2) errs = a + b * xs - ys s2 = sum(errs**2) / (n - 2) if s2 == 0: # It's an exact fit! return a, b h = (sum_x2 - 2 * sum_x * xs + n * xs**2) / (n * sum_x2 - sum_x**2) cook_ds = 0.5 / s2 * errs**2 * (h / (1 - h)**2) # Re-compute the regression, excluding those points with Cook's distance # greater than 0.5, and weighting by the inverse of x to give a more # stable y-intercept (as small batches have relatively more information # about the fixed overhead). weight = (cook_ds <= 0.5) / xs b, a = np.polyfit(xs, ys, 1, w=weight) return a, b try: # pylint: disable=wrong-import-order, wrong-import-position import numpy as np linear_regression = linear_regression_numpy except ImportError: linear_regression = linear_regression_no_numpy def _calculate_next_batch_size(self): if self._min_batch_size == self._max_batch_size: return self._min_batch_size elif len(self._data) < 1: return self._min_batch_size elif len(self._data) < 2: # Force some variety so we have distinct batch sizes on which to do # linear regression below. return int( max( min( self._max_batch_size, self._min_batch_size * self._MAX_GROWTH_FACTOR), self._min_batch_size + 1)) # There tends to be a lot of noise in the top quantile, which also # has outsided influence in the regression. If we have enough data, # Simply declare the top 20% to be outliers. trimmed_data = sorted(self._data)[:max(20, len(self._data) * 4 // 5)] # Linear regression for y = a + bx, where x is batch size and y is time. xs, ys = zip(*trimmed_data) a, b = self.linear_regression(xs, ys) # Avoid nonsensical or division-by-zero errors below due to noise. a = max(a, 1e-10) b = max(b, 1e-20) last_batch_size = self._data[-1][0] cap = min(last_batch_size * self._MAX_GROWTH_FACTOR, self._max_batch_size) target = self._max_batch_size if self._target_batch_duration_secs_including_fixed_cost: # Solution to # a + b*x = self._target_batch_duration_secs_including_fixed_cost. target = min( target, (self._target_batch_duration_secs_including_fixed_cost - a) / b) if self._target_batch_duration_secs: # Solution to b*x = self._target_batch_duration_secs. # We ignore the fixed cost in this computation as it has negligeabel # impact when it is small and unhelpfully forces the minimum batch size # when it is large. target = min(target, self._target_batch_duration_secs / b) if self._target_batch_overhead: # Solution to a / (a + b*x) = self._target_batch_overhead. target = min(target, (a / b) * (1 / self._target_batch_overhead - 1)) # Avoid getting stuck at a single batch size (especially the minimal # batch size) which would not allow us to extrapolate to other batch # sizes. # Jitter alternates between 0 and 1. jitter = len(self._data) % 2 # Smear our samples across a range centered at the target. if len(self._data) > 10: target += int(target * self._variance * 2 * (random.random() - .5)) return int(max(self._min_batch_size + jitter, min(target, cap))) def next_batch_size(self): # Check if we should replay a previous batch size due to it not being # recorded. if self._replay_last_batch_size: result = self._replay_last_batch_size self._replay_last_batch_size = None else: result = self._calculate_next_batch_size() seen_count = self._batch_size_num_seen.get(result, 0) + 1 if seen_count <= self._ignore_first_n_seen_per_batch_size: self.ignore_next_timing() self._batch_size_num_seen[result] = seen_count return result def stats(self): return "element_count=%s batch_count=%s next_batch_size=%s timings=%s" % ( self._element_count, self._batch_count, self._calculate_next_batch_size(), self._data) class _GlobalWindowsBatchingDoFn(DoFn): def __init__(self, batch_size_estimator, element_size_fn): self._batch_size_estimator = batch_size_estimator self._element_size_fn = element_size_fn def start_bundle(self): self._batch = [] self._running_batch_size = 0 self._target_batch_size = self._batch_size_estimator.next_batch_size() # The first emit often involves non-trivial setup. self._batch_size_estimator.ignore_next_timing() def process(self, element): self._batch.append(element) self._running_batch_size += self._element_size_fn(element) if self._running_batch_size >= self._target_batch_size: with self._batch_size_estimator.record_time(self._running_batch_size): yield window.GlobalWindows.windowed_value_at_end_of_window(self._batch) self._batch = [] self._running_batch_size = 0 self._target_batch_size = self._batch_size_estimator.next_batch_size() def finish_bundle(self): if self._batch: with self._batch_size_estimator.record_time(self._running_batch_size): yield window.GlobalWindows.windowed_value_at_end_of_window(self._batch) self._batch = None self._running_batch_size = 0 self._target_batch_size = self._batch_size_estimator.next_batch_size() logging.info( "BatchElements statistics: " + self._batch_size_estimator.stats()) class _SizedBatch(): def __init__(self): self.elements = [] self.size = 0 class _WindowAwareBatchingDoFn(DoFn): _MAX_LIVE_WINDOWS = 10 def __init__(self, batch_size_estimator, element_size_fn): self._batch_size_estimator = batch_size_estimator self._element_size_fn = element_size_fn def start_bundle(self): self._batches = collections.defaultdict(_SizedBatch) self._target_batch_size = self._batch_size_estimator.next_batch_size() # The first emit often involves non-trivial setup. self._batch_size_estimator.ignore_next_timing() def process(self, element, window=DoFn.WindowParam): batch = self._batches[window] batch.elements.append(element) batch.size += self._element_size_fn(element) if batch.size >= self._target_batch_size: with self._batch_size_estimator.record_time(batch.size): yield windowed_value.WindowedValue( batch.elements, window.max_timestamp(), (window, )) del self._batches[window] self._target_batch_size = self._batch_size_estimator.next_batch_size() elif len(self._batches) > self._MAX_LIVE_WINDOWS: window, batch = max( self._batches.items(), key=lambda window_batch: window_batch[1].size) with self._batch_size_estimator.record_time(batch.size): yield windowed_value.WindowedValue( batch.elements, window.max_timestamp(), (window, )) del self._batches[window] self._target_batch_size = self._batch_size_estimator.next_batch_size() def finish_bundle(self): for window, batch in self._batches.items(): if batch: with self._batch_size_estimator.record_time(batch.size): yield windowed_value.WindowedValue( batch.elements, window.max_timestamp(), (window, )) self._batches = None self._target_batch_size = self._batch_size_estimator.next_batch_size()
[docs]@typehints.with_input_types(T) @typehints.with_output_types(List[T]) class BatchElements(PTransform): """A Transform that batches elements for amortized processing. This transform is designed to precede operations whose processing cost is of the form time = fixed_cost + num_elements * per_element_cost where the per element cost is (often significantly) smaller than the fixed cost and could be amortized over multiple elements. It consumes a PCollection of element type T and produces a PCollection of element type List[T]. This transform attempts to find the best batch size between the minimim and maximum parameters by profiling the time taken by (fused) downstream operations. For a fixed batch size, set the min and max to be equal. Elements are batched per-window and batches emitted in the window corresponding to its contents. Each batch is emitted with a timestamp at the end of their window. Args: min_batch_size: (optional) the smallest size of a batch max_batch_size: (optional) the largest size of a batch target_batch_overhead: (optional) a target for fixed_cost / time, as used in the formula above target_batch_duration_secs: (optional) a target for total time per bundle, in seconds, excluding fixed cost target_batch_duration_secs_including_fixed_cost: (optional) a target for total time per bundle, in seconds, including fixed cost element_size_fn: (optional) A mapping of an element to its contribution to batch size, defaulting to every element having size 1. When provided, attempts to provide batches of optimal total size which may consist of a varying number of elements. variance: (optional) the permitted (relative) amount of deviation from the (estimated) ideal batch size used to produce a wider base for linear interpolation clock: (optional) an alternative to time.time for measuring the cost of donwstream operations (mostly for testing) record_metrics: (optional) whether or not to record beam metrics on distributions of the batch size. Defaults to True. """ def __init__( self, min_batch_size=1, max_batch_size=10000, target_batch_overhead=.05, target_batch_duration_secs=10, target_batch_duration_secs_including_fixed_cost=None, *, element_size_fn=lambda x: 1, variance=0.25, clock=time.time, record_metrics=True): self._batch_size_estimator = _BatchSizeEstimator( min_batch_size=min_batch_size, max_batch_size=max_batch_size, target_batch_overhead=target_batch_overhead, target_batch_duration_secs=target_batch_duration_secs, target_batch_duration_secs_including_fixed_cost=( target_batch_duration_secs_including_fixed_cost), variance=variance, clock=clock, record_metrics=record_metrics) self._element_size_fn = element_size_fn
[docs] def expand(self, pcoll): if getattr(pcoll.pipeline.runner, 'is_streaming', False): raise NotImplementedError("Requires stateful processing (BEAM-2687)") elif pcoll.windowing.is_default(): # This is the same logic as _GlobalWindowsBatchingDoFn, but optimized # for that simpler case. return pcoll | ParDo( _GlobalWindowsBatchingDoFn( self._batch_size_estimator, self._element_size_fn)) else: return pcoll | ParDo( _WindowAwareBatchingDoFn( self._batch_size_estimator, self._element_size_fn))
class _IdentityWindowFn(NonMergingWindowFn): """Windowing function that preserves existing windows. To be used internally with the Reshuffle transform. Will raise an exception when used after DoFns that return TimestampedValue elements. """ def __init__(self, window_coder): """Create a new WindowFn with compatible coder. To be applied to PCollections with windows that are compatible with the given coder. Arguments: window_coder: coders.Coder object to be used on windows. """ super().__init__() if window_coder is None: raise ValueError('window_coder should not be None') self._window_coder = window_coder def assign(self, assign_context): if assign_context.window is None: raise ValueError( 'assign_context.window should not be None. ' 'This might be due to a DoFn returning a TimestampedValue.') return [assign_context.window] def get_window_coder(self): return self._window_coder @typehints.with_input_types(Tuple[K, V]) @typehints.with_output_types(Tuple[K, V]) class ReshufflePerKey(PTransform): """PTransform that returns a PCollection equivalent to its input, but operationally provides some of the side effects of a GroupByKey, in particular checkpointing, and preventing fusion of the surrounding transforms. """ def expand(self, pcoll): windowing_saved = pcoll.windowing if windowing_saved.is_default(): # In this (common) case we can use a trivial trigger driver # and avoid the (expensive) window param. globally_windowed = window.GlobalWindows.windowed_value(None) MIN_TIMESTAMP = window.MIN_TIMESTAMP def reify_timestamps(element, timestamp=DoFn.TimestampParam): key, value = element if timestamp == MIN_TIMESTAMP: timestamp = None return key, (value, timestamp) def restore_timestamps(element): key, values = element return [ globally_windowed.with_value((key, value)) if timestamp is None else window.GlobalWindows.windowed_value((key, value), timestamp) for (value, timestamp) in values ] else: # typing: All conditional function variants must have identical signatures def reify_timestamps( # type: ignore[misc] element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): key, value = element # Transport the window as part of the value and restore it later. return key, windowed_value.WindowedValue(value, timestamp, [window]) def restore_timestamps(element): key, windowed_values = element return [wv.with_value((key, wv.value)) for wv in windowed_values] ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any) # TODO(https://github.com/apache/beam/issues/19785) Using global window as # one of the standard window. This is to mitigate the Dataflow Java Runner # Harness limitation to accept only standard coders. ungrouped._windowing = Windowing( window.GlobalWindows(), triggerfn=Always(), accumulation_mode=AccumulationMode.DISCARDING, timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST) result = ( ungrouped | GroupByKey() | FlatMap(restore_timestamps).with_output_types(Any)) result._windowing = windowing_saved return result
[docs]@typehints.with_input_types(T) @typehints.with_output_types(T) class Reshuffle(PTransform): """PTransform that returns a PCollection equivalent to its input, but operationally provides some of the side effects of a GroupByKey, in particular checkpointing, and preventing fusion of the surrounding transforms. Reshuffle adds a temporary random key to each element, performs a ReshufflePerKey, and finally removes the temporary key. """ # We use 32-bit integer as the default number of buckets. _DEFAULT_NUM_BUCKETS = 1 << 32 def __init__(self, num_buckets=None): """ :param num_buckets: If set, specifies the maximum random keys that would be generated. """ self.num_buckets = num_buckets if num_buckets else self._DEFAULT_NUM_BUCKETS valid_buckets = isinstance(num_buckets, int) and num_buckets > 0 if not (num_buckets is None or valid_buckets): raise ValueError( 'If `num_buckets` is set, it has to be an ' 'integer greater than 0, got %s' % num_buckets)
[docs] def expand(self, pcoll): # type: (pvalue.PValue) -> pvalue.PCollection return ( pcoll | 'AddRandomKeys' >> Map(lambda t: (random.randrange(0, self.num_buckets), t) ).with_input_types(T).with_output_types(Tuple[int, T]) | ReshufflePerKey() | 'RemoveRandomKeys' >> Map(lambda t: t[1]).with_input_types( Tuple[int, T]).with_output_types(T))
[docs] def to_runner_api_parameter(self, unused_context): # type: (PipelineContext) -> Tuple[str, None] return common_urns.composites.RESHUFFLE.urn, None
[docs] @staticmethod @PTransform.register_urn(common_urns.composites.RESHUFFLE.urn, None) def from_runner_api_parameter( unused_ptransform, unused_parameter, unused_context): return Reshuffle()
def fn_takes_side_inputs(fn): fn = getattr(fn, '_argspec_fn', fn) try: signature = get_signature(fn) except TypeError: # We can't tell; maybe it does. return True return ( len(signature.parameters) > 1 or any( p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD for p in signature.parameters.values()))
[docs]@ptransform_fn def WithKeys(pcoll, k, *args, **kwargs): """PTransform that takes a PCollection, and either a constant key or a callable, and returns a PCollection of (K, V), where each of the values in the input PCollection has been paired with either the constant key or a key computed from the value. The callable may optionally accept positional or keyword arguments, which should be passed to WithKeys directly. These may be either SideInputs or static (non-PCollection) values, such as ints. """ if callable(k): if fn_takes_side_inputs(k): if all(isinstance(arg, AsSideInput) for arg in args) and all(isinstance(kwarg, AsSideInput) for kwarg in kwargs.values()): return pcoll | Map( lambda v, *args, **kwargs: (k(v, *args, **kwargs), v), *args, **kwargs) return pcoll | Map(lambda v: (k(v, *args, **kwargs), v)) return pcoll | Map(lambda v: (k(v), v)) return pcoll | Map(lambda v: (k, v))
[docs]@typehints.with_input_types(Tuple[K, V]) @typehints.with_output_types(Tuple[K, Iterable[V]]) class GroupIntoBatches(PTransform): """PTransform that batches the input into desired batch size. Elements are buffered until they are equal to batch size provided in the argument at which point they are output to the output Pcollection. Windows are preserved (batches will contain elements from the same window) """ def __init__( self, batch_size, max_buffering_duration_secs=None, clock=time.time): """Create a new GroupIntoBatches. Arguments: batch_size: (required) How many elements should be in a batch max_buffering_duration_secs: (optional) How long in seconds at most an incomplete batch of elements is allowed to be buffered in the states. The duration must be a positive second duration and should be given as an int or float. Setting this parameter to zero effectively means no buffering limit. clock: (optional) an alternative to time.time (mostly for testing) """ self.params = _GroupIntoBatchesParams( batch_size, max_buffering_duration_secs) self.clock = clock
[docs] def expand(self, pcoll): input_coder = coders.registry.get_coder(pcoll) return pcoll | ParDo( _pardo_group_into_batches( input_coder, self.params.batch_size, self.params.max_buffering_duration_secs, self.clock))
[docs] def to_runner_api_parameter( self, unused_context # type: PipelineContext ): # type: (...) -> Tuple[str, beam_runner_api_pb2.GroupIntoBatchesPayload] return ( common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn, self.params.get_payload())
[docs] @staticmethod @PTransform.register_urn( common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn, beam_runner_api_pb2.GroupIntoBatchesPayload) def from_runner_api_parameter(unused_ptransform, proto, unused_context): return GroupIntoBatches(*_GroupIntoBatchesParams.parse_payload(proto))
[docs] @typehints.with_input_types(Tuple[K, V]) @typehints.with_output_types( typehints.Tuple[ ShardedKeyType[typehints.TypeVariable(K)], # type: ignore[misc] typehints.Iterable[typehints.TypeVariable(V)]]) class WithShardedKey(PTransform): """A GroupIntoBatches transform that outputs batched elements associated with sharded input keys. By default, keys are sharded to such that the input elements with the same key are spread to all available threads executing the transform. Runners may override the default sharding to do a better load balancing during the execution time. """ def __init__( self, batch_size, max_buffering_duration_secs=None, clock=time.time): """Create a new GroupIntoBatches with sharded output. See ``GroupIntoBatches`` transform for a description of input parameters. """ self.params = _GroupIntoBatchesParams( batch_size, max_buffering_duration_secs) self.clock = clock _shard_id_prefix = uuid.uuid4().bytes
[docs] def expand(self, pcoll): key_type, value_type = pcoll.element_type.tuple_types sharded_pcoll = pcoll | Map( lambda key_value: ( ShardedKey( key_value[0], # Use [uuid, thread id] as the shard id. GroupIntoBatches.WithShardedKey._shard_id_prefix + bytes( threading.get_ident().to_bytes(8, 'big'))), key_value[1])).with_output_types( typehints.Tuple[ ShardedKeyType[key_type], # type: ignore[misc] value_type]) return ( sharded_pcoll | GroupIntoBatches( self.params.batch_size, self.params.max_buffering_duration_secs, self.clock))
[docs] def to_runner_api_parameter( self, unused_context # type: PipelineContext ): # type: (...) -> Tuple[str, beam_runner_api_pb2.GroupIntoBatchesPayload] return ( common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn, self.params.get_payload())
[docs] @staticmethod @PTransform.register_urn( common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn, beam_runner_api_pb2.GroupIntoBatchesPayload) def from_runner_api_parameter(unused_ptransform, proto, unused_context): return GroupIntoBatches.WithShardedKey( *_GroupIntoBatchesParams.parse_payload(proto))
class _GroupIntoBatchesParams: """This class represents the parameters for :class:`apache_beam.utils.GroupIntoBatches` transform, used to define how elements should be batched. """ def __init__(self, batch_size, max_buffering_duration_secs): self.batch_size = batch_size self.max_buffering_duration_secs = ( 0 if max_buffering_duration_secs is None else max_buffering_duration_secs) self._validate() def __eq__(self, other): if other is None or not isinstance(other, _GroupIntoBatchesParams): return False return ( self.batch_size == other.batch_size and self.max_buffering_duration_secs == other.max_buffering_duration_secs) def _validate(self): assert self.batch_size is not None and self.batch_size > 0, ( 'batch_size must be a positive value') assert ( self.max_buffering_duration_secs is not None and self.max_buffering_duration_secs >= 0), ( 'max_buffering_duration must be a non-negative value') def get_payload(self): return beam_runner_api_pb2.GroupIntoBatchesPayload( batch_size=self.batch_size, max_buffering_duration_millis=int( self.max_buffering_duration_secs * 1000)) @staticmethod def parse_payload( proto # type: beam_runner_api_pb2.GroupIntoBatchesPayload ): return proto.batch_size, proto.max_buffering_duration_millis / 1000 def _pardo_group_into_batches( input_coder, batch_size, max_buffering_duration_secs, clock=time.time): ELEMENT_STATE = BagStateSpec('values', input_coder) COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn()) WINDOW_TIMER = TimerSpec('window_end', TimeDomain.WATERMARK) BUFFERING_TIMER = TimerSpec('buffering_end', TimeDomain.REAL_TIME) class _GroupIntoBatchesDoFn(DoFn): def process( self, element, window=DoFn.WindowParam, element_state=DoFn.StateParam(ELEMENT_STATE), count_state=DoFn.StateParam(COUNT_STATE), window_timer=DoFn.TimerParam(WINDOW_TIMER), buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): # Allowed lateness not supported in Python SDK # https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data window_timer.set(window.end) element_state.add(element) count_state.add(1) count = count_state.read() if count == 1 and max_buffering_duration_secs > 0: # This is the first element in batch. Start counting buffering time if a # limit was set. # pylint: disable=deprecated-method buffering_timer.set(clock() + max_buffering_duration_secs) if count >= batch_size: return self.flush_batch(element_state, count_state, buffering_timer) @on_timer(WINDOW_TIMER) def on_window_timer( self, element_state=DoFn.StateParam(ELEMENT_STATE), count_state=DoFn.StateParam(COUNT_STATE), buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): return self.flush_batch(element_state, count_state, buffering_timer) @on_timer(BUFFERING_TIMER) def on_buffering_timer( self, element_state=DoFn.StateParam(ELEMENT_STATE), count_state=DoFn.StateParam(COUNT_STATE), buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): return self.flush_batch(element_state, count_state, buffering_timer) def flush_batch(self, element_state, count_state, buffering_timer): batch = [element for element in element_state.read()] if not batch: return key, _ = batch[0] batch_values = [v for (k, v) in batch] element_state.clear() count_state.clear() buffering_timer.clear() yield key, batch_values return _GroupIntoBatchesDoFn()
[docs]class ToString(object): """ PTransform for converting a PCollection element, KV or PCollection Iterable to string. """ # pylint: disable=invalid-name
[docs] @staticmethod def Element(): """ Transforms each element of the PCollection to a string. """ return 'ElementToString' >> Map(str)
[docs] @staticmethod def Iterables(delimiter=None): """ Transforms each item in the iterable of the input of PCollection to a string. There is no trailing delimiter. """ if delimiter is None: delimiter = ',' return ( 'IterablesToString' >> Map(lambda xs: delimiter.join(str(x) for x in xs)).with_input_types( Iterable[Any]).with_output_types(str))
# An alias for Iterables. Kvs = Iterables
[docs]@typehints.with_input_types(T) @typehints.with_output_types(T) class LogElements(PTransform): """ PTransform for printing the elements of a PCollection. Args: label (str): (optional) A custom label for the transform. prefix (str): (optional) A prefix string to prepend to each logged element. with_timestamp (bool): (optional) Whether to include element's timestamp. with_window (bool): (optional) Whether to include element's window. level: (optional) The logging level for the output (e.g. `logging.DEBUG`, `logging.INFO`, `logging.WARNING`, `logging.ERROR`). If not specified, the log is printed to stdout. """ class _LoggingFn(DoFn): def __init__( self, prefix='', with_timestamp=False, with_window=False, level=None): super().__init__() self.prefix = prefix self.with_timestamp = with_timestamp self.with_window = with_window self.level = level def process( self, element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam, **kwargs): log_line = self.prefix + str(element) if self.with_timestamp: log_line += ', timestamp=' + repr(timestamp.to_rfc3339()) if self.with_window: log_line += ', window(start=' + window.start.to_rfc3339() log_line += ', end=' + window.end.to_rfc3339() + ')' if self.level == logging.DEBUG: logging.debug(log_line) elif self.level == logging.INFO: logging.info(log_line) elif self.level == logging.WARNING: logging.warning(log_line) elif self.level == logging.ERROR: logging.error(log_line) elif self.level == logging.CRITICAL: logging.critical(log_line) else: print(log_line) yield element def __init__( self, label=None, prefix='', with_timestamp=False, with_window=False, level=None): super().__init__(label) self.prefix = prefix self.with_timestamp = with_timestamp self.with_window = with_window self.level = level
[docs] def expand(self, input): return input | ParDo( self._LoggingFn( self.prefix, self.with_timestamp, self.with_window, self.level))
[docs]class Reify(object): """PTransforms for converting between explicit and implicit form of various Beam values."""
[docs] @typehints.with_input_types(T) @typehints.with_output_types(T) class Timestamp(PTransform): """PTransform to wrap a value in a TimestampedValue with it's associated timestamp."""
[docs] @staticmethod def add_timestamp_info(element, timestamp=DoFn.TimestampParam): yield TimestampedValue(element, timestamp)
[docs] def expand(self, pcoll): return pcoll | ParDo(self.add_timestamp_info)
[docs] @typehints.with_input_types(T) @typehints.with_output_types(T) class Window(PTransform): """PTransform to convert an element in a PCollection into a tuple of (element, timestamp, window), wrapped in a TimestampedValue with it's associated timestamp."""
[docs] @staticmethod def add_window_info( element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): yield TimestampedValue((element, timestamp, window), timestamp)
[docs] def expand(self, pcoll): return pcoll | ParDo(self.add_window_info)
[docs] @typehints.with_input_types(Tuple[K, V]) @typehints.with_output_types(Tuple[K, V]) class TimestampInValue(PTransform): """PTransform to wrap the Value in a KV pair in a TimestampedValue with the element's associated timestamp."""
[docs] @staticmethod def add_timestamp_info(element, timestamp=DoFn.TimestampParam): key, value = element yield (key, TimestampedValue(value, timestamp))
[docs] def expand(self, pcoll): return pcoll | ParDo(self.add_timestamp_info)
[docs] @typehints.with_input_types(Tuple[K, V]) @typehints.with_output_types(Tuple[K, V]) class WindowInValue(PTransform): """PTransform to convert the Value in a KV pair into a tuple of (value, timestamp, window), with the whole element being wrapped inside a TimestampedValue."""
[docs] @staticmethod def add_window_info( element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): key, value = element yield TimestampedValue((key, (value, timestamp, window)), timestamp)
[docs] def expand(self, pcoll): return pcoll | ParDo(self.add_window_info)
[docs]class Regex(object): """ PTransform to use Regular Expression to process the elements in a PCollection. """ ALL = "__regex_all_groups" @staticmethod def _regex_compile(regex): """Return re.compile if the regex has a string value""" if isinstance(regex, str): regex = re.compile(regex) return regex
[docs] @staticmethod @typehints.with_input_types(str) @typehints.with_output_types(str) @ptransform_fn def matches(pcoll, regex, group=0): """ Returns the matches (group 0 by default) if zero or more characters at the beginning of string match the regular expression. To match the entire string, add "$" sign at the end of regex expression. Group can be integer value or a string value. Args: regex: the regular expression string or (re.compile) pattern. group: (optional) name/number of the group, it can be integer or a string value. Defaults to 0, meaning the entire matched string will be returned. """ regex = Regex._regex_compile(regex) def _process(element): m = regex.match(element) if m: yield m.group(group) return pcoll | FlatMap(_process)
[docs] @staticmethod @typehints.with_input_types(str) @typehints.with_output_types(List[str]) @ptransform_fn def all_matches(pcoll, regex): """ Returns all matches (groups) if zero or more characters at the beginning of string match the regular expression. Args: regex: the regular expression string or (re.compile) pattern. """ regex = Regex._regex_compile(regex) def _process(element): m = regex.match(element) if m: yield [m.group(ix) for ix in range(m.lastindex + 1)] return pcoll | FlatMap(_process)
[docs] @staticmethod @typehints.with_input_types(str) @typehints.with_output_types(Tuple[str, str]) @ptransform_fn def matches_kv(pcoll, regex, keyGroup, valueGroup=0): """ Returns the KV pairs if the string matches the regular expression, deriving the key & value from the specified group of the regular expression. Args: regex: the regular expression string or (re.compile) pattern. keyGroup: The Regex group to use as the key. Can be int or str. valueGroup: (optional) Regex group to use the value. Can be int or str. The default value "0" returns entire matched string. """ regex = Regex._regex_compile(regex) def _process(element): match = regex.match(element) if match: yield (match.group(keyGroup), match.group(valueGroup)) return pcoll | FlatMap(_process)
[docs] @staticmethod @typehints.with_input_types(str) @typehints.with_output_types(str) @ptransform_fn def find(pcoll, regex, group=0): """ Returns the matches if a portion of the line matches the Regex. Returns the entire group (group 0 by default). Group can be integer value or a string value. Args: regex: the regular expression string or (re.compile) pattern. group: (optional) name of the group, it can be integer or a string value. """ regex = Regex._regex_compile(regex) def _process(element): r = regex.search(element) if r: yield r.group(group) return pcoll | FlatMap(_process)
[docs] @staticmethod @typehints.with_input_types(str) @typehints.with_output_types(Union[List[str], List[Tuple[str, str]]]) @ptransform_fn def find_all(pcoll, regex, group=0, outputEmpty=True): """ Returns the matches if a portion of the line matches the Regex. By default, list of group 0 will return with empty items. To get all groups, pass the `Regex.ALL` flag in the `group` parameter which returns all the groups in the tuple format. Args: regex: the regular expression string or (re.compile) pattern. group: (optional) name of the group, it can be integer or a string value. outputEmpty: (optional) Should empty be output. True to output empties and false if not. """ regex = Regex._regex_compile(regex) def _process(element): matches = regex.finditer(element) if group == Regex.ALL: yield [(m.group(), m.groups()[0]) for m in matches if outputEmpty or m.groups()[0]] else: yield [m.group(group) for m in matches if outputEmpty or m.group(group)] return pcoll | FlatMap(_process)
[docs] @staticmethod @typehints.with_input_types(str) @typehints.with_output_types(Tuple[str, str]) @ptransform_fn def find_kv(pcoll, regex, keyGroup, valueGroup=0): """ Returns the matches if a portion of the line matches the Regex. Returns the specified groups as the key and value pair. Args: regex: the regular expression string or (re.compile) pattern. keyGroup: The Regex group to use as the key. Can be int or str. valueGroup: (optional) Regex group to use the value. Can be int or str. The default value "0" returns entire matched string. """ regex = Regex._regex_compile(regex) def _process(element): matches = regex.finditer(element) if matches: for match in matches: yield (match.group(keyGroup), match.group(valueGroup)) return pcoll | FlatMap(_process)
[docs] @staticmethod @typehints.with_input_types(str) @typehints.with_output_types(str) @ptransform_fn def replace_all(pcoll, regex, replacement): """ Returns the matches if a portion of the line matches the regex and replaces all matches with the replacement string. Args: regex: the regular expression string or (re.compile) pattern. replacement: the string to be substituted for each match. """ regex = Regex._regex_compile(regex) return pcoll | Map(lambda elem: regex.sub(replacement, elem))
[docs] @staticmethod @typehints.with_input_types(str) @typehints.with_output_types(str) @ptransform_fn def replace_first(pcoll, regex, replacement): """ Returns the matches if a portion of the line matches the regex and replaces the first match with the replacement string. Args: regex: the regular expression string or (re.compile) pattern. replacement: the string to be substituted for each match. """ regex = Regex._regex_compile(regex) return pcoll | Map(lambda elem: regex.sub(replacement, elem, 1))
[docs] @staticmethod @typehints.with_input_types(str) @typehints.with_output_types(List[str]) @ptransform_fn def split(pcoll, regex, outputEmpty=False): """ Returns the list string which was splitted on the basis of regular expression. It will not output empty items (by defaults). Args: regex: the regular expression string or (re.compile) pattern. outputEmpty: (optional) Should empty be output. True to output empties and false if not. """ regex = Regex._regex_compile(regex) outputEmpty = bool(outputEmpty) def _process(element): r = regex.split(element) if r and not outputEmpty: r = list(filter(None, r)) yield r return pcoll | FlatMap(_process)