#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Pipeline manipulation utilities useful for many runners.
For internal use only; no backwards-compatibility guarantees.
"""
# pytype: skip-fileimport collections
import collections
import copy
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import environments
from apache_beam.typehints import typehints
[docs]
def validate_pipeline_graph(pipeline_proto):
"""Ensures this is a correctly constructed Beam pipeline.
"""
def get_coder(pcoll_id):
return pipeline_proto.components.coders[
pipeline_proto.components.pcollections[pcoll_id].coder_id]
def validate_transform(transform_id):
transform_proto = pipeline_proto.components.transforms[transform_id]
# Currently the only validation we perform is that GBK operations have
# their coders set properly.
if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
if len(transform_proto.inputs) != 1:
raise ValueError("Unexpected number of inputs: %s" % transform_proto)
if len(transform_proto.outputs) != 1:
raise ValueError("Unexpected number of outputs: %s" % transform_proto)
input_coder = get_coder(next(iter(transform_proto.inputs.values())))
output_coder = get_coder(next(iter(transform_proto.outputs.values())))
if input_coder.spec.urn != common_urns.coders.KV.urn:
raise ValueError(
"Bad coder for input of %s: %s" % (transform_id, input_coder))
if output_coder.spec.urn != common_urns.coders.KV.urn:
raise ValueError(
"Bad coder for output of %s: %s" % (transform_id, output_coder))
output_values_coder = pipeline_proto.components.coders[
output_coder.component_coder_ids[1]]
if (input_coder.component_coder_ids[0] !=
output_coder.component_coder_ids[0] or
output_values_coder.spec.urn != common_urns.coders.ITERABLE.urn or
output_values_coder.component_coder_ids[0] !=
input_coder.component_coder_ids[1]):
raise ValueError(
"Incompatible input coder %s and output coder %s for transform %s" %
(transform_id, input_coder, output_coder))
elif transform_proto.spec.urn == common_urns.primitives.ASSIGN_WINDOWS.urn:
if not transform_proto.inputs:
raise ValueError("Missing input for transform: %s" % transform_proto)
elif transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn:
if not transform_proto.inputs:
raise ValueError("Missing input for transform: %s" % transform_proto)
for t in transform_proto.subtransforms:
validate_transform(t)
for t in pipeline_proto.root_transform_ids:
validate_transform(t)
def _dep_key(dep):
if dep.type_urn == common_urns.artifact_types.FILE.urn:
payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
dep.type_payload)
if payload.sha256:
type_info = 'sha256', payload.sha256
else:
type_info = 'path', payload.path
elif dep.type_urn == common_urns.artifact_types.URL.urn:
payload = beam_runner_api_pb2.ArtifactUrlPayload.FromString(
dep.type_payload)
if payload.sha256:
type_info = 'sha256', payload.sha256
else:
type_info = 'url', payload.url
else:
type_info = dep.type_urn, dep.type_payload
return type_info, dep.role_urn, dep.role_payload
def _expanded_dep_keys(dep):
if (dep.type_urn == common_urns.artifact_types.FILE.urn and
dep.role_urn == common_urns.artifact_roles.STAGING_TO.urn):
payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
dep.type_payload)
role = beam_runner_api_pb2.ArtifactStagingToRolePayload.FromString(
dep.role_payload)
if role.staged_name == 'submission_environment_dependencies.txt':
return
elif role.staged_name == 'requirements.txt':
with open(payload.path) as fin:
for line in fin:
yield 'requirements.txt', line.strip()
return
yield _dep_key(dep)
def _base_env_key(env, include_deps=True):
return (
env.urn,
env.payload,
tuple(sorted(env.capabilities)),
tuple(sorted(env.resource_hints.items())),
tuple(sorted(_dep_key(dep)
for dep in env.dependencies)) if include_deps else None)
def _env_key(env):
return tuple(
sorted(
_base_env_key(e)
for e in environments.expand_anyof_environments(env)))
[docs]
def merge_common_environments(pipeline_proto, inplace=False):
canonical_environments = collections.defaultdict(list)
for env_id, env in pipeline_proto.components.environments.items():
canonical_environments[_env_key(env)].append(env_id)
if len(canonical_environments) == len(pipeline_proto.components.environments):
# All environments are already sufficiently distinct.
return pipeline_proto
environment_remappings = {
e: es[0]
for es in canonical_environments.values() for e in es
}
return update_environments(pipeline_proto, environment_remappings, inplace)
[docs]
def merge_superset_dep_environments(pipeline_proto):
"""Merges all environemnts A and B where A and B are equivalent except that
A has a superset of the dependencies of B.
"""
docker_envs = {}
for env_id, env in pipeline_proto.components.environments.items():
docker_env = environments.resolve_anyof_environment(
env, common_urns.environments.DOCKER.urn)
if docker_env.urn == common_urns.environments.DOCKER.urn:
docker_envs[env_id] = docker_env
has_base_and_dep = collections.defaultdict(set)
env_scores = {
env_id: (len(env.dependencies), env_id)
for (env_id, env) in docker_envs.items()
}
for env_id, env in docker_envs.items():
base_key = _base_env_key(env, include_deps=False)
has_base_and_dep[base_key, None].add(env_id)
for dep in env.dependencies:
for dep_key in _expanded_dep_keys(dep):
has_base_and_dep[base_key, dep_key].add(env_id)
environment_remappings = {}
for env_id, env in docker_envs.items():
base_key = _base_env_key(env, include_deps=False)
# This is the set of all environments that have at least all of env's deps.
candidates = set.intersection(
has_base_and_dep[base_key, None],
*[
has_base_and_dep[base_key, dep_key] for dep in env.dependencies
for dep_key in _expanded_dep_keys(dep)
])
# Choose the maximal one.
best = max(candidates, key=env_scores.get)
if best != env_id:
environment_remappings[env_id] = best
return update_environments(pipeline_proto, environment_remappings)
[docs]
def update_environments(pipeline_proto, environment_remappings, inplace=False):
if not environment_remappings:
return pipeline_proto
if not inplace:
pipeline_proto = copy.copy(pipeline_proto)
for t in pipeline_proto.components.transforms.values():
if t.environment_id not in pipeline_proto.components.environments:
# TODO(https://github.com/apache/beam/issues/30876): Remove this
# workaround.
continue
if t.environment_id and t.environment_id in environment_remappings:
t.environment_id = environment_remappings[t.environment_id]
for w in pipeline_proto.components.windowing_strategies.values():
if w.environment_id not in pipeline_proto.components.environments:
# TODO(https://github.com/apache/beam/issues/30876): Remove this
# workaround.
continue
if w.environment_id and w.environment_id in environment_remappings:
w.environment_id = environment_remappings[w.environment_id]
for e in set(environment_remappings.keys()) - set(
environment_remappings.values()):
del pipeline_proto.components.environments[e]
return pipeline_proto