#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import List
from typing import Mapping
from typing import Tuple
from typing import TypeVar
from typing import Union
import pandas as pd
import apache_beam as beam
from apache_beam import transforms
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frames # pylint: disable=unused-import
from apache_beam.dataframe import partitionings
if TYPE_CHECKING:
# pylint: disable=ungrouped-imports
from apache_beam.pvalue import PCollection
T = TypeVar('T')
class _DataframeExpressionsTransform(transforms.PTransform):
def __init__(self, outputs):
self._outputs = outputs
def expand(self, inputs):
return self._apply_deferred_ops(inputs, self._outputs)
def _apply_deferred_ops(
self,
inputs, # type: Dict[expressions.Expression, PCollection]
outputs, # type: Dict[Any, expressions.Expression]
): # -> Dict[Any, PCollection]
"""Construct a Beam graph that evaluates a set of expressions on a set of
input PCollections.
:param inputs: A mapping of placeholder expressions to PCollections.
:param outputs: A mapping of keys to expressions defined in terms of the
placeholders of inputs.
Returns a dictionary whose keys are those of outputs, and whose values are
PCollections corresponding to the values of outputs evaluated at the
values of inputs.
Logically, `_apply_deferred_ops({x: a, y: b}, {f: F(x, y), g: G(x, y)})`
returns `{f: F(a, b), g: G(a, b)}`.
"""
class ComputeStage(beam.PTransform):
"""A helper transform that computes a single stage of operations.
"""
def __init__(self, stage):
self.stage = stage
def default_label(self):
return '%s:%s' % (self.stage.ops, id(self))
def expand(self, pcolls):
if self.stage.partitioning != partitionings.Nothing():
# Arrange such that partitioned_pcoll is properly partitioned.
input_pcolls = {
tag: pcoll | 'Flat%s' % tag >> beam.FlatMap(
self.stage.partitioning.partition_fn)
for (tag, pcoll) in pcolls.items()
}
partitioned_pcoll = input_pcolls | beam.CoGroupByKey(
) | beam.MapTuple(
lambda _,
inputs: {tag: pd.concat(vs)
for tag, vs in inputs.items()})
else:
# Already partitioned, or no partitioning needed.
(tag, pcoll), = pcolls.items()
partitioned_pcoll = pcoll | beam.Map(lambda df: {tag: df})
# Actually evaluate the expressions.
def evaluate(partition, stage=self.stage):
session = expressions.Session(
{expr: partition[expr._id]
for expr in stage.inputs})
for expr in stage.outputs:
yield beam.pvalue.TaggedOutput(expr._id, expr.evaluate_at(session))
return partitioned_pcoll | beam.FlatMap(evaluate).with_outputs()
class Stage(object):
"""Used to build up a set of operations that can be fused together.
"""
def __init__(self, inputs, partitioning):
self.inputs = set(inputs)
if len(self.inputs) > 1 and partitioning == partitionings.Nothing():
# We have to shuffle to co-locate, might as well partition.
self.partitioning = partitionings.Index()
else:
self.partitioning = partitioning
self.ops = []
self.outputs = set()
# First define some helper functions.
def output_is_partitioned_by(expr, stage, partitioning):
if partitioning == partitionings.Nothing():
# Always satisfied.
return True
elif stage.partitioning == partitionings.Singleton():
# Within a stage, the singleton partitioning is trivially preserved.
return True
elif expr in stage.inputs:
# Inputs are all partitioned by stage.partitioning.
return stage.partitioning.is_subpartitioning_of(partitioning)
elif expr.preserves_partition_by().is_subpartitioning_of(partitioning):
# Here expr preserves at least the requested partitioning; its outputs
# will also have this partitioning iff its inputs do.
if expr.requires_partition_by().is_subpartitioning_of(partitioning):
# If expr requires at least this partitioning, we will arrange such
# that its inputs satisfy this.
return True
else:
# Otherwise, recursively check all the inputs.
return all(
output_is_partitioned_by(arg, stage, partitioning)
for arg in expr.args())
else:
return False
def common_stages(stage_lists):
# Set intersection, with a preference for earlier items in the list.
if stage_lists:
for stage in stage_lists[0]:
if all(stage in other for other in stage_lists[1:]):
yield stage
@memoize
def expr_to_stages(expr):
assert expr not in inputs
# First attempt to compute this expression as part of an existing stage,
# if possible.
#
# If expr does not require partitioning, just grab any stage, else grab
# the first stage where all of expr's inputs are partitioned as required.
# In either case, use the first such stage because earlier stages are
# closer to the inputs (have fewer intermediate stages).
required_partitioning = expr.requires_partition_by()
for stage in common_stages([expr_to_stages(arg) for arg in expr.args()
if arg not in inputs]):
if all(output_is_partitioned_by(arg, stage, required_partitioning)
for arg in expr.args()):
break
else:
# Otherwise, compute this expression as part of a new stage.
stage = Stage(expr.args(), required_partitioning)
for arg in expr.args():
if arg not in inputs:
# For each non-input argument, declare that it is also available in
# this new stage.
expr_to_stages(arg).append(stage)
# It also must be declared as an output of the producing stage.
expr_to_stage(arg).outputs.add(arg)
stage.ops.append(expr)
# This is a list as given expression may be available in many stages.
return [stage]
def expr_to_stage(expr):
# Any will do; the first requires the fewest intermediate stages.
return expr_to_stages(expr)[0]
# Ensure each output is computed.
for expr in outputs.values():
if expr not in inputs:
expr_to_stage(expr).outputs.add(expr)
@memoize
def stage_to_result(stage):
return {expr._id: expr_to_pcoll(expr)
for expr in stage.inputs} | ComputeStage(stage)
@memoize
def expr_to_pcoll(expr):
if expr in inputs:
return inputs[expr]
else:
return stage_to_result(expr_to_stage(expr))[expr._id]
# Now we can compute and return the result.
return {k: expr_to_pcoll(expr) for k, expr in outputs.items()}
[docs]def memoize(f):
cache = {}
def wrapper(*args):
if args not in cache:
cache[args] = f(*args)
return cache[args]
return wrapper
def _dict_union(dicts):
result = {}
for d in dicts:
result.update(d)
return result
def _flatten(
valueish, # type: Union[T, List[T], Tuple[T], Dict[Any, T]]
root=(), # type: Tuple[Any, ...]
):
# type: (...) -> Mapping[Tuple[Any, ...], T]
"""Given a nested structure of dicts, tuples, and lists, return a flat
dictionary where the values are the leafs and the keys are the "paths" to
these leaves.
For example `{a: x, b: (y, z)}` becomes `{(a,): x, (b, 0): y, (b, 1): c}`.
"""
if isinstance(valueish, dict):
return _dict_union(_flatten(v, root + (k, )) for k, v in valueish.items())
elif isinstance(valueish, (tuple, list)):
return _dict_union(
_flatten(v, root + (ix, )) for ix, v in enumerate(valueish))
else:
return {root: valueish}
def _substitute(valueish, replacements, root=()):
"""Substitutes the values in valueish with those in replacements where the
keys are as in _flatten.
For example,
```
_substitute(
{a: x, b: (y, z)},
{(a,): X, (b, 0): Y, (b, 1): Z})
```
returns `{a: X, b: (Y, Z)}`.
"""
if isinstance(valueish, dict):
return type(valueish)({
k: _substitute(v, replacements, root + (k, ))
for (k, v) in valueish.items()
})
elif isinstance(valueish, (tuple, list)):
return type(valueish)((
_substitute(v, replacements, root + (ix, ))
for (ix, v) in enumerate(valueish)))
else:
return replacements[root]