import collections
import itertools

import apache_beam as beam
from apache_beam import typehints
from apache_beam.utils.windowed_value import WindowedValue
from apache_beam.internal.util import ArgumentPlaceholder

[docs]class LiftedCombinePerKey(beam.PTransform): """An implementation of CombinePerKey that does mapper-side pre-combining. """ def __init__(self, combine_fn, args, kwargs): if any(isinstance(arg, ArgumentPlaceholder) for arg in itertools.chain(args, kwargs.values())): # This isn't implemented in dataflow either... raise NotImplementedError('Deferred CombineFn side inputs.') self._combine_fn = beam.transforms.combiners.curry_combine_fn( combine_fn, args, kwargs)
[docs] def expand(self, pcoll): return ( pcoll | beam.ParDo(PartialGroupByKeyCombiningValues(self._combine_fn)) | beam.GroupByKey() | beam.ParDo(FinishCombine(self._combine_fn)))
[docs]class PartialGroupByKeyCombiningValues(beam.DoFn): """Aggregates values into a per-key-window cache. As bundles are in-memory-sized, we don't bother flushing until the very end. """ def __init__(self, combine_fn): self._combine_fn = combine_fn
[docs] def start_bundle(self): self._cache = collections.defaultdict(self._combine_fn.create_accumulator)
[docs] def process(self, element, window=beam.DoFn.WindowParam): k, vi = element self._cache[k, window] = self._combine_fn.add_input(self._cache[k, window], vi)
[docs] def finish_bundle(self): for (k, w), va in self._cache.items(): yield WindowedValue((k, va), w.end, (w,))
[docs] def default_type_hints(self): hints = self._combine_fn.get_type_hints().copy() K = typehints.TypeVariable('K') if hints.input_types: args, kwargs = hints.input_types args = (typehints.Tuple[K, args[0]],) + args[1:] hints.set_input_types(*args, **kwargs) else: hints.set_input_types(typehints.Tuple[K, typehints.Any]) hints.set_output_types(typehints.Tuple[K, typehints.Any]) return hints
[docs]class FinishCombine(beam.DoFn): """Merges partially combined results. """ def __init__(self, combine_fn): self._combine_fn = combine_fn
[docs] def process(self, element): k, vs = element return [( k, self._combine_fn.extract_output( self._combine_fn.merge_accumulators(vs)))]
[docs] def default_type_hints(self): hints = self._combine_fn.get_type_hints().copy() K = typehints.TypeVariable('K') hints.set_input_types(typehints.Tuple[K, typehints.Any]) if hints.output_types: main_output_type = hints.simple_output_type('') hints.set_output_types(typehints.Tuple[K, main_output_type]) return hints