#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Module to build pipeline fragment that produces given PCollections.
For internal use only; no backwards-compatibility guarantees.
"""
import apache_beam as beam
from apache_beam.pipeline import PipelineVisitor
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.testing.test_stream import TestStream
[docs]class PipelineFragment(object):
  """A fragment of a pipeline definition.
  A pipeline fragment is built from the original pipeline definition to include
  only PTransforms that are necessary to produce the given PCollections.
  """
  def __init__(self, pcolls, options=None):
    """Constructor of PipelineFragment.
    Args:
      pcolls: (List[PCollection]) a list of PCollections to build pipeline
          fragment for.
      options: (PipelineOptions) the pipeline options for the implicit
          pipeline run.
    """
    assert len(pcolls) > 0, (
        'Need at least 1 PCollection as the target data to build a pipeline '
        'fragment that produces it.')
    for pcoll in pcolls:
      assert isinstance(pcoll, beam.pvalue.PCollection), (
          '{} is not an apache_beam.pvalue.PCollection.'.format(pcoll))
    # No modification to self._user_pipeline is allowed.
    self._user_pipeline = pcolls[0].pipeline
    # These are user PCollections. Do not use them to deduce anything that
    # will be executed by any runner. Instead, use
    # `self._runner_pcolls_to_user_pcolls.keys()` to get copied PCollections.
    self._pcolls = set(pcolls)
    for pcoll in self._pcolls:
      assert pcoll.pipeline is self._user_pipeline, (
          '{} belongs to a different user pipeline than other PCollections '
          'given and cannot be used to build a pipeline fragment that produces '
          'the given PCollections.'.format(pcoll))
    self._options = options
    # A copied pipeline instance for modification without changing the user
    # pipeline instance held by the end user. This instance can be processed
    # into a pipeline fragment that later run by the underlying runner.
    self._runner_pipeline = self._build_runner_pipeline()
    _, self._context = self._runner_pipeline.to_runner_api(return_context=True)
    from apache_beam.runners.interactive import pipeline_instrument as instr
    self._runner_pcoll_to_id = instr.pcoll_to_pcoll_id(
        self._runner_pipeline, self._context)
    # Correlate components in the runner pipeline to components in the user
    # pipeline. The target pcolls are the pcolls given and defined in the user
    # pipeline.
    self._id_to_target_pcoll = self._calculate_target_pcoll_ids()
    self._label_to_user_transform = self._calculate_user_transform_labels()
    # Below will give us the 1:1 correlation between
    # PCollections/AppliedPTransforms from the copied runner pipeline and
    # PCollections/AppliedPTransforms from the user pipeline.
    # (Dict[PCollection, PCollection])
    (
        self._runner_pcolls_to_user_pcolls,
        # (Dict[AppliedPTransform, AppliedPTransform])
        self._runner_transforms_to_user_transforms
    ) = self._build_correlation_between_pipelines(
        self._runner_pcoll_to_id,
        self._id_to_target_pcoll,
        self._label_to_user_transform)
    # Below are operated on the runner pipeline.
    (self._necessary_transforms,
     self._necessary_pcollections) = self._mark_necessary_transforms_and_pcolls(
         self._runner_pcolls_to_user_pcolls)
    self._runner_pipeline = self._prune_runner_pipeline_to_fragment(
        self._runner_pipeline, self._necessary_transforms)
[docs]  def deduce_fragment(self):
    """Deduce the pipeline fragment as an apache_beam.Pipeline instance."""
    fragment = beam.pipeline.Pipeline.from_runner_api(
        self._runner_pipeline.to_runner_api(),
        self._runner_pipeline.runner,
        self._options)
    ie.current_env().add_derived_pipeline(self._runner_pipeline, fragment)
    return fragment 
[docs]  def run(self, display_pipeline_graph=False, use_cache=True, blocking=False):
    """Shorthand to run the pipeline fragment."""
    try:
      preserved_skip_display = self._runner_pipeline.runner._skip_display
      preserved_force_compute = self._runner_pipeline.runner._force_compute
      preserved_blocking = self._runner_pipeline.runner._blocking
      self._runner_pipeline.runner._skip_display = not display_pipeline_graph
      self._runner_pipeline.runner._force_compute = not use_cache
      self._runner_pipeline.runner._blocking = blocking
      return self.deduce_fragment().run()
    finally:
      self._runner_pipeline.runner._skip_display = preserved_skip_display
      self._runner_pipeline.runner._force_compute = preserved_force_compute
      self._runner_pipeline.runner._blocking = preserved_blocking 
  def _build_runner_pipeline(self):
    runner_pipeline = beam.pipeline.Pipeline.from_runner_api(
        self._user_pipeline.to_runner_api(),
        self._user_pipeline.runner,
        self._options)
    ie.current_env().add_derived_pipeline(self._user_pipeline, runner_pipeline)
    return runner_pipeline
  def _calculate_target_pcoll_ids(self):
    pcoll_id_to_target_pcoll = {}
    for pcoll in self._pcolls:
      pcoll_id_to_target_pcoll[self._runner_pcoll_to_id.get(str(pcoll),
                                                            '')] = pcoll
    return pcoll_id_to_target_pcoll
  def _calculate_user_transform_labels(self):
    label_to_user_transform = {}
    class UserTransformVisitor(PipelineVisitor):
      def enter_composite_transform(self, transform_node):
        self.visit_transform(transform_node)
      def visit_transform(self, transform_node):
        if transform_node is not None:
          label_to_user_transform[transform_node.full_label] = transform_node
    v = UserTransformVisitor()
    self._runner_pipeline.visit(v)
    return label_to_user_transform
  def _build_correlation_between_pipelines(
      self, runner_pcoll_to_id, id_to_target_pcoll, label_to_user_transform):
    runner_pcolls_to_user_pcolls = {}
    runner_transforms_to_user_transforms = {}
    class CorrelationVisitor(PipelineVisitor):
      def enter_composite_transform(self, transform_node):
        self.visit_transform(transform_node)
      def visit_transform(self, transform_node):
        self._process_transform(transform_node)
        for in_pcoll in transform_node.inputs:
          self._process_pcoll(in_pcoll)
        for out_pcoll in transform_node.outputs.values():
          self._process_pcoll(out_pcoll)
      def _process_pcoll(self, pcoll):
        pcoll_id = runner_pcoll_to_id.get(str(pcoll), '')
        if pcoll_id in id_to_target_pcoll:
          runner_pcolls_to_user_pcolls[pcoll] = (id_to_target_pcoll[pcoll_id])
      def _process_transform(self, transform_node):
        if transform_node.full_label in label_to_user_transform:
          runner_transforms_to_user_transforms[transform_node] = (
              label_to_user_transform[transform_node.full_label])
    v = CorrelationVisitor()
    self._runner_pipeline.visit(v)
    return runner_pcolls_to_user_pcolls, runner_transforms_to_user_transforms
  def _mark_necessary_transforms_and_pcolls(self, runner_pcolls_to_user_pcolls):
    necessary_transforms = set()
    all_inputs = set()
    updated_all_inputs = set(runner_pcolls_to_user_pcolls.keys())
    # Do this until no more new PCollection is recorded.
    while len(updated_all_inputs) != len(all_inputs):
      all_inputs = set(updated_all_inputs)
      for pcoll in all_inputs:
        producer = pcoll.producer
        while producer:
          if producer in necessary_transforms:
            break
          # Mark the AppliedPTransform as necessary.
          necessary_transforms.add(producer)
          # Also mark composites that are not the root transform. If the root
          # transform is added, then all transforms are incorrectly marked as
          # necessary. If composites are not handled, then there will be
          # orphaned PCollections.
          if producer.parent is not None:
            necessary_transforms.update(producer.parts)
            # This will recursively add all the PCollections in this composite.
            for part in producer.parts:
              updated_all_inputs.update(part.outputs.values())
          # Record all necessary input and side input PCollections.
          updated_all_inputs.update(producer.inputs)
          # pylint: disable=bad-option-value
          side_input_pvalues = set(
              map(lambda side_input: side_input.pvalue, producer.side_inputs))
          updated_all_inputs.update(side_input_pvalues)
          # Go to its parent AppliedPTransform.
          producer = producer.parent
    return necessary_transforms, all_inputs
  def _prune_runner_pipeline_to_fragment(
      self, runner_pipeline, necessary_transforms):
    class PruneVisitor(PipelineVisitor):
      def enter_composite_transform(self, transform_node):
        if isinstance(transform_node.transform, TestStream):
          return
        pruned_parts = list(transform_node.parts)
        for part in transform_node.parts:
          if part not in necessary_transforms:
            pruned_parts.remove(part)
        transform_node.parts = tuple(pruned_parts)
        self.visit_transform(transform_node)
      def visit_transform(self, transform_node):
        if transform_node not in necessary_transforms:
          transform_node.parent = None
    v = PruneVisitor()
    runner_pipeline.visit(v)
    return runner_pipeline