#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Module to write cache for PCollections being computed.
For internal use only; no backward-compatibility guarantees.
"""
# pytype: skip-file
from typing import Tuple
import apache_beam as beam
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.interactive import cache_manager as cache
from apache_beam.runners.interactive.caching.cacheable import Cacheable
from apache_beam.runners.pipeline_context import PipelineContext
from apache_beam.testing import test_stream
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.window import WindowedValue
[docs]class WriteCache:
"""Class that facilitates writing cache for PCollections being computed.
"""
def __init__(
self,
pipeline: beam_runner_api_pb2.Pipeline,
context: PipelineContext,
cache_manager: cache.CacheManager,
cacheable: Cacheable):
self._pipeline = pipeline
self._context = context
self._cache_manager = cache_manager
self._cacheable = cacheable
self._key = repr(cacheable.to_key())
self._label = '{}{}'.format('_cache_', self._key)
[docs] def write_cache(self) -> None:
"""Writes cache for the cacheable PCollection that is being computed.
First, it creates a temporary pipeline instance on top of the existing
component_id_map from self._pipeline's context so that both pipelines
share the context and have no conflict component ids.
Second, it creates a _PCollectionPlaceHolder in the temporary pipeline that
mimics the attributes of the cacheable PCollection to be written into cache.
It also marks all components in the current temporary pipeline as
ignorable when later copying components to self._pipeline.
Third, it instantiates a _WriteCacheTransform that uses the
_PCollectionPlaceHolder as the input. This adds a subgraph under top level
transforms that writes the _PCollectionPlaceHolder into cache.
Fourth, it copies components of the subgraph from the temporary pipeline to
self._pipeline, skipping components that are ignored in the temporary
pipeline and components that are not in the temporary pipeline but presents
in the component_id_map of self._pipeline.
Last, it replaces inputs of all transforms that consume the
_PCollectionPlaceHolder with the cacheable PCollection to be written to
cache.
"""
template, write_input_placeholder = self._build_runner_api_template()
input_placeholder_id = self._context.pcollections.get_id(
write_input_placeholder.placeholder_pcoll)
input_id = self._context.pcollections.get_id(self._cacheable.pcoll)
# Copy cache writing subgraph from the template to the pipeline proto.
for pcoll_id in template.components.pcollections:
if (pcoll_id in self._pipeline.components.pcollections or
pcoll_id in write_input_placeholder.ignorable_components.pcollections
):
continue
self._pipeline.components.pcollections[pcoll_id].CopyFrom(
template.components.pcollections[pcoll_id])
for coder_id in template.components.coders:
if (coder_id in self._pipeline.components.coders or
coder_id in write_input_placeholder.ignorable_components.coders):
continue
self._pipeline.components.coders[coder_id].CopyFrom(
template.components.coders[coder_id])
for windowing_strategy_id in template.components.windowing_strategies:
if (windowing_strategy_id in
self._pipeline.components.windowing_strategies or
windowing_strategy_id in
write_input_placeholder.ignorable_components.windowing_strategies):
continue
self._pipeline.components.windowing_strategies[
windowing_strategy_id].CopyFrom(
template.components.windowing_strategies[windowing_strategy_id])
template_root_transform_id = template.root_transform_ids[0]
root_transform_id = self._pipeline.root_transform_ids[0]
for transform_id in template.components.transforms:
if (transform_id in self._pipeline.components.transforms or transform_id
in write_input_placeholder.ignorable_components.transforms):
continue
self._pipeline.components.transforms[transform_id].CopyFrom(
template.components.transforms[transform_id])
for top_level_transform in template.components.transforms[
template_root_transform_id].subtransforms:
if (top_level_transform in
write_input_placeholder.ignorable_components.transforms):
continue
self._pipeline.components.transforms[
root_transform_id].subtransforms.append(top_level_transform)
# Replace all the input pcoll of input_placeholder_id from cache writing
# with cacheable pcoll of input_id.
for transform in self._pipeline.components.transforms.values():
inputs = transform.inputs
if input_placeholder_id in inputs.values():
keys_need_replacement = set()
for key in inputs:
if inputs[key] == input_placeholder_id:
keys_need_replacement.add(key)
for key in keys_need_replacement:
inputs[key] = input_id
def _build_runner_api_template(
self) -> Tuple[beam_runner_api_pb2.Pipeline, '_PCollectionPlaceHolder']:
pph = _PCollectionPlaceHolder(self._cacheable.pcoll, self._context)
transform = _WriteCacheTransform(
self._cache_manager, self._key, self._label)
_ = pph.placeholder_pcoll | 'sink' + self._label >> transform
return pph.placeholder_pcoll.pipeline.to_runner_api(), pph
class _WriteCacheTransform(PTransform):
"""A composite transform encapsulates writing cache for PCollections.
"""
def __init__(self, cache_manager: cache.CacheManager, key: str, label: str):
self._cache_manager = cache_manager
self._key = key
self._label = label
def expand(self, pcoll: beam.pvalue.PCollection) -> beam.pvalue.PCollection:
class Reify(beam.DoFn):
def process(
self,
e,
w=beam.DoFn.WindowParam,
p=beam.DoFn.PaneInfoParam,
t=beam.DoFn.TimestampParam):
yield test_stream.WindowedValueHolder(WindowedValue(e, t, [w], p))
return (
pcoll
| 'reify' + self._label >> beam.ParDo(Reify())
| 'write' + self._label >> cache.WriteCache(
self._cache_manager, self._key, is_capture=False))
class _PCollectionPlaceHolder:
"""A placeholder as an input to the cache writing transform.
"""
def __init__(self, pcoll: beam.pvalue.PCollection, context: PipelineContext):
tmp_pipeline = beam.Pipeline()
tmp_pipeline.component_id_map = context.component_id_map
self._input_placeholder = tmp_pipeline | 'CreatePInput' >> beam.Create(
[], reshuffle=False)
self._input_placeholder.tag = pcoll.tag
self._input_placeholder.element_type = pcoll.element_type
self._input_placeholder.is_bounded = pcoll.is_bounded
self._input_placeholder._windowing = pcoll.windowing
self._ignorable_components = tmp_pipeline.to_runner_api().components
@property
def placeholder_pcoll(self) -> beam.pvalue.PCollection:
return self._input_placeholder
@property
def ignorable_components(self) -> beam_runner_api_pb2.Components:
"""Subgraph generated by the placeholder that can be ignored in the final
pipeline proto.
"""
return self._ignorable_components