#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""PTransform and descendants.
A PTransform is an object describing (not executing) a computation. The actual
execution semantics for a transform is captured by a runner object. A transform
object always belongs to a pipeline object.
A PTransform derived class needs to define the expand() method that describes
how one or more PValues are created by the transform.
The module defines a few standard transforms: FlatMap (parallel do),
GroupByKey (group by key), etc. Note that the expand() methods for these
classes contain code that will add nodes to the processing graph associated
with a pipeline.
As support for the FlatMap transform, the module also defines a DoFn
class and wrapper class that allows lambda functions to be used as
FlatMap processing functions.
"""
from __future__ import absolute_import
import copy
import inspect
import itertools
import operator
import os
import sys
import threading
from functools import reduce
from google.protobuf import message
from apache_beam import error
from apache_beam import pvalue
from apache_beam.internal import pickler
from apache_beam.internal import util
from apache_beam.portability import python_urns
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.transforms.display import HasDisplayData
from apache_beam.typehints import typehints
from apache_beam.typehints.decorators import TypeCheckError
from apache_beam.typehints.decorators import WithTypeHints
from apache_beam.typehints.decorators import getcallargs_forhints
from apache_beam.typehints.trivial_inference import instance_to_type
from apache_beam.typehints.typehints import validate_composite_type_param
from apache_beam.utils import proto_utils
__all__ = [
'PTransform',
'ptransform_fn',
'label_from_callable',
]
class _PValueishTransform(object):
"""Visitor for PValueish objects.
A PValueish is a PValue, or list, tuple, dict of PValuesish objects.
This visits a PValueish, contstructing a (possibly mutated) copy.
"""
def visit_nested(self, node, *args):
if isinstance(node, (tuple, list)):
args = [self.visit(x, *args) for x in node]
if isinstance(node, tuple) and hasattr(node.__class__, '_make'):
# namedtuples require unpacked arguments in their constructor
return node.__class__(*args)
else:
return node.__class__(args)
elif isinstance(node, dict):
return node.__class__(
{key: self.visit(value, *args) for (key, value) in node.items()})
else:
return node
class _SetInputPValues(_PValueishTransform):
def visit(self, node, replacements):
if id(node) in replacements:
return replacements[id(node)]
else:
return self.visit_nested(node, replacements)
# Caches to allow for materialization of values when executing a pipeline
# in-process, in eager mode. This cache allows the same _MaterializedResult
# object to be accessed and used despite Runner API round-trip serialization.
_pipeline_materialization_cache = {}
_pipeline_materialization_lock = threading.Lock()
def _allocate_materialized_pipeline(pipeline):
pid = os.getpid()
with _pipeline_materialization_lock:
pipeline_id = id(pipeline)
_pipeline_materialization_cache[(pid, pipeline_id)] = {}
def _allocate_materialized_result(pipeline):
pid = os.getpid()
with _pipeline_materialization_lock:
pipeline_id = id(pipeline)
if (pid, pipeline_id) not in _pipeline_materialization_cache:
raise ValueError('Materialized pipeline is not allocated for result '
'cache.')
result_id = len(_pipeline_materialization_cache[(pid, pipeline_id)])
result = _MaterializedResult(pipeline_id, result_id)
_pipeline_materialization_cache[(pid, pipeline_id)][result_id] = result
return result
def _get_materialized_result(pipeline_id, result_id):
pid = os.getpid()
with _pipeline_materialization_lock:
if (pid, pipeline_id) not in _pipeline_materialization_cache:
raise Exception(
'Materialization in out-of-process and remote runners is not yet '
'supported.')
return _pipeline_materialization_cache[(pid, pipeline_id)][result_id]
def _release_materialized_pipeline(pipeline):
pid = os.getpid()
with _pipeline_materialization_lock:
pipeline_id = id(pipeline)
del _pipeline_materialization_cache[(pid, pipeline_id)]
class _MaterializedResult(object):
def __init__(self, pipeline_id, result_id):
self._pipeline_id = pipeline_id
self._result_id = result_id
self.elements = []
def __reduce__(self):
# When unpickled (during Runner API roundtrip serailization), get the
# _MaterializedResult object from the cache so that values are written
# to the original _MaterializedResult when run in eager mode.
return (_get_materialized_result, (self._pipeline_id, self._result_id))
class _MaterializedDoOutputsTuple(pvalue.DoOutputsTuple):
def __init__(self, deferred, results_by_tag):
super(_MaterializedDoOutputsTuple, self).__init__(
None, None, deferred._tags, deferred._main_tag)
self._deferred = deferred
self._results_by_tag = results_by_tag
def __getitem__(self, tag):
if tag not in self._results_by_tag:
raise KeyError(
'Tag %r is not a a defined output tag of %s.' % (
tag, self._deferred))
return self._results_by_tag[tag].elements
class _AddMaterializationTransforms(_PValueishTransform):
def _materialize_transform(self, pipeline):
result = _allocate_materialized_result(pipeline)
# Need to define _MaterializeValuesDoFn here to avoid circular
# dependencies.
from apache_beam import DoFn
from apache_beam import ParDo
class _MaterializeValuesDoFn(DoFn):
def process(self, element):
result.elements.append(element)
materialization_label = '_MaterializeValues%d' % result._result_id
return (materialization_label >> ParDo(_MaterializeValuesDoFn()),
result)
def visit(self, node):
if isinstance(node, pvalue.PValue):
transform, result = self._materialize_transform(node.pipeline)
node | transform
return result
elif isinstance(node, pvalue.DoOutputsTuple):
results_by_tag = {}
for tag in itertools.chain([node._main_tag], node._tags):
results_by_tag[tag] = self.visit(node[tag])
return _MaterializedDoOutputsTuple(node, results_by_tag)
else:
return self.visit_nested(node)
class _FinalizeMaterialization(_PValueishTransform):
def visit(self, node):
if isinstance(node, _MaterializedResult):
return node.elements
elif isinstance(node, _MaterializedDoOutputsTuple):
return node
else:
return self.visit_nested(node)
class _GetPValues(_PValueishTransform):
def visit(self, node, pvalues):
if isinstance(node, (pvalue.PValue, pvalue.DoOutputsTuple)):
pvalues.append(node)
else:
self.visit_nested(node, pvalues)
def get_nested_pvalues(pvalueish):
pvalues = []
_GetPValues().visit(pvalueish, pvalues)
return pvalues
class _ZipPValues(object):
"""Pairs each PValue in a pvalueish with a value in a parallel out sibling.
Sibling should have the same nested structure as pvalueish. Leaves in
sibling are expanded across nested pvalueish lists, tuples, and dicts.
For example
ZipPValues().visit({'a': pc1, 'b': (pc2, pc3)},
{'a': 'A', 'b', 'B'})
will return
[('a', pc1, 'A'), ('b', pc2, 'B'), ('b', pc3, 'B')]
"""
def visit(self, pvalueish, sibling, pairs=None, context=None):
if pairs is None:
pairs = []
self.visit(pvalueish, sibling, pairs, context)
return pairs
elif isinstance(pvalueish, (pvalue.PValue, pvalue.DoOutputsTuple)):
pairs.append((context, pvalueish, sibling))
elif isinstance(pvalueish, (list, tuple)):
self.visit_sequence(pvalueish, sibling, pairs, context)
elif isinstance(pvalueish, dict):
self.visit_dict(pvalueish, sibling, pairs, context)
def visit_sequence(self, pvalueish, sibling, pairs, context):
if isinstance(sibling, (list, tuple)):
for ix, (p, s) in enumerate(zip(
pvalueish, list(sibling) + [None] * len(pvalueish))):
self.visit(p, s, pairs, 'position %s' % ix)
else:
for p in pvalueish:
self.visit(p, sibling, pairs, context)
def visit_dict(self, pvalueish, sibling, pairs, context):
if isinstance(sibling, dict):
for key, p in pvalueish.items():
self.visit(p, sibling.get(key), pairs, key)
else:
for p in pvalueish.values():
self.visit(p, sibling, pairs, context)
@PTransform.register_urn(python_urns.GENERIC_COMPOSITE_TRANSFORM, None)
def _create_transform(payload, unused_context):
empty_transform = PTransform()
empty_transform._fn_api_payload = payload
return empty_transform
@PTransform.register_urn(python_urns.PICKLED_TRANSFORM, None)
def _unpickle_transform(pickled_bytes, unused_context):
return pickler.loads(pickled_bytes)
class _ChainedPTransform(PTransform):
def __init__(self, *parts):
super(_ChainedPTransform, self).__init__(label=self._chain_label(parts))
self._parts = parts
def _chain_label(self, parts):
return '|'.join(p.label for p in parts)
def __or__(self, right):
if isinstance(right, PTransform):
# Create a flat list rather than a nested tree of composite
# transforms for better monitoring, etc.
return _ChainedPTransform(*(self._parts + (right,)))
return NotImplemented
def expand(self, pval):
return reduce(operator.or_, self._parts, pval)
class PTransformWithSideInputs(PTransform):
"""A superclass for any :class:`PTransform` (e.g.
:func:`~apache_beam.transforms.core.FlatMap` or
:class:`~apache_beam.transforms.core.CombineFn`)
invoking user code.
:class:`PTransform` s like :func:`~apache_beam.transforms.core.FlatMap`
invoke user-supplied code in some kind of package (e.g. a
:class:`~apache_beam.transforms.core.DoFn`) and optionally provide arguments
and side inputs to that code. This internal-use-only class contains common
functionality for :class:`PTransform` s that fit this model.
"""
def __init__(self, fn, *args, **kwargs):
if isinstance(fn, type) and issubclass(fn, WithTypeHints):
# Don't treat Fn class objects as callables.
raise ValueError('Use %s() not %s.' % (fn.__name__, fn.__name__))
self.fn = self.make_fn(fn)
# Now that we figure out the label, initialize the super-class.
super(PTransformWithSideInputs, self).__init__()
if (any([isinstance(v, pvalue.PCollection) for v in args]) or
any([isinstance(v, pvalue.PCollection) for v in kwargs.itervalues()])):
raise error.SideInputError(
'PCollection used directly as side input argument. Specify '
'AsIter(pcollection) or AsSingleton(pcollection) to indicate how the '
'PCollection is to be used.')
self.args, self.kwargs, self.side_inputs = util.remove_objects_from_args(
args, kwargs, pvalue.AsSideInput)
self.raw_side_inputs = args, kwargs
# Prevent name collisions with fns of the form '<function <lambda> at ...>'
self._cached_fn = self.fn
# Ensure fn and side inputs are picklable for remote execution.
self.fn = pickler.loads(pickler.dumps(self.fn))
self.args = pickler.loads(pickler.dumps(self.args))
self.kwargs = pickler.loads(pickler.dumps(self.kwargs))
# For type hints, because loads(dumps(class)) != class.
self.fn = self._cached_fn
def with_input_types(
self, input_type_hint, *side_inputs_arg_hints, **side_input_kwarg_hints):
"""Annotates the types of main inputs and side inputs for the PTransform.
Args:
input_type_hint: An instance of an allowed built-in type, a custom class,
or an instance of a typehints.TypeConstraint.
*side_inputs_arg_hints: A variable length argument composed of
of an allowed built-in type, a custom class, or a
typehints.TypeConstraint.
**side_input_kwarg_hints: A dictionary argument composed of
of an allowed built-in type, a custom class, or a
typehints.TypeConstraint.
Example of annotating the types of side-inputs::
FlatMap().with_input_types(int, int, bool)
Raises:
:class:`~exceptions.TypeError`: If **type_hint** is not a valid type-hint.
See
:func:`~apache_beam.typehints.typehints.validate_composite_type_param`
for further details.
Returns:
:class:`PTransform`: A reference to the instance of this particular
:class:`PTransform` object. This allows chaining type-hinting related
methods.
"""
super(PTransformWithSideInputs, self).with_input_types(input_type_hint)
for si in side_inputs_arg_hints:
validate_composite_type_param(si, 'Type hints for a PTransform')
for si in side_input_kwarg_hints.values():
validate_composite_type_param(si, 'Type hints for a PTransform')
self.side_inputs_types = side_inputs_arg_hints
return WithTypeHints.with_input_types(
self, input_type_hint, *side_inputs_arg_hints, **side_input_kwarg_hints)
def type_check_inputs(self, pvalueish):
type_hints = self.get_type_hints().input_types
if type_hints:
args, kwargs = self.raw_side_inputs
def element_type(side_input):
if isinstance(side_input, pvalue.AsSideInput):
return side_input.element_type
return instance_to_type(side_input)
arg_types = [pvalueish.element_type] + [element_type(v) for v in args]
kwargs_types = {k: element_type(v) for (k, v) in kwargs.items()}
argspec_fn = self._process_argspec_fn()
bindings = getcallargs_forhints(argspec_fn, *arg_types, **kwargs_types)
hints = getcallargs_forhints(argspec_fn, *type_hints[0], **type_hints[1])
for arg, hint in hints.items():
if arg.startswith('__unknown__'):
continue
if hint is None:
continue
if not typehints.is_consistent_with(
bindings.get(arg, typehints.Any), hint):
raise TypeCheckError(
'Type hint violation for \'%s\': requires %s but got %s for %s'
% (self.label, hint, bindings[arg], arg))
def _process_argspec_fn(self):
"""Returns an argspec of the function actually consuming the data.
"""
raise NotImplementedError
def make_fn(self, fn):
# TODO(silviuc): Add comment describing that this is meant to be overriden
# by methods detecting callables and wrapping them in DoFns.
return fn
def default_label(self):
return '%s(%s)' % (self.__class__.__name__, self.fn.default_label())
class _PTransformFnPTransform(PTransform):
"""A class wrapper for a function-based transform."""
def __init__(self, fn, *args, **kwargs):
super(_PTransformFnPTransform, self).__init__()
self._fn = fn
self._args = args
self._kwargs = kwargs
def display_data(self):
res = {'fn': (self._fn.__name__
if hasattr(self._fn, '__name__')
else self._fn.__class__),
'args': DisplayDataItem(str(self._args)).drop_if_default('()'),
'kwargs': DisplayDataItem(str(self._kwargs)).drop_if_default('{}')}
return res
def expand(self, pcoll):
# Since the PTransform will be implemented entirely as a function
# (once called), we need to pass through any type-hinting information that
# may have been annotated via the .with_input_types() and
# .with_output_types() methods.
kwargs = dict(self._kwargs)
args = tuple(self._args)
try:
if 'type_hints' in inspect.getargspec(self._fn).args:
args = (self.get_type_hints(),) + args
except TypeError:
# Might not be a function.
pass
return self._fn(pcoll, *args, **kwargs)
def default_label(self):
if self._args:
return '%s(%s)' % (
label_from_callable(self._fn), label_from_callable(self._args[0]))
return label_from_callable(self._fn)
[docs]def label_from_callable(fn):
if hasattr(fn, 'default_label'):
return fn.default_label()
elif hasattr(fn, '__name__'):
if fn.__name__ == '<lambda>':
return '<lambda at %s:%s>' % (
os.path.basename(fn.__code__.co_filename),
fn.__code__.co_firstlineno)
return fn.__name__
return str(fn)
class _NamedPTransform(PTransform):
def __init__(self, transform, label):
super(_NamedPTransform, self).__init__(label)
self.transform = transform
def __ror__(self, pvalueish, _unused=None):
return self.transform.__ror__(pvalueish, self.label)
def expand(self, pvalue):
raise RuntimeError("Should never be expanded directly.")