#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""PTransform and descendants.
A PTransform is an object describing (not executing) a computation. The actual
execution semantics for a transform is captured by a runner object. A transform
object always belongs to a pipeline object.
A PTransform derived class needs to define the expand() method that describes
how one or more PValues are created by the transform.
The module defines a few standard transforms: FlatMap (parallel do),
GroupByKey (group by key), etc. Note that the expand() methods for these
classes contain code that will add nodes to the processing graph associated
with a pipeline.
As support for the FlatMap transform, the module also defines a DoFn
class and wrapper class that allows lambda functions to be used as
FlatMap processing functions.
"""
from __future__ import absolute_import
import contextlib
import copy
import itertools
import operator
import os
import sys
import threading
from builtins import hex
from builtins import object
from builtins import zip
from functools import reduce
from functools import wraps
from google.protobuf import message
from apache_beam import error
from apache_beam import pvalue
from apache_beam.internal import pickler
from apache_beam.internal import util
from apache_beam.portability import python_urns
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.transforms.display import HasDisplayData
from apache_beam.typehints import native_type_compatibility
from apache_beam.typehints import typehints
from apache_beam.typehints.decorators import TypeCheckError
from apache_beam.typehints.decorators import WithTypeHints
from apache_beam.typehints.decorators import get_signature
from apache_beam.typehints.decorators import getcallargs_forhints
from apache_beam.typehints.trivial_inference import instance_to_type
from apache_beam.typehints.typehints import validate_composite_type_param
from apache_beam.utils import proto_utils
__all__ = [
'PTransform',
'ptransform_fn',
'label_from_callable',
]
class _PValueishTransform(object):
"""Visitor for PValueish objects.
A PValueish is a PValue, or list, tuple, dict of PValuesish objects.
This visits a PValueish, contstructing a (possibly mutated) copy.
"""
def visit_nested(self, node, *args):
if isinstance(node, (tuple, list)):
args = [self.visit(x, *args) for x in node]
if isinstance(node, tuple) and hasattr(node.__class__, '_make'):
# namedtuples require unpacked arguments in their constructor
return node.__class__(*args)
else:
return node.__class__(args)
elif isinstance(node, dict):
return node.__class__(
{key: self.visit(value, *args) for (key, value) in node.items()})
else:
return node
class _SetInputPValues(_PValueishTransform):
def visit(self, node, replacements):
if id(node) in replacements:
return replacements[id(node)]
else:
return self.visit_nested(node, replacements)
# Caches to allow for materialization of values when executing a pipeline
# in-process, in eager mode. This cache allows the same _MaterializedResult
# object to be accessed and used despite Runner API round-trip serialization.
_pipeline_materialization_cache = {}
_pipeline_materialization_lock = threading.Lock()
def _allocate_materialized_pipeline(pipeline):
pid = os.getpid()
with _pipeline_materialization_lock:
pipeline_id = id(pipeline)
_pipeline_materialization_cache[(pid, pipeline_id)] = {}
def _allocate_materialized_result(pipeline):
pid = os.getpid()
with _pipeline_materialization_lock:
pipeline_id = id(pipeline)
if (pid, pipeline_id) not in _pipeline_materialization_cache:
raise ValueError('Materialized pipeline is not allocated for result '
'cache.')
result_id = len(_pipeline_materialization_cache[(pid, pipeline_id)])
result = _MaterializedResult(pipeline_id, result_id)
_pipeline_materialization_cache[(pid, pipeline_id)][result_id] = result
return result
def _get_materialized_result(pipeline_id, result_id):
pid = os.getpid()
with _pipeline_materialization_lock:
if (pid, pipeline_id) not in _pipeline_materialization_cache:
raise Exception(
'Materialization in out-of-process and remote runners is not yet '
'supported.')
return _pipeline_materialization_cache[(pid, pipeline_id)][result_id]
def _release_materialized_pipeline(pipeline):
pid = os.getpid()
with _pipeline_materialization_lock:
pipeline_id = id(pipeline)
del _pipeline_materialization_cache[(pid, pipeline_id)]
class _MaterializedResult(object):
def __init__(self, pipeline_id, result_id):
self._pipeline_id = pipeline_id
self._result_id = result_id
self.elements = []
def __reduce__(self):
# When unpickled (during Runner API roundtrip serailization), get the
# _MaterializedResult object from the cache so that values are written
# to the original _MaterializedResult when run in eager mode.
return (_get_materialized_result, (self._pipeline_id, self._result_id))
class _MaterializedDoOutputsTuple(pvalue.DoOutputsTuple):
def __init__(self, deferred, results_by_tag):
super(_MaterializedDoOutputsTuple, self).__init__(
None, None, deferred._tags, deferred._main_tag)
self._deferred = deferred
self._results_by_tag = results_by_tag
def __getitem__(self, tag):
if tag not in self._results_by_tag:
raise KeyError(
'Tag %r is not a a defined output tag of %s.' % (
tag, self._deferred))
return self._results_by_tag[tag].elements
class _AddMaterializationTransforms(_PValueishTransform):
def _materialize_transform(self, pipeline):
result = _allocate_materialized_result(pipeline)
# Need to define _MaterializeValuesDoFn here to avoid circular
# dependencies.
from apache_beam import DoFn
from apache_beam import ParDo
class _MaterializeValuesDoFn(DoFn):
def process(self, element):
result.elements.append(element)
materialization_label = '_MaterializeValues%d' % result._result_id
return (materialization_label >> ParDo(_MaterializeValuesDoFn()),
result)
def visit(self, node):
if isinstance(node, pvalue.PValue):
transform, result = self._materialize_transform(node.pipeline)
node | transform
return result
elif isinstance(node, pvalue.DoOutputsTuple):
results_by_tag = {}
for tag in itertools.chain([node._main_tag], node._tags):
results_by_tag[tag] = self.visit(node[tag])
return _MaterializedDoOutputsTuple(node, results_by_tag)
else:
return self.visit_nested(node)
class _FinalizeMaterialization(_PValueishTransform):
def visit(self, node):
if isinstance(node, _MaterializedResult):
return node.elements
elif isinstance(node, _MaterializedDoOutputsTuple):
return node
else:
return self.visit_nested(node)
class _GetPValues(_PValueishTransform):
def visit(self, node, pvalues):
if isinstance(node, (pvalue.PValue, pvalue.DoOutputsTuple)):
pvalues.append(node)
else:
self.visit_nested(node, pvalues)
def get_nested_pvalues(pvalueish):
pvalues = []
_GetPValues().visit(pvalueish, pvalues)
return pvalues
def get_nested_pvalues0(pvalueish):
if isinstance(pvalueish, (tuple, list)):
tagged_values = enumerate(pvalueish)
if isinstance(pvalueish, dict):
tagged_values = pvalueish.items()
else:
yield None, pvalueish
return
for tag, subvalue in tagged_values:
for subtag, subsubvalue in get_nested_pvalues(subvalue):
if subtag is None:
yield tag, subsubvalue
else:
yield '%s.%s' % (tag, subsubvalue), subsubvalue
class _ZipPValues(object):
"""Pairs each PValue in a pvalueish with a value in a parallel out sibling.
Sibling should have the same nested structure as pvalueish. Leaves in
sibling are expanded across nested pvalueish lists, tuples, and dicts.
For example
ZipPValues().visit({'a': pc1, 'b': (pc2, pc3)},
{'a': 'A', 'b', 'B'})
will return
[('a', pc1, 'A'), ('b', pc2, 'B'), ('b', pc3, 'B')]
"""
def visit(self, pvalueish, sibling, pairs=None, context=None):
if pairs is None:
pairs = []
self.visit(pvalueish, sibling, pairs, context)
return pairs
elif isinstance(pvalueish, (pvalue.PValue, pvalue.DoOutputsTuple)):
pairs.append((context, pvalueish, sibling))
elif isinstance(pvalueish, (list, tuple)):
self.visit_sequence(pvalueish, sibling, pairs, context)
elif isinstance(pvalueish, dict):
self.visit_dict(pvalueish, sibling, pairs, context)
def visit_sequence(self, pvalueish, sibling, pairs, context):
if isinstance(sibling, (list, tuple)):
for ix, (p, s) in enumerate(zip(
pvalueish, list(sibling) + [None] * len(pvalueish))):
self.visit(p, s, pairs, 'position %s' % ix)
else:
for p in pvalueish:
self.visit(p, sibling, pairs, context)
def visit_dict(self, pvalueish, sibling, pairs, context):
if isinstance(sibling, dict):
for key, p in pvalueish.items():
self.visit(p, sibling.get(key), pairs, key)
else:
for p in pvalueish.values():
self.visit(p, sibling, pairs, context)
@PTransform.register_urn(python_urns.GENERIC_COMPOSITE_TRANSFORM, None)
def _create_transform(payload, unused_context):
empty_transform = PTransform()
empty_transform._fn_api_payload = payload
return empty_transform
@PTransform.register_urn(python_urns.PICKLED_TRANSFORM, None)
def _unpickle_transform(pickled_bytes, unused_context):
return pickler.loads(pickled_bytes)
class _ChainedPTransform(PTransform):
def __init__(self, *parts):
super(_ChainedPTransform, self).__init__(label=self._chain_label(parts))
self._parts = parts
def _chain_label(self, parts):
return '|'.join(p.label for p in parts)
def __or__(self, right):
if isinstance(right, PTransform):
# Create a flat list rather than a nested tree of composite
# transforms for better monitoring, etc.
return _ChainedPTransform(*(self._parts + (right,)))
return NotImplemented
def expand(self, pval):
return reduce(operator.or_, self._parts, pval)
class PTransformWithSideInputs(PTransform):
"""A superclass for any :class:`PTransform` (e.g.
:func:`~apache_beam.transforms.core.FlatMap` or
:class:`~apache_beam.transforms.core.CombineFn`)
invoking user code.
:class:`PTransform` s like :func:`~apache_beam.transforms.core.FlatMap`
invoke user-supplied code in some kind of package (e.g. a
:class:`~apache_beam.transforms.core.DoFn`) and optionally provide arguments
and side inputs to that code. This internal-use-only class contains common
functionality for :class:`PTransform` s that fit this model.
"""
def __init__(self, fn, *args, **kwargs):
if isinstance(fn, type) and issubclass(fn, WithTypeHints):
# Don't treat Fn class objects as callables.
raise ValueError('Use %s() not %s.' % (fn.__name__, fn.__name__))
self.fn = self.make_fn(fn, bool(args or kwargs))
# Now that we figure out the label, initialize the super-class.
super(PTransformWithSideInputs, self).__init__()
if (any([isinstance(v, pvalue.PCollection) for v in args]) or
any([isinstance(v, pvalue.PCollection) for v in kwargs.values()])):
raise error.SideInputError(
'PCollection used directly as side input argument. Specify '
'AsIter(pcollection) or AsSingleton(pcollection) to indicate how the '
'PCollection is to be used.')
self.args, self.kwargs, self.side_inputs = util.remove_objects_from_args(
args, kwargs, pvalue.AsSideInput)
self.raw_side_inputs = args, kwargs
# Prevent name collisions with fns of the form '<function <lambda> at ...>'
self._cached_fn = self.fn
# Ensure fn and side inputs are picklable for remote execution.
try:
self.fn = pickler.loads(pickler.dumps(self.fn))
except RuntimeError as e:
raise RuntimeError('Unable to pickle fn %s: %s' % (self.fn, e))
self.args = pickler.loads(pickler.dumps(self.args))
self.kwargs = pickler.loads(pickler.dumps(self.kwargs))
# For type hints, because loads(dumps(class)) != class.
self.fn = self._cached_fn
def with_input_types(
self, input_type_hint, *side_inputs_arg_hints, **side_input_kwarg_hints):
"""Annotates the types of main inputs and side inputs for the PTransform.
Args:
input_type_hint: An instance of an allowed built-in type, a custom class,
or an instance of a typehints.TypeConstraint.
*side_inputs_arg_hints: A variable length argument composed of
of an allowed built-in type, a custom class, or a
typehints.TypeConstraint.
**side_input_kwarg_hints: A dictionary argument composed of
of an allowed built-in type, a custom class, or a
typehints.TypeConstraint.
Example of annotating the types of side-inputs::
FlatMap().with_input_types(int, int, bool)
Raises:
:class:`~exceptions.TypeError`: If **type_hint** is not a valid type-hint.
See
:func:`~apache_beam.typehints.typehints.validate_composite_type_param`
for further details.
Returns:
:class:`PTransform`: A reference to the instance of this particular
:class:`PTransform` object. This allows chaining type-hinting related
methods.
"""
super(PTransformWithSideInputs, self).with_input_types(input_type_hint)
side_inputs_arg_hints = native_type_compatibility.convert_to_beam_types(
side_inputs_arg_hints)
side_input_kwarg_hints = native_type_compatibility.convert_to_beam_types(
side_input_kwarg_hints)
for si in side_inputs_arg_hints:
validate_composite_type_param(si, 'Type hints for a PTransform')
for si in side_input_kwarg_hints.values():
validate_composite_type_param(si, 'Type hints for a PTransform')
self.side_inputs_types = side_inputs_arg_hints
return WithTypeHints.with_input_types(
self, input_type_hint, *side_inputs_arg_hints, **side_input_kwarg_hints)
def type_check_inputs(self, pvalueish):
type_hints = self.get_type_hints().input_types
if type_hints:
args, kwargs = self.raw_side_inputs
def element_type(side_input):
if isinstance(side_input, pvalue.AsSideInput):
return side_input.element_type
return instance_to_type(side_input)
arg_types = [pvalueish.element_type] + [element_type(v) for v in args]
kwargs_types = {k: element_type(v) for (k, v) in kwargs.items()}
argspec_fn = self._process_argspec_fn()
bindings = getcallargs_forhints(argspec_fn, *arg_types, **kwargs_types)
hints = getcallargs_forhints(argspec_fn, *type_hints[0], **type_hints[1])
for arg, hint in hints.items():
if arg.startswith('__unknown__'):
continue
if hint is None:
continue
if not typehints.is_consistent_with(
bindings.get(arg, typehints.Any), hint):
raise TypeCheckError(
'Type hint violation for \'%s\': requires %s but got %s for %s'
% (self.label, hint, bindings[arg], arg))
def _process_argspec_fn(self):
"""Returns an argspec of the function actually consuming the data.
"""
raise NotImplementedError
def make_fn(self, fn, has_side_inputs):
# TODO(silviuc): Add comment describing that this is meant to be overriden
# by methods detecting callables and wrapping them in DoFns.
return fn
def default_label(self):
return '%s(%s)' % (self.__class__.__name__, self.fn.default_label())
class _PTransformFnPTransform(PTransform):
"""A class wrapper for a function-based transform."""
def __init__(self, fn, *args, **kwargs):
super(_PTransformFnPTransform, self).__init__()
self._fn = fn
self._args = args
self._kwargs = kwargs
def display_data(self):
res = {'fn': (self._fn.__name__
if hasattr(self._fn, '__name__')
else self._fn.__class__),
'args': DisplayDataItem(str(self._args)).drop_if_default('()'),
'kwargs': DisplayDataItem(str(self._kwargs)).drop_if_default('{}')}
return res
def expand(self, pcoll):
# Since the PTransform will be implemented entirely as a function
# (once called), we need to pass through any type-hinting information that
# may have been annotated via the .with_input_types() and
# .with_output_types() methods.
kwargs = dict(self._kwargs)
args = tuple(self._args)
# TODO(BEAM-5878) Support keyword-only arguments.
try:
if 'type_hints' in get_signature(self._fn).parameters:
args = (self.get_type_hints(),) + args
except TypeError:
# Might not be a function.
pass
return self._fn(pcoll, *args, **kwargs)
def default_label(self):
if self._args:
return '%s(%s)' % (
label_from_callable(self._fn), label_from_callable(self._args[0]))
return label_from_callable(self._fn)
[docs]def label_from_callable(fn):
if hasattr(fn, 'default_label'):
return fn.default_label()
elif hasattr(fn, '__name__'):
if fn.__name__ == '<lambda>':
return '<lambda at %s:%s>' % (
os.path.basename(fn.__code__.co_filename),
fn.__code__.co_firstlineno)
return fn.__name__
return str(fn)
class _NamedPTransform(PTransform):
def __init__(self, transform, label):
super(_NamedPTransform, self).__init__(label)
self.transform = transform
def __ror__(self, pvalueish, _unused=None):
return self.transform.__ror__(pvalueish, self.label)
def expand(self, pvalue):
raise RuntimeError("Should never be expanded directly.")