Source code for apache_beam.runners.pipeline_context

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility class for serializing pipelines via the runner API.

For internal use only; no backwards-compatibility guarantees.

# pytype: skip-file
# mypy: disallow-untyped-defs

from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import FrozenSet
from typing import Generic
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import Type
from typing import TypeVar
from typing import Union

from typing_extensions import Protocol

from apache_beam import coders
from apache_beam import pipeline
from apache_beam import pvalue
from apache_beam.internal import pickler
from apache_beam.pipeline import ComponentIdMap
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import core
from apache_beam.transforms import environments
from apache_beam.transforms.resources import merge_resource_hints
from apache_beam.typehints import native_type_compatibility

  from google.protobuf import message  # pylint: disable=ungrouped-imports
  from apache_beam.coders.coder_impl import IterableStateReader
  from apache_beam.coders.coder_impl import IterableStateWriter
  from apache_beam.transforms import ptransform

PortableObjectT = TypeVar('PortableObjectT', bound='PortableObject')

[docs]class PortableObject(Protocol):
[docs] def to_runner_api(self, __context): # type: (PipelineContext) -> Any pass
[docs] @classmethod def from_runner_api(cls, __proto, __context): # type: (Any, PipelineContext) -> Any pass
class _PipelineContextMap(Generic[PortableObjectT]): """This is a bi-directional map between objects and ids. Under the hood it encodes and decodes these objects into runner API representations. """ def __init__(self, context, # type: PipelineContext obj_type, # type: Type[PortableObjectT] namespace, # type: str proto_map=None # type: Optional[Mapping[str, message.Message]] ): # type: (...) -> None self._pipeline_context = context self._obj_type = obj_type self._namespace = namespace self._obj_to_id = {} # type: Dict[Any, str] self._id_to_obj = {} # type: Dict[str, Any] self._id_to_proto = dict(proto_map) if proto_map else {} def populate_map(self, proto_map): # type: (Mapping[str, message.Message]) -> None for id, proto in self._id_to_proto.items(): proto_map[id].CopyFrom(proto) def get_id(self, obj, label=None): # type: (PortableObjectT, Optional[str]) -> str if obj not in self._obj_to_id: id = self._pipeline_context.component_id_map.get_or_assign( obj, self._obj_type, label) self._id_to_obj[id] = obj self._obj_to_id[obj] = id self._id_to_proto[id] = obj.to_runner_api(self._pipeline_context) return self._obj_to_id[obj] def get_proto(self, obj, label=None): # type: (PortableObjectT, Optional[str]) -> message.Message return self._id_to_proto[self.get_id(obj, label)] def get_by_id(self, id): # type: (str) -> PortableObjectT if id not in self._id_to_obj: self._id_to_obj[id] = self._obj_type.from_runner_api( self._id_to_proto[id], self._pipeline_context) return self._id_to_obj[id] def get_by_proto(self, maybe_new_proto, label=None, deduplicate=False): # type: (message.Message, Optional[str], bool) -> str # TODO: this method may not be safe for arbitrary protos due to # xlang concerns, hence limiting usage to the only current use-case it has. # See: assert isinstance(maybe_new_proto, beam_runner_api_pb2.Environment) obj = self._obj_type.from_runner_api( maybe_new_proto, self._pipeline_context) if deduplicate: if obj in self._obj_to_id: return self._obj_to_id[obj] for id, proto in self._id_to_proto.items(): if proto == maybe_new_proto: return id return self.put_proto( self._pipeline_context.component_id_map.get_or_assign( obj=obj, obj_type=self._obj_type, label=label), maybe_new_proto) def get_id_to_proto_map(self): # type: () -> Dict[str, message.Message] return self._id_to_proto def get_proto_from_id(self, id): # type: (str) -> message.Message return self.get_id_to_proto_map()[id] def put_proto(self, id, proto, ignore_duplicates=False): # type: (str, message.Message, bool) -> str if not ignore_duplicates and id in self._id_to_proto: raise ValueError("Id '%s' is already taken." % id) elif (ignore_duplicates and id in self._id_to_proto and self._id_to_proto[id] != proto): raise ValueError( 'Cannot insert different protos %r and %r with the same ID %r', self._id_to_proto[id], proto, id) self._id_to_proto[id] = proto return id def __getitem__(self, id): # type: (str) -> Any return self.get_by_id(id) def __contains__(self, id): # type: (str) -> bool return id in self._id_to_proto
[docs]class PipelineContext(object): """For internal use only; no backwards-compatibility guarantees. Used for accessing and constructing the referenced objects of a Pipeline. """ def __init__(self, proto=None, # type: Optional[Union[beam_runner_api_pb2.Components, beam_fn_api_pb2.ProcessBundleDescriptor]] component_id_map=None, # type: Optional[pipeline.ComponentIdMap] default_environment=None, # type: Optional[environments.Environment] use_fake_coders=False, # type: bool iterable_state_read=None, # type: Optional[IterableStateReader] iterable_state_write=None, # type: Optional[IterableStateWriter] namespace='ref', # type: str requirements=(), # type: Iterable[str] ): # type: (...) -> None if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor): proto = beam_runner_api_pb2.Components( coders=dict(proto.coders.items()), windowing_strategies=dict(proto.windowing_strategies.items()), environments=dict(proto.environments.items())) self.component_id_map = component_id_map or ComponentIdMap(namespace) assert self.component_id_map.namespace == namespace # TODO( Initialize # component_id_map with objects from proto. self.transforms = _PipelineContextMap( self, pipeline.AppliedPTransform, namespace, proto.transforms if proto is not None else None) self.pcollections = _PipelineContextMap( self, pvalue.PCollection, namespace, proto.pcollections if proto is not None else None) self.coders = _PipelineContextMap( self, coders.Coder, namespace, proto.coders if proto is not None else None) self.windowing_strategies = _PipelineContextMap( self, core.Windowing, namespace, proto.windowing_strategies if proto is not None else None) self.environments = _PipelineContextMap( self, environments.Environment, namespace, proto.environments if proto is not None else None) if default_environment is None: default_environment = environments.DefaultEnvironment() self._default_environment_id = self.environments.get_id( default_environment, label='default_environment') # type: str self.use_fake_coders = use_fake_coders self.deterministic_coder_map = { } # type: Mapping[coders.Coder, coders.Coder] self.iterable_state_read = iterable_state_read self.iterable_state_write = iterable_state_write self._requirements = set(requirements)
[docs] def add_requirement(self, requirement): # type: (str) -> None self._requirements.add(requirement)
[docs] def requirements(self): # type: () -> FrozenSet[str] return frozenset(self._requirements)
# If fake coders are requested, return a pickled version of the element type # rather than an actual coder. The element type is required for some runners, # as well as performing a round-trip through protos. # TODO( Remove once this is no # longer needed.
[docs] def coder_id_from_element_type( self, element_type, requires_deterministic_key_coder=None): # type: (Any, Optional[str]) -> str if self.use_fake_coders: return pickler.dumps(element_type).decode('ascii') else: coder = coders.registry.get_coder(element_type) if requires_deterministic_key_coder: coder = coders.TupleCoder([ self.deterministic_coder( coder.key_coder(), requires_deterministic_key_coder), coder.value_coder() ]) return self.coders.get_id(coder)
[docs] def deterministic_coder(self, coder, msg): # type: (coders.Coder, str) -> coders.Coder if coder not in self.deterministic_coder_map: self.deterministic_coder_map[coder] = coder.as_deterministic_coder(msg) # type: ignore return self.deterministic_coder_map[coder]
[docs] def element_type_from_coder_id(self, coder_id): # type: (str) -> Any if self.use_fake_coders or coder_id not in self.coders: return pickler.loads(coder_id) else: return native_type_compatibility.convert_to_beam_type( self.coders[coder_id].to_type_hint())
[docs] @staticmethod def from_runner_api(proto): # type: (beam_runner_api_pb2.Components) -> PipelineContext return PipelineContext(proto)
[docs] def to_runner_api(self): # type: () -> beam_runner_api_pb2.Components context_proto = beam_runner_api_pb2.Components() self.transforms.populate_map(context_proto.transforms) self.pcollections.populate_map(context_proto.pcollections) self.coders.populate_map(context_proto.coders) self.windowing_strategies.populate_map(context_proto.windowing_strategies) self.environments.populate_map(context_proto.environments) return context_proto
[docs] def default_environment_id(self): # type: () -> str return self._default_environment_id
[docs] def get_environment_id_for_resource_hints( self, hints): # type: (Dict[str, bytes]) -> str """Returns an environment id that has necessary resource hints.""" if not hints: return self.default_environment_id() def get_or_create_environment_with_resource_hints( template_env_id, resource_hints, ): # type: (str, Dict[str, bytes]) -> str """Creates an environment that has necessary hints and returns its id.""" template_env = self.environments.get_proto_from_id(template_env_id) cloned_env = beam_runner_api_pb2.Environment() # (TODO # Remove the suppress warning for type once mypy is updated to 0.941 or # higher. # mypy 0.790 throws the warning below but 0.941 doesn't. # error: Argument 1 to "CopyFrom" of "Message" has incompatible type # "Message"; expected "Environment" [arg-type] # Here, Environment is a subclass of Message but mypy still # throws an error. cloned_env.CopyFrom(template_env) # type: ignore[arg-type] cloned_env.resource_hints.clear() cloned_env.resource_hints.update(resource_hints) return self.environments.get_by_proto( cloned_env, label='environment_with_resource_hints', deduplicate=True) default_env_id = self.default_environment_id() env_hints = self.environments.get_by_id(default_env_id).resource_hints() hints = merge_resource_hints(outer_hints=env_hints, inner_hints=hints) maybe_new_env_id = get_or_create_environment_with_resource_hints( default_env_id, hints) return maybe_new_env_id