# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import logging
import pprint
import re
import uuid
from typing import Iterable
from typing import Mapping
import yaml
from yaml.loader import SafeLoader
import apache_beam as beam
from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform
from apache_beam.yaml import yaml_provider
__all__ = ["YamlTransform"]
_LOGGER = logging.getLogger(__name__)
def memoize_method(func):
def wrapper(self, *args):
if not hasattr(self, '_cache'):
self._cache = {}
key = func.__name__, args
if key not in self._cache:
self._cache[key] = func(self, *args)
return self._cache[key]
return wrapper
def only_element(xs):
x, = xs
return x
class SafeLineLoader(SafeLoader):
"""A yaml loader that attaches line information to mappings and strings."""
class TaggedString(str):
"""A string class to which we can attach metadata.
This is primarily used to trace a string's origin back to its place in a
yaml file.
def __reduce__(self):
# Pickle as an ordinary string.
return str, (str(self), )
def construct_scalar(self, node):
value = super().construct_scalar(node)
if isinstance(value, str):
value = SafeLineLoader.TaggedString(value)
value._line_ = node.start_mark.line + 1
return value
def construct_mapping(self, node, deep=False):
mapping = super().construct_mapping(node, deep=deep)
mapping['__line__'] = node.start_mark.line + 1
mapping['__uuid__'] = self.create_uuid()
return mapping
def create_uuid(cls):
return str(uuid.uuid4())
def strip_metadata(cls, spec, tagged_str=True):
if isinstance(spec, Mapping):
return {
key: cls.strip_metadata(value, tagged_str)
for key,
value in spec.items() if key not in ('__line__', '__uuid__')
elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)):
return [cls.strip_metadata(value, tagged_str) for value in spec]
elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str:
return str(spec)
return spec
def get_line(obj):
if isinstance(obj, dict):
return obj.get('__line__', 'unknown')
return getattr(obj, '_line_', 'unknown')
class LightweightScope(object):
def __init__(self, transforms):
self._transforms = transforms
self._transforms_by_uuid = {t['__uuid__']: t for t in self._transforms}
self._uuid_by_name = collections.defaultdict(list)
for spec in self._transforms:
if 'name' in spec:
if 'type' in spec:
def get_transform_id_and_output_name(self, name):
if '.' in name:
transform_name, output = name.rsplit('.', 1)
transform_name, output = name, None
return self.get_transform_id(transform_name), output
def get_transform_id(self, transform_name):
if transform_name in self._transforms_by_uuid:
return transform_name
candidates = self._uuid_by_name[transform_name]
if not candidates:
raise ValueError(
f'Unknown transform at line '
f'{SafeLineLoader.get_line(transform_name)}: {transform_name}')
elif len(candidates) > 1:
raise ValueError(
f'Ambiguous transform at line '
f'{SafeLineLoader.get_line(transform_name)}: {transform_name}')
return only_element(candidates)
class Scope(LightweightScope):
"""To look up PCollections (typically outputs of prior transforms) by name."""
def __init__(self, root, inputs, transforms, providers):
self.root = root
self._inputs = inputs
self.providers = providers
self._seen_names = set()
def compute_all(self):
for transform_id in self._transforms_by_uuid.keys():
def get_pcollection(self, name):
if name in self._inputs:
return self._inputs[name]
elif '.' in name:
transform, output = name.rsplit('.', 1)
outputs = self.get_outputs(transform)
if output in outputs:
return outputs[output]
raise ValueError(
f'Unknown output {repr(output)} '
f'at line {SafeLineLoader.get_line(name)}: '
f'{transform} only has outputs {list(outputs.keys())}')
outputs = self.get_outputs(name)
if len(outputs) == 1:
return only_element(outputs.values())
raise ValueError(
f'Ambiguous output at line {SafeLineLoader.get_line(name)}: '
f'{name} has outputs {list(outputs.keys())}')
def get_outputs(self, transform_name):
return self.compute_outputs(self.get_transform_id(transform_name))
def compute_outputs(self, transform_id):
return expand_transform(self._transforms_by_uuid[transform_id], self)
# A method on scope as providers may be scoped...
def create_ptransform(self, spec):
if 'type' not in spec:
raise ValueError(f'Missing transform type: {identify_object(spec)}')
if spec['type'] not in self.providers:
raise ValueError(
'Unknown transform type %r at %s' %
(spec['type'], identify_object(spec)))
for provider in self.providers.get(spec['type']):
if provider.available():
raise ValueError(
'No available provider for type %r at %s' %
(spec['type'], identify_object(spec)))
if 'args' in spec:
args = spec['args']
if not isinstance(args, dict):
raise ValueError(
'Arguments for transform at %s must be a mapping.' %
args = {
key: value
for (key, value) in spec.items()
if key not in ('type', 'name', 'input', 'output')
real_args = SafeLineLoader.strip_metadata(args)
# pylint: disable=undefined-loop-variable
ptransform = provider.create_transform(spec['type'], real_args)
# TODO(robertwb): Should we have a better API for adding annotations
# than this?
annotations = dict(
ptransform.annotations = lambda: annotations
return ptransform
except Exception as exn:
if isinstance(exn, TypeError):
# Create a slightly more generic error message for argument errors.
msg = str(exn).replace('positional', '').replace('keyword', '')
msg = re.sub(r'\S+lambda\S+', '', msg)
msg = re.sub(' +', ' ', msg).strip()
msg = str(exn)
raise ValueError(
f'Invalid transform specification at {identify_object(spec)}: {msg}'
) from exn
def unique_name(self, spec, ptransform, strictness=0):
if 'name' in spec:
name = spec['name']
strictness += 1
name = ptransform.label
if name in self._seen_names:
if strictness >= 2:
raise ValueError(f'Duplicate name at {identify_object(spec)}: {name}')
name = f'{name}@{SafeLineLoader.get_line(spec)}'
return name
def expand_transform(spec, scope):
if 'type' not in spec:
raise TypeError(
f'Missing type parameter for transform at {identify_object(spec)}')
type = spec['type']
if type == 'composite':
return expand_composite_transform(spec, scope)
return expand_leaf_transform(spec, scope)
def expand_leaf_transform(spec, scope):
spec = normalize_inputs_outputs(spec)
inputs_dict = {
key: scope.get_pcollection(value)
for (key, value) in spec['input'].items()
input_type = spec.get('input_type', 'default')
if input_type == 'list':
inputs = tuple(inputs_dict.values())
elif input_type == 'map':
inputs = inputs_dict
if len(inputs_dict) == 0:
inputs = scope.root
elif len(inputs_dict) == 1:
inputs = next(iter(inputs_dict.values()))
inputs = inputs_dict
_LOGGER.info("Expanding %s ", identify_object(spec))
ptransform = scope.create_ptransform(spec)
# TODO: Move validation to construction?
with FullyQualifiedNamedTransform.with_filter('*'):
outputs = inputs | scope.unique_name(spec, ptransform) >> ptransform
except Exception as exn:
raise ValueError(
f"Error apply transform {identify_object(spec)}: {exn}") from exn
if isinstance(outputs, dict):
# TODO: Handle (or at least reject) nested case.
return outputs
elif isinstance(outputs, (tuple, list)):
return {'out{ix}': pcoll for (ix, pcoll) in enumerate(outputs)}
elif isinstance(outputs, beam.PCollection):
return {'out': outputs}
raise ValueError(
f'Transform {identify_object(spec)} returned an unexpected type '
def expand_composite_transform(spec, scope):
spec = normalize_inputs_outputs(normalize_source_sink(spec))
inner_scope = Scope(
scope.root, {
key: scope.get_pcollection(value)
for key,
value in spec['input'].items()
yaml_provider.parse_providers(spec.get('providers', [])),
class CompositePTransform(beam.PTransform):
def expand(inputs):
return {
key: inner_scope.get_pcollection(value)
for (key, value) in spec['output'].items()
if 'name' not in spec:
spec['name'] = 'Composite'
if spec['name'] is None: # top-level pipeline, don't nest
return CompositePTransform.expand(None)
_LOGGER.info("Expanding %s ", identify_object(spec))
return ({
key: scope.get_pcollection(value)
for key,
value in spec['input'].items()
} or scope.root) | scope.unique_name(spec, None) >> CompositePTransform()
def expand_chain_transform(spec, scope):
return expand_composite_transform(chain_as_composite(spec), scope)
def chain_as_composite(spec):
# A chain is simply a composite transform where all inputs and outputs
# are implicit.
spec = normalize_source_sink(spec)
if 'transforms' not in spec:
raise TypeError(
f"Chain at {identify_object(spec)} missing transforms property.")
has_explicit_outputs = 'output' in spec
composite_spec = normalize_inputs_outputs(spec)
new_transforms = []
for ix, transform in enumerate(composite_spec['transforms']):
if any(io in transform for io in ('input', 'output', 'input', 'output')):
raise ValueError(
f'Transform {identify_object(transform)} is part of a chain, '
'must have implicit inputs and outputs.')
if ix == 0:
transform['input'] = {key: key for key in composite_spec['input'].keys()}
transform['input'] = new_transforms[-1]['__uuid__']
composite_spec['transforms'] = new_transforms
last_transform = new_transforms[-1]['__uuid__']
if has_explicit_outputs:
composite_spec['output'] = {
key: f'{last_transform}.{value}'
for (key, value) in composite_spec['output'].items()
composite_spec['output'] = last_transform
if 'name' not in composite_spec:
composite_spec['name'] = 'Chain'
composite_spec['type'] = 'composite'
return composite_spec
def preprocess_chain(spec):
if spec['type'] == 'chain':
return chain_as_composite(spec)
return spec
def pipeline_as_composite(spec):
if isinstance(spec, list):
return {
'type': 'composite',
'name': None,
'transforms': spec,
'__line__': spec[0]['__line__'],
'__uuid__': SafeLineLoader.create_uuid(),
return dict(spec, name=None, type=spec.get('type', 'composite'))
def normalize_source_sink(spec):
if 'source' not in spec and 'sink' not in spec:
return spec
spec = dict(spec)
spec['transforms'] = list(spec.get('transforms', []))
if 'source' in spec:
spec['transforms'].insert(0, spec.pop('source'))
if 'sink' in spec:
return spec
def preprocess_source_sink(spec):
if spec['type'] in ('chain', 'composite'):
return normalize_source_sink(spec)
return spec
def normalize_inputs_outputs(spec):
spec = dict(spec)
def normalize_io(tag):
io = spec.get(tag, {})
if isinstance(io, (str, list)):
return {tag: io}
return SafeLineLoader.strip_metadata(io, tagged_str=False)
return dict(spec, input=normalize_io('input'), output=normalize_io('output'))
def identify_object(spec):
line = SafeLineLoader.get_line(spec)
name = extract_name(spec)
if name:
return f'"{name}" at line {line}'
return f'at line {line}'
def extract_name(spec):
if 'name' in spec:
return spec['name']
elif 'id' in spec:
return spec['id']
elif 'type' in spec:
return spec['type']
elif len(spec) == 1:
return extract_name(next(iter(spec.values())))
return ''
def push_windowing_to_roots(spec):
scope = LightweightScope(spec['transforms'])
consumed_outputs_by_transform = collections.defaultdict(set)
for transform in spec['transforms']:
for _, input_ref in transform['input'].items():
transform_id, output = scope.get_transform_id_and_output_name(input_ref)
except ValueError:
# Could be an input or an ambiguity we'll raise later.
for transform in spec['transforms']:
if not transform['input'] and 'windowing' not in transform:
transform['windowing'] = spec['windowing']
transform['__consumed_outputs'] = consumed_outputs_by_transform[
return spec
def preprocess_windowing(spec):
if spec['type'] == 'WindowInto':
# This is the transform where it is actually applied.
return spec
elif 'windowing' not in spec:
# Nothing to do.
return spec
if spec['type'] == 'composite':
# Apply the windowing to any reads, creates, etc. in this transform
# TODO(robertwb): Better handle the case where a read is followed by a
# setting of the timestamps. We should be careful of sliding windows
# in particular.
spec = push_windowing_to_roots(spec)
windowing = spec.pop('windowing')
if spec['input']:
# Apply the windowing to all inputs by wrapping it in a trasnform that
# first applies windowing and then applies the original transform.
original_inputs = spec['input']
windowing_transforms = [{
'type': 'WindowInto',
'name': f'WindowInto[{key}]',
'windowing': windowing,
'input': key,
'__line__': spec['__line__'],
'__uuid__': SafeLineLoader.create_uuid(),
} for key in original_inputs.keys()]
windowed_inputs = {
key: t['__uuid__']
for (key, t) in zip(original_inputs.keys(), windowing_transforms)
modified_spec = dict(
spec, input=windowed_inputs, __uuid__=SafeLineLoader.create_uuid())
return {
'type': 'composite',
'name': spec.get('name', None) or spec['type'],
'transforms': [modified_spec] + windowing_transforms,
'input': spec['input'],
'output': modified_spec['__uuid__'],
'__line__': spec['__line__'],
'__uuid__': spec['__uuid__'],
elif spec['type'] == 'composite':
# Pushing the windowing down was sufficient.
return spec
# No inputs, apply the windowing to all outputs.
consumed_outputs = list(spec.pop('__consumed_outputs', {None}))
modified_spec = dict(spec, __uuid__=SafeLineLoader.create_uuid())
windowing_transforms = [{
'type': 'WindowInto',
'name': f'WindowInto[{out}]',
'windowing': windowing,
'input': modified_spec['__uuid__'] + ('.' + out if out else ''),
'__line__': spec['__line__'],
'__uuid__': SafeLineLoader.create_uuid(),
} for out in consumed_outputs]
if consumed_outputs == [None]:
windowed_outputs = only_element(windowing_transforms)['__uuid__']
windowed_outputs = {
out: t['__uuid__']
for (out, t) in zip(consumed_outputs, windowing_transforms)
return {
'type': 'composite',
'name': spec.get('name', None) or spec['type'],
'transforms': [modified_spec] + windowing_transforms,
'output': windowed_outputs,
'__line__': spec['__line__'],
'__uuid__': spec['__uuid__'],
def preprocess_flattened_inputs(spec):
if spec['type'] != 'composite':
return spec
# Prefer to add the flattens as sibling operations rather than nesting
# to keep graph shape consistent when the number of inputs goes from
# one to multiple.
new_transforms = []
for t in spec['transforms']:
if t['type'] == 'Flatten':
# Don't flatten before explicit flatten.
# But we do have to expand list inputs into singleton inputs.
def all_inputs(t):
for key, values in t.get('input', {}).items():
if isinstance(values, list):
for ix, values in enumerate(values):
yield f'{key}{ix}', values
yield key, values
inputs_dict = {}
for key, value in all_inputs(t):
while key in inputs_dict:
key += '_'
inputs_dict[key] = value
t = dict(t, input=inputs_dict)
replaced_inputs = {}
for key, values in t.get('input', {}).items():
if isinstance(values, list):
flatten_id = SafeLineLoader.create_uuid()
'type': 'Flatten',
'name': '%s-Flatten[%s]' % (t.get('name', t['type']), key),
'input': {
f'input{ix}': value
for (ix, value) in enumerate(values)
'__line__': spec['__line__'],
'__uuid__': flatten_id,
replaced_inputs[key] = flatten_id
if replaced_inputs:
t = dict(t, input={**t['input'], **replaced_inputs})
return dict(spec, transforms=new_transforms)
def preprocess(spec, verbose=False):
if verbose:
def apply(phase, spec):
spec = phase(spec)
if spec['type'] in {'composite', 'chain'}:
spec = dict(
spec, transforms=[apply(phase, t) for t in spec['transforms']])
return spec
for phase in [preprocess_source_sink,
spec = apply(phase, spec)
if verbose:
print('=' * 20, phase, '=' * 20)
return spec
def expand_pipeline(pipeline, pipeline_spec, providers=None):
if isinstance(pipeline_spec, str):
pipeline_spec = yaml.load(pipeline_spec, Loader=SafeLineLoader)
# Calling expand directly to avoid outer layer of nesting.
return YamlTransform(
**yaml_provider.parse_providers(pipeline_spec.get('providers', [])),
key: yaml_provider.as_provider_list(key, value)
for (key, value) in (providers or {}).items()