Source code for apache_beam.runners.pipeline_context

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility class for serializing pipelines via the runner API.

For internal use only; no backwards-compatibility guarantees.

# pytype: skip-file
# mypy: disallow-untyped-defs

from __future__ import absolute_import

from builtins import object
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import FrozenSet
from typing import Generic
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import Type
from typing import TypeVar
from typing import Union

from typing_extensions import Protocol

from apache_beam import coders
from apache_beam import pipeline
from apache_beam import pvalue
from apache_beam.internal import pickler
from apache_beam.pipeline import ComponentIdMap
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import core
from apache_beam.transforms import environments
from apache_beam.typehints import native_type_compatibility

  from google.protobuf import message  # pylint: disable=ungrouped-imports
  from apache_beam.coders.coder_impl import IterableStateReader
  from apache_beam.coders.coder_impl import IterableStateWriter

PortableObjectT = TypeVar('PortableObjectT', bound='PortableObject')

[docs]class PortableObject(Protocol):
[docs] def to_runner_api(self, __context): # type: (PipelineContext) -> Any pass
[docs] @classmethod def from_runner_api(cls, __proto, __context): # type: (Any, PipelineContext) -> Any pass
class _PipelineContextMap(Generic[PortableObjectT]): """This is a bi-directional map between objects and ids. Under the hood it encodes and decodes these objects into runner API representations. """ def __init__(self, context, # type: PipelineContext obj_type, # type: Type[PortableObjectT] namespace, # type: str proto_map=None # type: Optional[Mapping[str, message.Message]] ): # type: (...) -> None self._pipeline_context = context self._obj_type = obj_type self._namespace = namespace self._obj_to_id = {} # type: Dict[Any, str] self._id_to_obj = {} # type: Dict[str, Any] self._id_to_proto = dict(proto_map) if proto_map else {} def populate_map(self, proto_map): # type: (Mapping[str, message.Message]) -> None for id, proto in self._id_to_proto.items(): proto_map[id].CopyFrom(proto) def get_id(self, obj, label=None): # type: (PortableObjectT, Optional[str]) -> str if obj not in self._obj_to_id: id = self._pipeline_context.component_id_map.get_or_assign( obj, self._obj_type, label) self._id_to_obj[id] = obj self._obj_to_id[obj] = id self._id_to_proto[id] = obj.to_runner_api(self._pipeline_context) return self._obj_to_id[obj] def get_proto(self, obj, label=None): # type: (PortableObjectT, Optional[str]) -> message.Message return self._id_to_proto[self.get_id(obj, label)] def get_by_id(self, id): # type: (str) -> PortableObjectT if id not in self._id_to_obj: self._id_to_obj[id] = self._obj_type.from_runner_api( self._id_to_proto[id], self._pipeline_context) return self._id_to_obj[id] def get_by_proto(self, maybe_new_proto, label=None, deduplicate=False): # type: (message.Message, Optional[str], bool) -> str if deduplicate: for id, proto in self._id_to_proto.items(): if proto == maybe_new_proto: return id return self.put_proto( self._pipeline_context.component_id_map.get_or_assign( label, obj_type=self._obj_type), maybe_new_proto) def get_id_to_proto_map(self): # type: () -> Dict[str, message.Message] return self._id_to_proto def get_proto_from_id(self, id): # type: (str) -> message.Message return self.get_id_to_proto_map()[id] def put_proto(self, id, proto, ignore_duplicates=False): # type: (str, message.Message, bool) -> str if not ignore_duplicates and id in self._id_to_proto: raise ValueError("Id '%s' is already taken." % id) elif (ignore_duplicates and id in self._id_to_proto and self._id_to_proto[id] != proto): raise ValueError( 'Cannot insert different protos %r and %r with the same ID %r', self._id_to_proto[id], proto, id) self._id_to_proto[id] = proto return id def __getitem__(self, id): # type: (str) -> Any return self.get_by_id(id) def __contains__(self, id): # type: (str) -> bool return id in self._id_to_proto
[docs]class PipelineContext(object): """For internal use only; no backwards-compatibility guarantees. Used for accessing and constructing the referenced objects of a Pipeline. """ def __init__(self, proto=None, # type: Optional[Union[beam_runner_api_pb2.Components, beam_fn_api_pb2.ProcessBundleDescriptor]] component_id_map=None, # type: Optional[pipeline.ComponentIdMap] default_environment=None, # type: Optional[environments.Environment] use_fake_coders=False, # type: bool iterable_state_read=None, # type: Optional[IterableStateReader] iterable_state_write=None, # type: Optional[IterableStateWriter] namespace='ref', # type: str requirements=(), # type: Iterable[str] ): # type: (...) -> None if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor): proto = beam_runner_api_pb2.Components( coders=dict(proto.coders.items()), windowing_strategies=dict(proto.windowing_strategies.items()), environments=dict(proto.environments.items())) self.component_id_map = component_id_map or ComponentIdMap(namespace) assert self.component_id_map.namespace == namespace self.transforms = _PipelineContextMap( self, pipeline.AppliedPTransform, namespace, proto.transforms if proto is not None else None) self.pcollections = _PipelineContextMap( self, pvalue.PCollection, namespace, proto.pcollections if proto is not None else None) self.coders = _PipelineContextMap( self, coders.Coder, namespace, proto.coders if proto is not None else None) self.windowing_strategies = _PipelineContextMap( self, core.Windowing, namespace, proto.windowing_strategies if proto is not None else None) self.environments = _PipelineContextMap( self, environments.Environment, namespace, proto.environments if proto is not None else None) if default_environment: self._default_environment_id = self.environments.get_id( default_environment, label='default_environment') # type: Optional[str] else: self._default_environment_id = None self.use_fake_coders = use_fake_coders self.iterable_state_read = iterable_state_read self.iterable_state_write = iterable_state_write self._requirements = set(requirements)
[docs] def add_requirement(self, requirement): # type: (str) -> None self._requirements.add(requirement)
[docs] def requirements(self): # type: () -> FrozenSet[str] return frozenset(self._requirements)
# If fake coders are requested, return a pickled version of the element type # rather than an actual coder. The element type is required for some runners, # as well as performing a round-trip through protos. # TODO(BEAM-2717): Remove once this is no longer needed.
[docs] def coder_id_from_element_type(self, element_type): # type: (Any) -> str if self.use_fake_coders: return pickler.dumps(element_type).decode('ascii') else: return self.coders.get_id(coders.registry.get_coder(element_type))
[docs] def element_type_from_coder_id(self, coder_id): # type: (str) -> Any if self.use_fake_coders or coder_id not in self.coders: return pickler.loads(coder_id) else: return native_type_compatibility.convert_to_beam_type( self.coders[coder_id].to_type_hint())
[docs] @staticmethod def from_runner_api(proto): # type: (beam_runner_api_pb2.Components) -> PipelineContext return PipelineContext(proto)
[docs] def to_runner_api(self): # type: () -> beam_runner_api_pb2.Components context_proto = beam_runner_api_pb2.Components() self.transforms.populate_map(context_proto.transforms) self.pcollections.populate_map(context_proto.pcollections) self.coders.populate_map(context_proto.coders) self.windowing_strategies.populate_map(context_proto.windowing_strategies) self.environments.populate_map(context_proto.environments) return context_proto
[docs] def default_environment_id(self): # type: () -> Optional[str] return self._default_environment_id