Source code for apache_beam.transforms.util

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Simple utility PTransforms.
"""

from __future__ import absolute_import
from __future__ import division

import collections
import contextlib
import random
import time
from builtins import object
from builtins import range
from builtins import zip

from future.utils import itervalues

from apache_beam import typehints
from apache_beam.metrics import Metrics
from apache_beam.portability import common_urns
from apache_beam.transforms import window
from apache_beam.transforms.core import CombinePerKey
from apache_beam.transforms.core import DoFn
from apache_beam.transforms.core import FlatMap
from apache_beam.transforms.core import Flatten
from apache_beam.transforms.core import GroupByKey
from apache_beam.transforms.core import Map
from apache_beam.transforms.core import ParDo
from apache_beam.transforms.core import Windowing
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.ptransform import ptransform_fn
from apache_beam.transforms.trigger import AccumulationMode
from apache_beam.transforms.trigger import AfterCount
from apache_beam.transforms.window import NonMergingWindowFn
from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import TimestampedValue
from apache_beam.utils import windowed_value

__all__ = [
    'BatchElements',
    'CoGroupByKey',
    'Keys',
    'KvSwap',
    'RemoveDuplicates',
    'Reshuffle',
    'Values',
    ]

K = typehints.TypeVariable('K')
V = typehints.TypeVariable('V')
T = typehints.TypeVariable('T')


[docs]class CoGroupByKey(PTransform): """Groups results across several PCollections by key. Given an input dict of serializable keys (called "tags") to 0 or more PCollections of (key, value) tuples, it creates a single output PCollection of (key, value) tuples whose keys are the unique input keys from all inputs, and whose values are dicts mapping each tag to an iterable of whatever values were under the key in the corresponding PCollection, in this manner:: ('some key', {'tag1': ['value 1 under "some key" in pcoll1', 'value 2 under "some key" in pcoll1', ...], 'tag2': ... , ... }) For example, given: {'tag1': pc1, 'tag2': pc2, 333: pc3} where: pc1 = [(k1, v1)] pc2 = [] pc3 = [(k1, v31), (k1, v32), (k2, v33)] The output PCollection would be: [(k1, {'tag1': [v1], 'tag2': [], 333: [v31, v32]}), (k2, {'tag1': [], 'tag2': [], 333: [v33]})] CoGroupByKey also works for tuples, lists, or other flat iterables of PCollections, in which case the values of the resulting PCollections will be tuples whose nth value is the list of values from the nth PCollection---conceptually, the "tags" are the indices into the input. Thus, for this input:: (pc1, pc2, pc3) the output would be:: [(k1, ([v1], [], [v31, v32]), (k2, ([], [], [v33]))] Attributes: **kwargs: Accepts a single named argument "pipeline", which specifies the pipeline that "owns" this PTransform. Ordinarily CoGroupByKey can obtain this information from one of the input PCollections, but if there are none (or if there's a chance there may be none), this argument is the only way to provide pipeline information, and should be considered mandatory. """ def __init__(self, **kwargs): super(CoGroupByKey, self).__init__() self.pipeline = kwargs.pop('pipeline', None) if kwargs: raise ValueError('Unexpected keyword arguments: %s' % list(kwargs.keys())) def _extract_input_pvalues(self, pvalueish): try: # If this works, it's a dict. return pvalueish, tuple(itervalues(pvalueish)) except AttributeError: pcolls = tuple(pvalueish) return pcolls, pcolls
[docs] def expand(self, pcolls): """Performs CoGroupByKey on argument pcolls; see class docstring.""" # For associating values in K-V pairs with the PCollections they came from. def _pair_tag_with_value(key_value, tag): (key, value) = key_value return (key, (tag, value)) # Creates the key, value pairs for the output PCollection. Values are either # lists or dicts (per the class docstring), initialized by the result of # result_ctor(result_ctor_arg). def _merge_tagged_vals_under_key(key_grouped, result_ctor, result_ctor_arg): (key, grouped) = key_grouped result_value = result_ctor(result_ctor_arg) for tag, value in grouped: result_value[tag].append(value) return (key, result_value) try: # If pcolls is a dict, we turn it into (tag, pcoll) pairs for use in the # general-purpose code below. The result value constructor creates dicts # whose keys are the tags. result_ctor_arg = list(pcolls) result_ctor = lambda tags: dict((tag, []) for tag in tags) pcolls = pcolls.items() except AttributeError: # Otherwise, pcolls is a list/tuple, so we turn it into (index, pcoll) # pairs. The result value constructor makes tuples with len(pcolls) slots. pcolls = list(enumerate(pcolls)) result_ctor_arg = len(pcolls) result_ctor = lambda size: tuple([] for _ in range(size)) # Check input PCollections for PCollection-ness, and that they all belong # to the same pipeline. for _, pcoll in pcolls: self._check_pcollection(pcoll) if self.pipeline: assert pcoll.pipeline == self.pipeline return ([pcoll | 'pair_with_%s' % tag >> Map(_pair_tag_with_value, tag) for tag, pcoll in pcolls] | Flatten(pipeline=self.pipeline) | GroupByKey() | Map(_merge_tagged_vals_under_key, result_ctor, result_ctor_arg))
[docs]def Keys(label='Keys'): # pylint: disable=invalid-name """Produces a PCollection of first elements of 2-tuples in a PCollection.""" return label >> Map(lambda k_v: k_v[0])
[docs]def Values(label='Values'): # pylint: disable=invalid-name """Produces a PCollection of second elements of 2-tuples in a PCollection.""" return label >> Map(lambda k_v1: k_v1[1])
[docs]def KvSwap(label='KvSwap'): # pylint: disable=invalid-name """Produces a PCollection reversing 2-tuples in a PCollection.""" return label >> Map(lambda k_v2: (k_v2[1], k_v2[0]))
@ptransform_fn def RemoveDuplicates(pcoll): # pylint: disable=invalid-name """Produces a PCollection containing the unique elements of a PCollection.""" return (pcoll | 'ToPairs' >> Map(lambda v: (v, None)) | 'Group' >> CombinePerKey(lambda vs: None) | 'RemoveDuplicates' >> Keys()) class _BatchSizeEstimator(object): """Estimates the best size for batches given historical timing. """ _MAX_DATA_POINTS = 100 _MAX_GROWTH_FACTOR = 2 def __init__(self, min_batch_size=1, max_batch_size=1000, target_batch_overhead=.1, target_batch_duration_secs=1, variance=0.25, clock=time.time): if min_batch_size > max_batch_size: raise ValueError("Minimum (%s) must not be greater than maximum (%s)" % ( min_batch_size, max_batch_size)) if target_batch_overhead and not 0 < target_batch_overhead <= 1: raise ValueError("target_batch_overhead (%s) must be between 0 and 1" % ( target_batch_overhead)) if target_batch_duration_secs and target_batch_duration_secs <= 0: raise ValueError("target_batch_duration_secs (%s) must be positive" % ( target_batch_duration_secs)) if not (target_batch_overhead or target_batch_duration_secs): raise ValueError("At least one of target_batch_overhead or " "target_batch_duration_secs must be positive.") self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._target_batch_overhead = target_batch_overhead self._target_batch_duration_secs = target_batch_duration_secs self._variance = variance self._clock = clock self._data = [] self._ignore_next_timing = False self._size_distribution = Metrics.distribution( 'BatchElements', 'batch_size') self._time_distribution = Metrics.distribution( 'BatchElements', 'msec_per_batch') # Beam distributions only accept integer values, so we use this to # accumulate under-reported values until they add up to whole milliseconds. # (Milliseconds are chosen because that's conventionally used elsewhere in # profiling-style counters.) self._remainder_msecs = 0 def ignore_next_timing(self): """Call to indicate the next timing should be ignored. For example, the first emit of a ParDo operation is known to be anomalous due to setup that may occur. """ self._ignore_next_timing = False @contextlib.contextmanager def record_time(self, batch_size): start = self._clock() yield elapsed = self._clock() - start elapsed_msec = 1e3 * elapsed + self._remainder_msecs self._size_distribution.update(batch_size) self._time_distribution.update(int(elapsed_msec)) self._remainder_msecs = elapsed_msec - int(elapsed_msec) if self._ignore_next_timing: self._ignore_next_timing = False else: self._data.append((batch_size, elapsed)) if len(self._data) >= self._MAX_DATA_POINTS: self._thin_data() def _thin_data(self): # Make sure we don't change the parity of len(self._data) # As it's used below to alternate jitter. self._data.pop(random.randrange(len(self._data) // 4)) self._data.pop(random.randrange(len(self._data) // 2)) @staticmethod def linear_regression_no_numpy(xs, ys): # Least squares fit for y = a + bx over all points. n = float(len(xs)) xbar = sum(xs) / n ybar = sum(ys) / n if xbar == 0: return ybar, 0 if all(xs[0] == x for x in xs): # Simply use the mean if all values in xs are same. return 0, ybar / xbar b = (sum([(x - xbar) * (y - ybar) for x, y in zip(xs, ys)]) / sum([(x - xbar)**2 for x in xs])) a = ybar - b * xbar return a, b @staticmethod def linear_regression_numpy(xs, ys): # pylint: disable=wrong-import-order, wrong-import-position import numpy as np from numpy import sum n = len(xs) if all(xs[0] == x for x in xs): # If all values of xs are same then fallback to linear_regression_no_numpy return _BatchSizeEstimator.linear_regression_no_numpy(xs, ys) xs = np.asarray(xs, dtype=float) ys = np.asarray(ys, dtype=float) # First do a simple least squares fit for y = a + bx over all points. b, a = np.polyfit(xs, ys, 1) if n < 10: return a, b else: # Refine this by throwing out outliers, according to Cook's distance. # https://en.wikipedia.org/wiki/Cook%27s_distance sum_x = sum(xs) sum_x2 = sum(xs**2) errs = a + b * xs - ys s2 = sum(errs**2) / (n - 2) if s2 == 0: # It's an exact fit! return a, b h = (sum_x2 - 2 * sum_x * xs + n * xs**2) / (n * sum_x2 - sum_x**2) cook_ds = 0.5 / s2 * errs**2 * (h / (1 - h)**2) # Re-compute the regression, excluding those points with Cook's distance # greater than 0.5, and weighting by the inverse of x to give a more # stable y-intercept (as small batches have relatively more information # about the fixed overhead). weight = (cook_ds <= 0.5) / xs b, a = np.polyfit(xs, ys, 1, w=weight) return a, b try: # pylint: disable=wrong-import-order, wrong-import-position import numpy as np linear_regression = linear_regression_numpy except ImportError: linear_regression = linear_regression_no_numpy def next_batch_size(self): if self._min_batch_size == self._max_batch_size: return self._min_batch_size elif len(self._data) < 1: return self._min_batch_size elif len(self._data) < 2: # Force some variety so we have distinct batch sizes on which to do # linear regression below. return int(max( min(self._max_batch_size, self._min_batch_size * self._MAX_GROWTH_FACTOR), self._min_batch_size + 1)) # There tends to be a lot of noise in the top quantile, which also # has outsided influence in the regression. If we have enough data, # Simply declare the top 20% to be outliers. trimmed_data = sorted(self._data)[:max(20, len(self._data) * 4 // 5)] # Linear regression for y = a + bx, where x is batch size and y is time. xs, ys = zip(*trimmed_data) a, b = self.linear_regression(xs, ys) # Avoid nonsensical or division-by-zero errors below due to noise. a = max(a, 1e-10) b = max(b, 1e-20) last_batch_size = self._data[-1][0] cap = min(last_batch_size * self._MAX_GROWTH_FACTOR, self._max_batch_size) target = self._max_batch_size if self._target_batch_duration_secs: # Solution to a + b*x = self._target_batch_duration_secs. target = min(target, (self._target_batch_duration_secs - a) / b) if self._target_batch_overhead: # Solution to a / (a + b*x) = self._target_batch_overhead. target = min(target, (a / b) * (1 / self._target_batch_overhead - 1)) # Avoid getting stuck at a single batch size (especially the minimal # batch size) which would not allow us to extrapolate to other batch # sizes. # Jitter alternates between 0 and 1. jitter = len(self._data) % 2 # Smear our samples across a range centered at the target. if len(self._data) > 10: target += int(target * self._variance * 2 * (random.random() - .5)) return int(max(self._min_batch_size + jitter, min(target, cap))) class _GlobalWindowsBatchingDoFn(DoFn): def __init__(self, batch_size_estimator): self._batch_size_estimator = batch_size_estimator def start_bundle(self): self._batch = [] self._batch_size = self._batch_size_estimator.next_batch_size() # The first emit often involves non-trivial setup. self._batch_size_estimator.ignore_next_timing() def process(self, element): self._batch.append(element) if len(self._batch) >= self._batch_size: with self._batch_size_estimator.record_time(self._batch_size): yield self._batch self._batch = [] self._batch_size = self._batch_size_estimator.next_batch_size() def finish_bundle(self): if self._batch: with self._batch_size_estimator.record_time(self._batch_size): yield window.GlobalWindows.windowed_value(self._batch) self._batch = None self._batch_size = self._batch_size_estimator.next_batch_size() class _WindowAwareBatchingDoFn(DoFn): _MAX_LIVE_WINDOWS = 10 def __init__(self, batch_size_estimator): self._batch_size_estimator = batch_size_estimator def start_bundle(self): self._batches = collections.defaultdict(list) self._batch_size = self._batch_size_estimator.next_batch_size() # The first emit often involves non-trivial setup. self._batch_size_estimator.ignore_next_timing() def process(self, element, window=DoFn.WindowParam): self._batches[window].append(element) if len(self._batches[window]) >= self._batch_size: with self._batch_size_estimator.record_time(self._batch_size): yield windowed_value.WindowedValue( self._batches[window], window.max_timestamp(), (window,)) del self._batches[window] self._batch_size = self._batch_size_estimator.next_batch_size() elif len(self._batches) > self._MAX_LIVE_WINDOWS: window, _ = sorted( self._batches.items(), key=lambda window_batch: len(window_batch[1]), reverse=True)[0] with self._batch_size_estimator.record_time(self._batch_size): yield windowed_value.WindowedValue( self._batches[window], window.max_timestamp(), (window,)) del self._batches[window] self._batch_size = self._batch_size_estimator.next_batch_size() def finish_bundle(self): for window, batch in self._batches.items(): if batch: with self._batch_size_estimator.record_time(self._batch_size): yield windowed_value.WindowedValue( batch, window.max_timestamp(), (window,)) self._batches = None self._batch_size = self._batch_size_estimator.next_batch_size()
[docs]@typehints.with_input_types(T) @typehints.with_output_types(typehints.List[T]) class BatchElements(PTransform): """A Transform that batches elements for amortized processing. This transform is designed to precede operations whose processing cost is of the form time = fixed_cost + num_elements * per_element_cost where the per element cost is (often significantly) smaller than the fixed cost and could be amortized over multiple elements. It consumes a PCollection of element type T and produces a PCollection of element type List[T]. This transform attempts to find the best batch size between the minimim and maximum parameters by profiling the time taken by (fused) downstream operations. For a fixed batch size, set the min and max to be equal. Elements are batched per-window and batches emitted in the window corresponding to its contents. Args: min_batch_size: (optional) the smallest number of elements per batch max_batch_size: (optional) the largest number of elements per batch target_batch_overhead: (optional) a target for fixed_cost / time, as used in the formula above target_batch_duration_secs: (optional) a target for total time per bundle, in seconds variance: (optional) the permitted (relative) amount of deviation from the (estimated) ideal batch size used to produce a wider base for linear interpolation clock: (optional) an alternative to time.time for measuring the cost of donwstream operations (mostly for testing) """ def __init__(self, min_batch_size=1, max_batch_size=10000, target_batch_overhead=.05, target_batch_duration_secs=1, variance=0.25, clock=time.time): self._batch_size_estimator = _BatchSizeEstimator( min_batch_size=min_batch_size, max_batch_size=max_batch_size, target_batch_overhead=target_batch_overhead, target_batch_duration_secs=target_batch_duration_secs, variance=variance, clock=clock)
[docs] def expand(self, pcoll): if getattr(pcoll.pipeline.runner, 'is_streaming', False): raise NotImplementedError("Requires stateful processing (BEAM-2687)") elif pcoll.windowing.is_default(): # This is the same logic as _GlobalWindowsBatchingDoFn, but optimized # for that simpler case. return pcoll | ParDo(_GlobalWindowsBatchingDoFn( self._batch_size_estimator)) else: return pcoll | ParDo(_WindowAwareBatchingDoFn(self._batch_size_estimator))
class _IdentityWindowFn(NonMergingWindowFn): """Windowing function that preserves existing windows. To be used internally with the Reshuffle transform. Will raise an exception when used after DoFns that return TimestampedValue elements. """ def __init__(self, window_coder): """Create a new WindowFn with compatible coder. To be applied to PCollections with windows that are compatible with the given coder. Arguments: window_coder: coders.Coder object to be used on windows. """ super(_IdentityWindowFn, self).__init__() if window_coder is None: raise ValueError('window_coder should not be None') self._window_coder = window_coder def assign(self, assign_context): if assign_context.window is None: raise ValueError( 'assign_context.window should not be None. ' 'This might be due to a DoFn returning a TimestampedValue.') return [assign_context.window] def get_window_coder(self): return self._window_coder @typehints.with_input_types(typehints.KV[K, V]) @typehints.with_output_types(typehints.KV[K, V]) class ReshufflePerKey(PTransform): """PTransform that returns a PCollection equivalent to its input, but operationally provides some of the side effects of a GroupByKey, in particular preventing fusion of the surrounding transforms, checkpointing, and deduplication by id. ReshufflePerKey is experimental. No backwards compatibility guarantees. """ def expand(self, pcoll): windowing_saved = pcoll.windowing if windowing_saved.is_default(): # In this (common) case we can use a trivial trigger driver # and avoid the (expensive) window param. globally_windowed = window.GlobalWindows.windowed_value(None) window_fn = window.GlobalWindows() MIN_TIMESTAMP = window.MIN_TIMESTAMP def reify_timestamps(element, timestamp=DoFn.TimestampParam): key, value = element if timestamp == MIN_TIMESTAMP: timestamp = None return key, (value, timestamp) def restore_timestamps(element): key, values = element return [ globally_windowed.with_value((key, value)) if timestamp is None else window.GlobalWindows.windowed_value((key, value), timestamp) for (value, timestamp) in values] else: # The linter is confused. # hash(1) is used to force "runtime" selection of _IdentityWindowFn # pylint: disable=abstract-class-instantiated cls = hash(1) and _IdentityWindowFn window_fn = cls( windowing_saved.windowfn.get_window_coder()) def reify_timestamps(element, timestamp=DoFn.TimestampParam): key, value = element return key, TimestampedValue(value, timestamp) def restore_timestamps(element, window=DoFn.WindowParam): # Pass the current window since _IdentityWindowFn wouldn't know how # to generate it. key, values = element return [ windowed_value.WindowedValue( (key, value.value), value.timestamp, [window]) for value in values] ungrouped = pcoll | Map(reify_timestamps) ungrouped._windowing = Windowing( window_fn, triggerfn=AfterCount(1), accumulation_mode=AccumulationMode.DISCARDING, timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST) result = (ungrouped | GroupByKey() | FlatMap(restore_timestamps)) result._windowing = windowing_saved return result
[docs]@typehints.with_input_types(T) @typehints.with_output_types(T) class Reshuffle(PTransform): """PTransform that returns a PCollection equivalent to its input, but operationally provides some of the side effects of a GroupByKey, in particular preventing fusion of the surrounding transforms, checkpointing, and deduplication by id. Reshuffle adds a temporary random key to each element, performs a ReshufflePerKey, and finally removes the temporary key. Reshuffle is experimental. No backwards compatibility guarantees. """
[docs] def expand(self, pcoll): return (pcoll | 'AddRandomKeys' >> Map(lambda t: (random.getrandbits(32), t)) | ReshufflePerKey() | 'RemoveRandomKeys' >> Map(lambda t: t[1]))
[docs] def to_runner_api_parameter(self, unused_context): return common_urns.composites.RESHUFFLE.urn, None
[docs] @PTransform.register_urn(common_urns.composites.RESHUFFLE.urn, None) def from_runner_api_parameter(unused_parameter, unused_context): return Reshuffle()