#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This module defines Providers usable from yaml, which is a specification
for where to find and how to invoke services that vend implementations of
various PTransforms."""
import collections
import hashlib
import inspect
import json
import logging
import os
import re
import subprocess
import sys
import urllib.parse
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Mapping
from typing import Optional
import docstring_parser
import yaml
from yaml.loader import SafeLoader
import apache_beam as beam
import apache_beam.dataframe.io
import apache_beam.io
import apache_beam.transforms.util
from apache_beam.portability.api import schema_pb2
from apache_beam.transforms import external
from apache_beam.transforms import window
from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform
from apache_beam.typehints import schemas
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.schemas import named_tuple_to_schema
from apache_beam.typehints.schemas import typing_to_runner_api
from apache_beam.utils import python_callable
from apache_beam.utils import subprocess_server
from apache_beam.version import __version__ as beam_version
[docs]class Provider:
"""Maps transform types names and args to concrete PTransform instances."""
[docs] def available(self) -> bool:
"""Returns whether this provider is available to use in this environment."""
raise NotImplementedError(type(self))
[docs] def cache_artifacts(self) -> Optional[Iterable[str]]:
raise NotImplementedError(type(self))
[docs] def config_schema(self, type):
return None
[docs] def description(self, type):
return None
[docs] def underlying_provider(self):
"""If this provider is simply a proxy to another provider, return the
provider that should actually be used for affinity checking.
"""
return self
[docs] def affinity(self, other: "Provider"):
"""Returns a value approximating how good it would be for this provider
to be used immediately following a transform from the other provider
(e.g. to encourage fusion).
"""
# TODO(yaml): This is a very rough heuristic. Consider doing better.
# E.g. we could look at the the expected environments themselves.
# Possibly, we could provide multiple expansions and have the runner itself
# choose the actual implementation based on fusion (and other) criteria.
a = self.underlying_provider()
b = other.underlying_provider()
return a._affinity(b) + b._affinity(a)
def _affinity(self, other: "Provider"):
if self is other or self == other:
return 100
elif type(self) == type(other):
return 10
else:
return 0
[docs]def as_provider(name, provider_or_constructor):
if isinstance(provider_or_constructor, Provider):
return provider_or_constructor
else:
return InlineProvider({name: provider_or_constructor})
[docs]def as_provider_list(name, lst):
if not isinstance(lst, list):
return as_provider_list(name, [lst])
return [as_provider(name, x) for x in lst]
[docs]class ExternalProvider(Provider):
"""A Provider implemented via the cross language transform service."""
_provider_types: Dict[str, Callable[..., Provider]] = {}
def __init__(self, urns, service):
self._urns = urns
self._service = service
self._schema_transforms = None
[docs] def config_schema(self, type):
if self._urns[type] in self.schema_transforms():
return named_tuple_to_schema(
self.schema_transforms()[self._urns[type]].configuration_schema)
[docs] def description(self, type):
if self._urns[type] in self.schema_transforms():
return self.schema_transforms()[self._urns[type]].description
[docs] @classmethod
def provider_from_spec(cls, spec):
from apache_beam.yaml.yaml_transform import SafeLineLoader
for required in ('type', 'transforms'):
if required not in spec:
raise ValueError(
f'Missing {required} in provider '
f'at line {SafeLineLoader.get_line(spec)}')
urns = spec['transforms']
type = spec['type']
config = SafeLineLoader.strip_metadata(spec.get('config', {}))
extra_params = set(SafeLineLoader.strip_metadata(spec).keys()) - set(
['transforms', 'type', 'config'])
if extra_params:
raise ValueError(
f'Unexpected parameters in provider of type {type} '
f'at line {SafeLineLoader.get_line(spec)}: {extra_params}')
if config.get('version', None) == 'BEAM_VERSION':
config['version'] = beam_version
if type in cls._provider_types:
try:
return cls._provider_types[type](urns, **config)
except Exception as exn:
raise ValueError(
f'Unable to instantiate provider of type {type} '
f'at line {SafeLineLoader.get_line(spec)}: {exn}') from exn
else:
raise NotImplementedError(
f'Unknown provider type: {type} '
f'at line {SafeLineLoader.get_line(spec)}.')
[docs] @classmethod
def register_provider_type(cls, type_name):
def apply(constructor):
cls._provider_types[type_name] = constructor
return constructor
return apply
[docs]@ExternalProvider.register_provider_type('javaJar')
def java_jar(urns, jar: str):
if not os.path.exists(jar):
parsed = urllib.parse.urlparse(jar)
if not parsed.scheme or not parsed.netloc:
raise ValueError(f'Invalid path or url: {jar}')
return ExternalJavaProvider(urns, lambda: jar)
[docs]@ExternalProvider.register_provider_type('mavenJar')
def maven_jar(
urns,
*,
artifact_id,
group_id,
version,
repository=subprocess_server.JavaJarServer.MAVEN_CENTRAL_REPOSITORY,
classifier=None,
appendix=None):
return ExternalJavaProvider(
urns,
lambda: subprocess_server.JavaJarServer.path_to_maven_jar(
artifact_id=artifact_id,
version=version,
repository=repository,
classifier=classifier,
appendix=appendix))
[docs]@ExternalProvider.register_provider_type('beamJar')
def beam_jar(
urns,
*,
gradle_target,
appendix=None,
version=beam_version,
artifact_id=None):
return ExternalJavaProvider(
urns,
lambda: subprocess_server.JavaJarServer.path_to_beam_jar(
gradle_target=gradle_target, version=version, artifact_id=artifact_id)
)
[docs]@ExternalProvider.register_provider_type('docker')
def docker(urns, **config):
raise NotImplementedError()
[docs]@ExternalProvider.register_provider_type('remote')
class RemoteProvider(ExternalProvider):
_is_available = None
def __init__(self, urns, address: str):
super().__init__(urns, service=address)
[docs] def available(self):
if self._is_available is None:
try:
with external.ExternalTransform.service(self._service) as service:
service.ready(1)
self._is_available = True
except Exception:
self._is_available = False
return self._is_available
[docs] def cache_artifacts(self):
pass
[docs]class ExternalJavaProvider(ExternalProvider):
def __init__(self, urns, jar_provider):
super().__init__(
urns, lambda: external.JavaJarExpansionService(jar_provider()))
self._jar_provider = jar_provider
[docs] def available(self):
# pylint: disable=subprocess-run-check
return subprocess.run(['which', 'java'],
capture_output=True).returncode == 0
[docs] def cache_artifacts(self):
return [self._jar_provider()]
[docs]@ExternalProvider.register_provider_type('python')
def python(urns, packages=()):
if packages:
return ExternalPythonProvider(urns, packages)
else:
return InlineProvider({
name:
python_callable.PythonCallableWithSource.load_from_fully_qualified_name(
constructor)
for (name, constructor) in urns.items()
})
[docs]@ExternalProvider.register_provider_type('pythonPackage')
class ExternalPythonProvider(ExternalProvider):
def __init__(self, urns, packages):
super().__init__(urns, PypiExpansionService(packages))
[docs] def available(self):
return True # If we're running this script, we have Python installed.
[docs] def cache_artifacts(self):
return [self._service._venv()]
def _affinity(self, other: "Provider"):
if isinstance(other, InlineProvider):
return 50
else:
return super()._affinity(other)
# This is needed because type inference can't handle *args, **kwargs forwarding.
# TODO(BEAM-24755): Add support for type inference of through kwargs calls.
[docs]def fix_pycallable():
from apache_beam.transforms.ptransform import label_from_callable
def default_label(self):
src = self._source.strip()
last_line = src.split('\n')[-1]
if last_line[0] != ' ' and len(last_line) < 72:
return last_line
return label_from_callable(self._callable)
def _argspec_fn(self):
return self._callable
python_callable.PythonCallableWithSource.default_label = default_label
python_callable.PythonCallableWithSource._argspec_fn = property(_argspec_fn)
original_infer_return_type = trivial_inference.infer_return_type
def infer_return_type(fn, *args, **kwargs):
if isinstance(fn, python_callable.PythonCallableWithSource):
fn = fn._callable
return original_infer_return_type(fn, *args, **kwargs)
trivial_inference.infer_return_type = infer_return_type
original_fn_takes_side_inputs = (
apache_beam.transforms.util.fn_takes_side_inputs)
def fn_takes_side_inputs(fn):
if isinstance(fn, python_callable.PythonCallableWithSource):
fn = fn._callable
return original_fn_takes_side_inputs(fn)
apache_beam.transforms.util.fn_takes_side_inputs = fn_takes_side_inputs
[docs]class InlineProvider(Provider):
def __init__(self, transform_factories, no_input_transforms=()):
self._transform_factories = transform_factories
self._no_input_transforms = set(no_input_transforms)
[docs] def available(self):
return True
[docs] def cache_artifacts(self):
pass
[docs] def config_schema(self, typ):
factory = self._transform_factories[typ]
if isinstance(factory, type) and issubclass(factory, beam.PTransform):
# https://bugs.python.org/issue40897
params = dict(inspect.signature(factory.__init__).parameters)
if 'self' in params:
del params['self']
else:
params = inspect.signature(factory).parameters
def type_of(p):
t = p.annotation
if t == p.empty:
return Any
else:
return t
docs = {
param.arg_name: param.description
for param in self.get_docs(typ).params
}
names_and_types = [
(name, typing_to_runner_api(type_of(p))) for name, p in params.items()
]
return schema_pb2.Schema(
fields=[
schema_pb2.Field(name=name, type=type, description=docs.get(name))
for (name, type) in names_and_types
])
[docs] def description(self, typ):
def empty_if_none(s):
return s or ''
docs = self.get_docs(typ)
return (
empty_if_none(docs.short_description) + '\n\n' +
empty_if_none(docs.long_description)).strip() or None
[docs] def get_docs(self, typ):
docstring = self._transform_factories[typ].__doc__ or ''
# These "extra" docstring parameters are not relevant for YAML and mess
# up the parsing.
docstring = re.sub(
r'Pandas Parameters\s+-----.*', '', docstring, flags=re.S)
return docstring_parser.parse(
docstring, docstring_parser.DocstringStyle.GOOGLE)
[docs] def to_json(self):
return {'type': "InlineProvider"}
[docs]class SqlBackedProvider(Provider):
def __init__(
self,
transforms: Mapping[str, Callable[..., beam.PTransform]],
sql_provider: Optional[Provider] = None):
self._transforms = transforms
if sql_provider is None:
sql_provider = beam_jar(
urns={'Sql': 'beam:external:java:sql:v1'},
gradle_target='sdks:java:extensions:sql:expansion-service:shadowJar')
self._sql_provider = sql_provider
[docs] def sql_provider(self):
return self._sql_provider
[docs] def available(self):
return self.sql_provider().available()
[docs] def cache_artifacts(self):
return self.sql_provider().cache_artifacts()
[docs] def underlying_provider(self):
return self.sql_provider()
[docs] def to_json(self):
return {'type': "SqlBackedProvider"}
PRIMITIVE_NAMES_TO_ATOMIC_TYPE = {
py_type.__name__: schema_type
for (py_type, schema_type) in schemas.PRIMITIVE_TO_ATOMIC_TYPE.items()
if py_type.__module__ != 'typing'
}
[docs]def element_to_rows(e):
if isinstance(e, dict):
return dicts_to_rows(e)
else:
return beam.Row(element=dicts_to_rows(e))
[docs]def dicts_to_rows(o):
if isinstance(o, dict):
return beam.Row(**{k: dicts_to_rows(v) for k, v in o.items()})
elif isinstance(o, list):
return [dicts_to_rows(e) for e in o]
else:
return o
[docs]def create_builtin_provider():
def create(elements: Iterable[Any], reshuffle: Optional[bool] = True):
"""Creates a collection containing a specified set of elements.
YAML/JSON-style mappings will be interpreted as Beam rows. For example::
type: Create
elements:
- {first: 0, second: {str: "foo", values: [1, 2, 3]}}
will result in a schema of the form (int, Row(string, List[int])).
Args:
elements: The set of elements that should belong to the PCollection.
YAML/JSON-style mappings will be interpreted as Beam rows.
reshuffle (optional): Whether to introduce a reshuffle (to possibly
redistribute the work) if there is more than one element in the
collection. Defaults to True.
"""
return beam.Create([element_to_rows(e) for e in elements],
reshuffle=reshuffle is not False)
# Or should this be posargs, args?
# pylint: disable=dangerous-default-value
def fully_qualified_named_transform(
constructor: str,
args: Optional[Iterable[Any]] = (),
kwargs: Optional[Mapping[str, Any]] = {}):
"""A Python PTransform identified by fully qualified name.
This allows one to import, construct, and apply any Beam Python transform.
This can be useful for using transforms that have not yet been exposed
via a YAML interface. Note, however, that conversion may be required if this
transform does not accept or produce Beam Rows.
For example,
type: PyTransform
config:
constructor: apache_beam.pkg.mod.SomeClass
args: [1, 'foo']
kwargs:
baz: 3
can be used to access the transform
`apache_beam.pkg.mod.SomeClass(1, 'foo', baz=3)`.
Args:
constructor: Fully qualified name of a callable used to construct the
transform. Often this is a class such as
`apache_beam.pkg.mod.SomeClass` but it can also be a function or
any other callable that returns a PTransform.
args: A list of parameters to pass to the callable as positional
arguments.
kwargs: A list of parameters to pass to the callable as keyword
arguments.
"""
with FullyQualifiedNamedTransform.with_filter('*'):
return constructor >> FullyQualifiedNamedTransform(
constructor, args, kwargs)
# This intermediate is needed because there is no way to specify a tuple of
# exactly zero or one PCollection in yaml (as they would be interpreted as
# PBegin and the PCollection itself respectively).
class Flatten(beam.PTransform):
"""Flattens multiple PCollections into a single PCollection.
The elements of the resulting PCollection will be the (disjoint) union of
all the elements of all the inputs.
Note that in YAML transforms can always take a list of inputs which will
be implicitly flattened.
"""
def __init__(self):
# Suppress the "label" argument from the superclass for better docs.
# pylint: disable=useless-parent-delegation
super().__init__()
def expand(self, pcolls):
if isinstance(pcolls, beam.PCollection):
pipeline_arg = {}
pcolls = (pcolls, )
elif isinstance(pcolls, dict):
pipeline_arg = {}
pcolls = tuple(pcolls.values())
else:
pipeline_arg = {'pipeline': pcolls.pipeline}
pcolls = ()
return pcolls | beam.Flatten(**pipeline_arg)
class WindowInto(beam.PTransform):
# pylint: disable=line-too-long
"""A window transform assigning windows to each element of a PCollection.
The assigned windows will affect all downstream aggregating operations,
which will aggregate by window as well as by key.
See [the Beam documentation on windowing](https://beam.apache.org/documentation/programming-guide/#windowing)
for more details.
Note that any Yaml transform can have a
[windowing parameter](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/yaml/README.md#windowing),
which is applied to its inputs (if any) or outputs (if there are no inputs)
which means that explicit WindowInto operations are not typically needed.
Args:
windowing: the type and parameters of the windowing to perform
"""
def __init__(self, windowing):
self._window_transform = self._parse_window_spec(windowing)
def expand(self, pcoll):
return pcoll | self._window_transform
@staticmethod
def _parse_window_spec(spec):
spec = dict(spec)
window_type = spec.pop('type')
# TODO: These are in seconds, perhaps parse duration strings meaningfully?
if window_type == 'global':
window_fn = window.GlobalWindows()
elif window_type == 'fixed':
window_fn = window.FixedWindows(spec.pop('size'), spec.pop('offset', 0))
elif window_type == 'sliding':
window_fn = window.SlidingWindows(
spec.pop('size'), spec.pop('period'), spec.pop('offset', 0))
elif window_type == 'sessions':
window_fn = window.FixedWindows(spec.pop('gap'))
if spec:
raise ValueError(f'Unknown parameters {spec.keys()}')
# TODO: Triggering, etc.
return beam.WindowInto(window_fn)
def LogForTesting():
"""Logs each element of its input PCollection.
The output of this transform is a copy of its input for ease of use in
chain-style pipelines.
"""
def log_and_return(x):
logging.info(x)
return x
return beam.Map(log_and_return)
return InlineProvider({
'Create': create,
'LogForTesting': LogForTesting,
'PyTransform': fully_qualified_named_transform,
'Flatten': Flatten,
'WindowInto': WindowInto,
},
no_input_transforms=('Create', ))
[docs]class PypiExpansionService:
"""Expands transforms by fully qualified name in a virtual environment
with the given dependencies.
"""
VENV_CACHE = os.path.expanduser("~/.apache_beam/cache/venvs")
def __init__(self, packages, base_python=sys.executable):
self._packages = packages
self._base_python = base_python
@classmethod
def _key(cls, base_python, packages):
return json.dumps({
'binary': base_python, 'packages': sorted(packages)
},
sort_keys=True)
@classmethod
def _path(cls, base_python, packages):
return os.path.join(
cls.VENV_CACHE,
hashlib.sha256(cls._key(base_python,
packages).encode('utf-8')).hexdigest())
@classmethod
def _create_venv_from_scratch(cls, base_python, packages):
venv = cls._path(base_python, packages)
if not os.path.exists(venv):
subprocess.run([base_python, '-m', 'venv', venv], check=True)
venv_python = os.path.join(venv, 'bin', 'python')
subprocess.run([venv_python, '-m', 'ensurepip'], check=True)
subprocess.run([venv_python, '-m', 'pip', 'install'] + packages,
check=True)
with open(venv + '-requirements.txt', 'w') as fout:
fout.write('\n'.join(packages))
return venv
@classmethod
def _create_venv_from_clone(cls, base_python, packages):
venv = cls._path(base_python, packages)
if not os.path.exists(venv):
clonable_venv = cls._create_venv_to_clone(base_python)
clonable_python = os.path.join(clonable_venv, 'bin', 'python')
subprocess.run(
[clonable_python, '-m', 'clonevirtualenv', clonable_venv, venv],
check=True)
venv_binary = os.path.join(venv, 'bin', 'python')
subprocess.run([venv_binary, '-m', 'pip', 'install'] + packages,
check=True)
with open(venv + '-requirements.txt', 'w') as fout:
fout.write('\n'.join(packages))
return venv
@classmethod
def _create_venv_to_clone(cls, base_python):
return cls._create_venv_from_scratch(
base_python, [
'apache_beam[dataframe,gcp,test]==' + beam_version,
'virtualenv-clone'
])
def _venv(self):
return self._create_venv_from_clone(self._base_python, self._packages)
def __enter__(self):
venv = self._venv()
self._service_provider = subprocess_server.SubprocessServer(
external.ExpansionAndArtifactRetrievalStub,
[
os.path.join(venv, 'bin', 'python'),
'-m',
'apache_beam.runners.portability.expansion_service_main',
'--port',
'{{PORT}}',
'--fully_qualified_name_glob=*',
'--pickle_library=cloudpickle',
'--requirements_file=' + os.path.join(venv + '-requirements.txt')
])
self._service = self._service_provider.__enter__()
return self._service
def __exit__(self, *args):
self._service_provider.__exit__(*args)
self._service = None
[docs]@ExternalProvider.register_provider_type('renaming')
class RenamingProvider(Provider):
def __init__(self, transforms, mappings, underlying_provider, defaults=None):
if isinstance(underlying_provider, dict):
underlying_provider = ExternalProvider.provider_from_spec(
underlying_provider)
self._transforms = transforms
self._underlying_provider = underlying_provider
for transform in transforms.keys():
if transform not in mappings:
raise ValueError(f'Missing transform {transform} in mappings.')
self._mappings = self.expand_mappings(mappings)
self._defaults = defaults or {}
[docs] @staticmethod
def expand_mappings(mappings):
if not isinstance(mappings, dict):
raise ValueError(
"RenamingProvider mappings must be dict of transform "
"mappings.")
for key, value in mappings.items():
if isinstance(value, str):
if value not in mappings.keys():
raise ValueError(
"RenamingProvider transform mappings must be dict or "
"specify transform that has mappings within same "
"provider.")
mappings[key] = mappings[value]
return mappings
[docs] def available(self) -> bool:
return self._underlying_provider.available()
[docs] def config_schema(self, type):
underlying_schema = self._underlying_provider.config_schema(
self._transforms[type])
if underlying_schema is None:
return None
defaults = self._defaults.get(type, {})
underlying_schema_fields = {f.name: f for f in underlying_schema.fields}
missing = set(self._mappings[type].values()) - set(
underlying_schema_fields.keys())
if missing:
raise ValueError(
f"Mapping destinations {missing} for {type} are not in the "
f"underlying config schema {list(underlying_schema_fields.keys())}")
def with_name(
original: schema_pb2.Field, new_name: str) -> schema_pb2.Field:
result = schema_pb2.Field()
result.CopyFrom(original)
result.name = new_name
return result
return schema_pb2.Schema(
fields=[
with_name(underlying_schema_fields[dest], src)
for (src, dest) in self._mappings[type].items()
if dest not in defaults
])
[docs] def description(self, typ):
return self._underlying_provider.description(typ)
def _affinity(self, other):
raise NotImplementedError(
'Should not be calling _affinity directly on this provider.')
[docs] def underlying_provider(self):
return self._underlying_provider.underlying_provider()
[docs] def cache_artifacts(self):
self._underlying_provider.cache_artifacts()
[docs]def parse_providers(provider_specs):
providers = collections.defaultdict(list)
for provider_spec in provider_specs:
provider = ExternalProvider.provider_from_spec(provider_spec)
for transform_type in provider.provided_transforms():
providers[transform_type].append(provider)
# TODO: Do this better.
provider.to_json = lambda result=provider_spec: result
return providers
[docs]def merge_providers(*provider_sets):
result = collections.defaultdict(list)
for provider_set in provider_sets:
if isinstance(provider_set, Provider):
provider = provider_set
provider_set = {
transform_type: [provider]
for transform_type in provider.provided_transforms()
}
elif isinstance(provider_set, list):
provider_set = merge_providers(*provider_set)
for transform_type, providers in provider_set.items():
result[transform_type].extend(providers)
return result
[docs]def standard_providers():
from apache_beam.yaml.yaml_combine import create_combine_providers
from apache_beam.yaml.yaml_mapping import create_mapping_providers
from apache_beam.yaml.yaml_io import io_providers
with open(os.path.join(os.path.dirname(__file__),
'standard_providers.yaml')) as fin:
standard_providers = yaml.load(fin, Loader=SafeLoader)
return merge_providers(
create_builtin_provider(),
create_mapping_providers(),
create_combine_providers(),
io_providers(),
parse_providers(standard_providers))