Source code for apache_beam.runners.interactive.caching.write_cache

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Module to write cache for PCollections being computed.

For internal use only; no backward-compatibility guarantees.
"""
# pytype: skip-file

from typing import Tuple

import apache_beam as beam
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.interactive import cache_manager as cache
from apache_beam.runners.interactive.caching.cacheable import Cacheable
from apache_beam.runners.pipeline_context import PipelineContext
from apache_beam.testing import test_stream
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.window import WindowedValue


[docs]class WriteCache: """Class that facilitates writing cache for PCollections being computed. """ def __init__( self, pipeline: beam_runner_api_pb2.Pipeline, context: PipelineContext, cache_manager: cache.CacheManager, cacheable: Cacheable): self._pipeline = pipeline self._context = context self._cache_manager = cache_manager self._cacheable = cacheable self._key = repr(cacheable.to_key()) self._label = '{}{}'.format('_cache_', self._key)
[docs] def write_cache(self) -> None: """Writes cache for the cacheable PCollection that is being computed. First, it creates a temporary pipeline instance on top of the existing component_id_map from self._pipeline's context so that both pipelines share the context and have no conflict component ids. Second, it creates a _PCollectionPlaceHolder in the temporary pipeline that mimics the attributes of the cacheable PCollection to be written into cache. It also marks all components in the current temporary pipeline as ignorable when later copying components to self._pipeline. Third, it instantiates a _WriteCacheTransform that uses the _PCollectionPlaceHolder as the input. This adds a subgraph under top level transforms that writes the _PCollectionPlaceHolder into cache. Fourth, it copies components of the subgraph from the temporary pipeline to self._pipeline, skipping components that are ignored in the temporary pipeline and components that are not in the temporary pipeline but presents in the component_id_map of self._pipeline. Last, it replaces inputs of all transforms that consume the _PCollectionPlaceHolder with the cacheable PCollection to be written to cache. """ template, write_input_placeholder = self._build_runner_api_template() input_placeholder_id = self._context.pcollections.get_id( write_input_placeholder.placeholder_pcoll) input_id = self._context.pcollections.get_id(self._cacheable.pcoll) # Copy cache writing subgraph from the template to the pipeline proto. for pcoll_id in template.components.pcollections: if (pcoll_id in self._pipeline.components.pcollections or pcoll_id in write_input_placeholder.ignorable_components.pcollections ): continue self._pipeline.components.pcollections[pcoll_id].CopyFrom( template.components.pcollections[pcoll_id]) for coder_id in template.components.coders: if (coder_id in self._pipeline.components.coders or coder_id in write_input_placeholder.ignorable_components.coders): continue self._pipeline.components.coders[coder_id].CopyFrom( template.components.coders[coder_id]) for windowing_strategy_id in template.components.windowing_strategies: if (windowing_strategy_id in self._pipeline.components.windowing_strategies or windowing_strategy_id in write_input_placeholder.ignorable_components.windowing_strategies): continue self._pipeline.components.windowing_strategies[ windowing_strategy_id].CopyFrom( template.components.windowing_strategies[windowing_strategy_id]) template_root_transform_id = template.root_transform_ids[0] root_transform_id = self._pipeline.root_transform_ids[0] for transform_id in template.components.transforms: if (transform_id in self._pipeline.components.transforms or transform_id in write_input_placeholder.ignorable_components.transforms): continue self._pipeline.components.transforms[transform_id].CopyFrom( template.components.transforms[transform_id]) for top_level_transform in template.components.transforms[ template_root_transform_id].subtransforms: if (top_level_transform in write_input_placeholder.ignorable_components.transforms): continue self._pipeline.components.transforms[ root_transform_id].subtransforms.append(top_level_transform) # Replace all the input pcoll of input_placeholder_id from cache writing # with cacheable pcoll of input_id. for transform in self._pipeline.components.transforms.values(): inputs = transform.inputs if input_placeholder_id in inputs.values(): keys_need_replacement = set() for key in inputs: if inputs[key] == input_placeholder_id: keys_need_replacement.add(key) for key in keys_need_replacement: inputs[key] = input_id
def _build_runner_api_template( self) -> Tuple[beam_runner_api_pb2.Pipeline, '_PCollectionPlaceHolder']: pph = _PCollectionPlaceHolder(self._cacheable.pcoll, self._context) transform = _WriteCacheTransform( self._cache_manager, self._key, self._label) _ = pph.placeholder_pcoll | 'sink' + self._label >> transform return pph.placeholder_pcoll.pipeline.to_runner_api(), pph
class _WriteCacheTransform(PTransform): """A composite transform encapsulates writing cache for PCollections. """ def __init__(self, cache_manager: cache.CacheManager, key: str, label: str): self._cache_manager = cache_manager self._key = key self._label = label def expand(self, pcoll: beam.pvalue.PCollection) -> beam.pvalue.PCollection: class Reify(beam.DoFn): def process( self, e, w=beam.DoFn.WindowParam, p=beam.DoFn.PaneInfoParam, t=beam.DoFn.TimestampParam): yield test_stream.WindowedValueHolder(WindowedValue(e, t, [w], p)) return ( pcoll | 'reify' + self._label >> beam.ParDo(Reify()) | 'write' + self._label >> cache.WriteCache( self._cache_manager, self._key, is_capture=False)) class _PCollectionPlaceHolder: """A placeholder as an input to the cache writing transform. """ def __init__(self, pcoll: beam.pvalue.PCollection, context: PipelineContext): tmp_pipeline = beam.Pipeline() tmp_pipeline.component_id_map = context.component_id_map self._input_placeholder = tmp_pipeline | 'CreatePInput' >> beam.Create( [], reshuffle=False) self._input_placeholder.tag = pcoll.tag self._input_placeholder.element_type = pcoll.element_type self._input_placeholder.is_bounded = pcoll.is_bounded self._input_placeholder._windowing = pcoll.windowing self._ignorable_components = tmp_pipeline.to_runner_api().components @property def placeholder_pcoll(self) -> beam.pvalue.PCollection: return self._input_placeholder @property def ignorable_components(self) -> beam_runner_api_pb2.Components: """Subgraph generated by the placeholder that can be ignored in the final pipeline proto. """ return self._ignorable_components