Source code for apache_beam.runners.direct.direct_runner

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""DirectRunner, executing on the local machine.

The DirectRunner is a runner implementation that executes the entire
graph of transformations belonging to a pipeline on the local machine.
"""

# pytype: skip-file

from __future__ import absolute_import

import itertools
import logging
import time
import typing

from google.protobuf import wrappers_pb2

import apache_beam as beam
from apache_beam import coders
from apache_beam import typehints
from apache_beam.internal.util import ArgumentPlaceholder
from apache_beam.options.pipeline_options import DirectOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.pvalue import PCollection
from apache_beam.runners.direct.bundle_factory import BundleFactory
from apache_beam.runners.direct.clock import RealClock
from apache_beam.runners.direct.clock import TestClock
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineRunner
from apache_beam.runners.runner import PipelineState
from apache_beam.transforms.core import CombinePerKey
from apache_beam.transforms.core import CombineValuesDoFn
from apache_beam.transforms.core import DoFn
from apache_beam.transforms.core import ParDo
from apache_beam.transforms.ptransform import PTransform
from apache_beam.typehints import trivial_inference

# Note that the BundleBasedDirectRunner and SwitchingDirectRunner names are
# experimental and have no backwards compatibility guarantees.
__all__ = ['BundleBasedDirectRunner', 'DirectRunner', 'SwitchingDirectRunner']

_LOGGER = logging.getLogger(__name__)


[docs]class SwitchingDirectRunner(PipelineRunner): """Executes a single pipeline on the local machine. This implementation switches between using the FnApiRunner (which has high throughput for batch jobs) and using the BundleBasedDirectRunner, which supports streaming execution and certain primitives not yet implemented in the FnApiRunner. """
[docs] def is_fnapi_compatible(self): return BundleBasedDirectRunner.is_fnapi_compatible()
[docs] def run_pipeline(self, pipeline, options): from apache_beam.pipeline import PipelineVisitor from apache_beam.runners.dataflow.native_io.iobase import NativeSource from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite from apache_beam.testing.test_stream import TestStream class _FnApiRunnerSupportVisitor(PipelineVisitor): """Visitor determining if a Pipeline can be run on the FnApiRunner.""" def accept(self, pipeline): self.supported_by_fnapi_runner = True pipeline.visit(self) return self.supported_by_fnapi_runner def visit_transform(self, applied_ptransform): transform = applied_ptransform.transform # The FnApiRunner does not support streaming execution. if isinstance(transform, TestStream): self.supported_by_fnapi_runner = False # The FnApiRunner does not support reads from NativeSources. if (isinstance(transform, beam.io.Read) and isinstance(transform.source, NativeSource)): self.supported_by_fnapi_runner = False # The FnApiRunner does not support the use of _NativeWrites. if isinstance(transform, _NativeWrite): self.supported_by_fnapi_runner = False if isinstance(transform, beam.ParDo): dofn = transform.dofn # The FnApiRunner does not support execution of CombineFns with # deferred side inputs. if isinstance(dofn, CombineValuesDoFn): args, kwargs = transform.raw_side_inputs args_to_check = itertools.chain(args, kwargs.values()) if any(isinstance(arg, ArgumentPlaceholder) for arg in args_to_check): self.supported_by_fnapi_runner = False # Check whether all transforms used in the pipeline are supported by the # FnApiRunner, and the pipeline was not meant to be run as streaming. if _FnApiRunnerSupportVisitor().accept(pipeline): from apache_beam.runners.portability.fn_api_runner import FnApiRunner runner = FnApiRunner() else: runner = BundleBasedDirectRunner() return runner.run_pipeline(pipeline, options)
# Type variables. K = typing.TypeVar('K') V = typing.TypeVar('V') @typehints.with_input_types(typing.Tuple[K, V]) @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) class _GroupByKeyOnly(PTransform): """A group by key transform, ignoring windows.""" def infer_output_type(self, input_type): key_type, value_type = trivial_inference.key_value_types(input_type) return typehints.KV[key_type, typehints.Iterable[value_type]] def expand(self, pcoll): self._check_pcollection(pcoll) return PCollection.from_(pcoll) @typehints.with_input_types(typing.Tuple[K, typing.Iterable[V]]) @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) class _GroupAlsoByWindow(ParDo): """The GroupAlsoByWindow transform.""" def __init__(self, windowing): super(_GroupAlsoByWindow, self).__init__(_GroupAlsoByWindowDoFn(windowing)) self.windowing = windowing def expand(self, pcoll): self._check_pcollection(pcoll) return PCollection.from_(pcoll) class _GroupAlsoByWindowDoFn(DoFn): # TODO(robertwb): Support combiner lifting. def __init__(self, windowing): super(_GroupAlsoByWindowDoFn, self).__init__() self.windowing = windowing def infer_output_type(self, input_type): key_type, windowed_value_iter_type = trivial_inference.key_value_types( input_type) value_type = windowed_value_iter_type.inner_type.inner_type return typehints.Iterable[typehints.KV[key_type, typehints.Iterable[value_type]]] def start_bundle(self): # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.transforms.trigger import create_trigger_driver # pylint: enable=wrong-import-order, wrong-import-position self.driver = create_trigger_driver(self.windowing, True) def process(self, element): k, vs = element return self.driver.process_entire_key(k, vs) @typehints.with_input_types(typing.Tuple[K, V]) @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) class _StreamingGroupByKeyOnly(_GroupByKeyOnly): """Streaming GroupByKeyOnly placeholder for overriding in DirectRunner.""" urn = "direct_runner:streaming_gbko:v0.1" # These are needed due to apply overloads. def to_runner_api_parameter(self, unused_context): return _StreamingGroupByKeyOnly.urn, None @staticmethod @PTransform.register_urn(urn, None) def from_runner_api_parameter( unused_ptransform, unused_payload, unused_context): return _StreamingGroupByKeyOnly() @typehints.with_input_types(typing.Tuple[K, typing.Iterable[V]]) @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) class _StreamingGroupAlsoByWindow(_GroupAlsoByWindow): """Streaming GroupAlsoByWindow placeholder for overriding in DirectRunner.""" urn = "direct_runner:streaming_gabw:v0.1" # These are needed due to apply overloads. def to_runner_api_parameter(self, context): return ( _StreamingGroupAlsoByWindow.urn, wrappers_pb2.BytesValue( value=context.windowing_strategies.get_id(self.windowing))) @staticmethod @PTransform.register_urn(urn, wrappers_pb2.BytesValue) def from_runner_api_parameter(unused_ptransform, payload, context): return _StreamingGroupAlsoByWindow( context.windowing_strategies.get_by_id(payload.value)) @typehints.with_input_types(typing.Tuple[K, typing.Iterable[V]]) @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) class _GroupByKey(PTransform): """The DirectRunner GroupByKey implementation.""" def expand(self, pcoll): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.coders import typecoders input_type = pcoll.element_type if input_type is not None: # Initialize type-hints used below to enforce type-checking and to # pass downstream to further PTransforms. key_type, value_type = trivial_inference.key_value_types(input_type) # Enforce the input to a GBK has a KV element type. pcoll.element_type = typehints.typehints.coerce_to_kv_type( pcoll.element_type) typecoders.registry.verify_deterministic( typecoders.registry.get_coder(key_type), 'GroupByKey operation "%s"' % self.label) reify_output_type = typehints.KV[ key_type, typehints.WindowedValue[value_type]] # type: ignore[misc] gbk_input_type = ( typehints.KV[ key_type, typehints.Iterable[typehints.WindowedValue[ # type: ignore[misc] value_type]]]) gbk_output_type = typehints.KV[key_type, typehints.Iterable[value_type]] # pylint: disable=bad-continuation return ( pcoll | 'ReifyWindows' >> ( ParDo(beam.GroupByKey.ReifyWindows()).with_output_types( reify_output_type)) | 'GroupByKey' >> ( _GroupByKeyOnly().with_input_types( reify_output_type).with_output_types(gbk_input_type)) | ( 'GroupByWindow' >> _GroupAlsoByWindow(pcoll.windowing).with_input_types( gbk_input_type).with_output_types(gbk_output_type))) else: # The input_type is None, run the default return ( pcoll | 'ReifyWindows' >> ParDo(beam.GroupByKey.ReifyWindows()) | 'GroupByKey' >> _GroupByKeyOnly() | 'GroupByWindow' >> _GroupAlsoByWindow(pcoll.windowing)) def _get_transform_overrides(pipeline_options): # A list of PTransformOverride objects to be applied before running a pipeline # using DirectRunner. # Currently this only works for overrides where the input and output types do # not change. # For internal use only; no backwards-compatibility guarantees. # Importing following locally to avoid a circular dependency. from apache_beam.pipeline import PTransformOverride from apache_beam.runners.direct.helper_transforms import LiftedCombinePerKey from apache_beam.runners.direct.sdf_direct_runner import ProcessKeyedElementsViaKeyedWorkItemsOverride from apache_beam.runners.direct.sdf_direct_runner import SplittableParDoOverride class CombinePerKeyOverride(PTransformOverride): def matches(self, applied_ptransform): if isinstance(applied_ptransform.transform, CombinePerKey): return applied_ptransform.inputs[0].windowing.is_default() def get_replacement_transform_for_applied_ptransform( self, applied_ptransform): # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems # with resolving imports when they are at top. # pylint: disable=wrong-import-position try: transform = applied_ptransform.transform return LiftedCombinePerKey( transform.fn, transform.args, transform.kwargs) except NotImplementedError: return transform class StreamingGroupByKeyOverride(PTransformOverride): def matches(self, applied_ptransform): # Note: we match the exact class, since we replace it with a subclass. return applied_ptransform.transform.__class__ == _GroupByKeyOnly def get_replacement_transform_for_applied_ptransform( self, applied_ptransform): # Use specialized streaming implementation. transform = _StreamingGroupByKeyOnly() return transform class StreamingGroupAlsoByWindowOverride(PTransformOverride): def matches(self, applied_ptransform): # Note: we match the exact class, since we replace it with a subclass. transform = applied_ptransform.transform return ( isinstance(applied_ptransform.transform, ParDo) and isinstance(transform.dofn, _GroupAlsoByWindowDoFn) and transform.__class__ != _StreamingGroupAlsoByWindow) def get_replacement_transform_for_applied_ptransform( self, applied_ptransform): # Use specialized streaming implementation. transform = _StreamingGroupAlsoByWindow( applied_ptransform.transform.dofn.windowing) return transform class TestStreamOverride(PTransformOverride): def matches(self, applied_ptransform): from apache_beam.testing.test_stream import TestStream self.applied_ptransform = applied_ptransform return isinstance(applied_ptransform.transform, TestStream) def get_replacement_transform_for_applied_ptransform( self, applied_ptransform): from apache_beam.runners.direct.test_stream_impl import _ExpandableTestStream return _ExpandableTestStream(applied_ptransform.transform) class GroupByKeyPTransformOverride(PTransformOverride): """A ``PTransformOverride`` for ``GroupByKey``. This replaces the Beam implementation as a primitive. """ def matches(self, applied_ptransform): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.transforms.core import GroupByKey return isinstance(applied_ptransform.transform, GroupByKey) def get_replacement_transform_for_applied_ptransform( self, applied_ptransform): return _GroupByKey() overrides = [ # This needs to be the first and the last override. Other overrides depend # on the GroupByKey implementation to be composed of _GroupByKeyOnly and # _GroupAlsoByWindow. GroupByKeyPTransformOverride(), SplittableParDoOverride(), ProcessKeyedElementsViaKeyedWorkItemsOverride(), CombinePerKeyOverride(), TestStreamOverride(), ] # Add streaming overrides, if necessary. if pipeline_options.view_as(StandardOptions).streaming: overrides.append(StreamingGroupByKeyOverride()) overrides.append(StreamingGroupAlsoByWindowOverride()) # Add PubSub overrides, if PubSub is available. try: from apache_beam.io.gcp import pubsub as unused_pubsub overrides += _get_pubsub_transform_overrides(pipeline_options) except ImportError: pass # This also needs to be last because other transforms apply GBKs which need to # be translated into a DirectRunner-compatible transform. overrides.append(GroupByKeyPTransformOverride()) return overrides class _DirectReadFromPubSub(PTransform): def __init__(self, source): self._source = source def _infer_output_coder( self, unused_input_type=None, unused_input_coder=None): # type: (...) -> typing.Optional[coders.Coder] return coders.BytesCoder() def get_windowing(self, inputs): return beam.Windowing(beam.window.GlobalWindows()) def expand(self, pvalue): # This is handled as a native transform. return PCollection(self.pipeline, is_bounded=self._source.is_bounded()) class _DirectWriteToPubSubFn(DoFn): BUFFER_SIZE_ELEMENTS = 100 FLUSH_TIMEOUT_SECS = BUFFER_SIZE_ELEMENTS * 0.5 def __init__(self, sink): self.project = sink.project self.short_topic_name = sink.topic_name self.id_label = sink.id_label self.timestamp_attribute = sink.timestamp_attribute self.with_attributes = sink.with_attributes # TODO(BEAM-4275): Add support for id_label and timestamp_attribute. if sink.id_label: raise NotImplementedError( 'DirectRunner: id_label is not supported for ' 'PubSub writes') if sink.timestamp_attribute: raise NotImplementedError( 'DirectRunner: timestamp_attribute is not ' 'supported for PubSub writes') def start_bundle(self): self._buffer = [] def process(self, elem): self._buffer.append(elem) if len(self._buffer) >= self.BUFFER_SIZE_ELEMENTS: self._flush() def finish_bundle(self): self._flush() def _flush(self): from google.cloud import pubsub pub_client = pubsub.PublisherClient() topic = pub_client.topic_path(self.project, self.short_topic_name) if self.with_attributes: futures = [ pub_client.publish(topic, elem.data, **elem.attributes) for elem in self._buffer ] else: futures = [pub_client.publish(topic, elem) for elem in self._buffer] timer_start = time.time() for future in futures: remaining = self.FLUSH_TIMEOUT_SECS - (time.time() - timer_start) future.result(remaining) self._buffer = [] def _get_pubsub_transform_overrides(pipeline_options): from apache_beam.io.gcp import pubsub as beam_pubsub from apache_beam.pipeline import PTransformOverride class ReadFromPubSubOverride(PTransformOverride): def matches(self, applied_ptransform): return isinstance( applied_ptransform.transform, beam_pubsub.ReadFromPubSub) def get_replacement_transform_for_applied_ptransform( self, applied_ptransform): if not pipeline_options.view_as(StandardOptions).streaming: raise Exception( 'PubSub I/O is only available in streaming mode ' '(use the --streaming flag).') return _DirectReadFromPubSub(applied_ptransform.transform._source) class WriteToPubSubOverride(PTransformOverride): def matches(self, applied_ptransform): return isinstance( applied_ptransform.transform, (beam_pubsub.WriteToPubSub, beam_pubsub._WriteStringsToPubSub)) def get_replacement_transform_for_applied_ptransform( self, applied_ptransform): if not pipeline_options.view_as(StandardOptions).streaming: raise Exception( 'PubSub I/O is only available in streaming mode ' '(use the --streaming flag).') return beam.ParDo( _DirectWriteToPubSubFn(applied_ptransform.transform._sink)) return [ReadFromPubSubOverride(), WriteToPubSubOverride()]
[docs]class BundleBasedDirectRunner(PipelineRunner): """Executes a single pipeline on the local machine."""
[docs] @staticmethod def is_fnapi_compatible(): return False
[docs] def run_pipeline(self, pipeline, options): """Execute the entire pipeline and returns an DirectPipelineResult.""" # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems # with resolving imports when they are at top. # pylint: disable=wrong-import-position from apache_beam.pipeline import PipelineVisitor from apache_beam.runners.direct.consumer_tracking_pipeline_visitor import \ ConsumerTrackingPipelineVisitor from apache_beam.runners.direct.evaluation_context import EvaluationContext from apache_beam.runners.direct.executor import Executor from apache_beam.runners.direct.transform_evaluator import \ TransformEvaluatorRegistry from apache_beam.testing.test_stream import TestStream # If the TestStream I/O is used, use a mock test clock. class TestStreamUsageVisitor(PipelineVisitor): """Visitor determining whether a Pipeline uses a TestStream.""" def __init__(self): self.uses_test_stream = False def visit_transform(self, applied_ptransform): if isinstance(applied_ptransform.transform, TestStream): self.uses_test_stream = True visitor = TestStreamUsageVisitor() pipeline.visit(visitor) clock = TestClock() if visitor.uses_test_stream else RealClock() # Performing configured PTransform overrides. pipeline.replace_all(_get_transform_overrides(options)) _LOGGER.info('Running pipeline with DirectRunner.') self.consumer_tracking_visitor = ConsumerTrackingPipelineVisitor() pipeline.visit(self.consumer_tracking_visitor) evaluation_context = EvaluationContext( options, BundleFactory( stacked=options.view_as( DirectOptions).direct_runner_use_stacked_bundle), self.consumer_tracking_visitor.root_transforms, self.consumer_tracking_visitor.value_to_consumers, self.consumer_tracking_visitor.step_names, self.consumer_tracking_visitor.views, clock) executor = Executor( self.consumer_tracking_visitor.value_to_consumers, TransformEvaluatorRegistry(evaluation_context), evaluation_context) # DirectRunner does not support injecting # PipelineOptions values at runtime RuntimeValueProvider.set_runtime_options({}) # Start the executor. This is a non-blocking call, it will start the # execution in background threads and return. executor.start(self.consumer_tracking_visitor.root_transforms) result = DirectPipelineResult(executor, evaluation_context) return result
# Use the SwitchingDirectRunner as the default. DirectRunner = SwitchingDirectRunner class DirectPipelineResult(PipelineResult): """A DirectPipelineResult provides access to info about a pipeline.""" def __init__(self, executor, evaluation_context): super(DirectPipelineResult, self).__init__(PipelineState.RUNNING) self._executor = executor self._evaluation_context = evaluation_context def __del__(self): if self._state == PipelineState.RUNNING: _LOGGER.warning( 'The DirectPipelineResult is being garbage-collected while the ' 'DirectRunner is still running the corresponding pipeline. This may ' 'lead to incomplete execution of the pipeline if the main thread ' 'exits before pipeline completion. Consider using ' 'result.wait_until_finish() to wait for completion of pipeline ' 'execution.') def wait_until_finish(self, duration=None): if not PipelineState.is_terminal(self.state): if duration: raise NotImplementedError( 'DirectRunner does not support duration argument.') try: self._executor.await_completion() self._state = PipelineState.DONE except: # pylint: disable=broad-except self._state = PipelineState.FAILED raise return self._state def aggregated_values(self, aggregator_or_name): return self._evaluation_context.get_aggregator_values(aggregator_or_name) def metrics(self): return self._evaluation_context.metrics() def cancel(self): """Shuts down pipeline workers. For testing use only. Does not properly wait for pipeline workers to shut down. """ self._state = PipelineState.CANCELLING self._executor.shutdown() self._state = PipelineState.CANCELLED