#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import threading
import time
import warnings
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import pandas as pd
import apache_beam as beam
from apache_beam.dataframe.frame_base import DeferredBase
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.interactive import background_caching_job as bcj
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive import interactive_runner as ir
from apache_beam.runners.interactive import pipeline_fragment as pf
from apache_beam.runners.interactive import utils
from apache_beam.runners.interactive.caching.cacheable import CacheKey
from apache_beam.runners.runner import PipelineState
_LOGGER = logging.getLogger(__name__)
[docs]class ElementStream:
"""A stream of elements from a given PCollection."""
def __init__(
self,
pcoll: beam.pvalue.PCollection,
var: str,
cache_key: str,
max_n: int,
max_duration_secs: float):
self._pcoll = pcoll
self._cache_key = cache_key
self._pipeline = ie.current_env().user_pipeline(pcoll.pipeline)
self._var = var
self._n = max_n
self._duration_secs = max_duration_secs
# A small state variable that when True, indicates that no more new elements
# will be yielded if read() is called again.
self._done = False
@property
def var(self) -> str:
"""Returns the variable named that defined this PCollection."""
return self._var
@property
def pcoll(self) -> beam.pvalue.PCollection:
"""Returns the PCollection that supplies this stream with data."""
return self._pcoll
@property
def cache_key(self) -> str:
"""Returns the cache key for this stream."""
return self._cache_key
[docs] def display_id(self, suffix: str) -> str:
"""Returns a unique id able to be displayed in a web browser."""
return utils.obfuscate(self._cache_key, suffix)
[docs] def is_computed(self) -> bool:
# noqa: F821
"""Returns True if no more elements will be recorded."""
return self._pcoll in ie.current_env().computed_pcollections
[docs] def is_done(self) -> bool:
# noqa: F821
"""Returns True if no more new elements will be yielded."""
return self._done
[docs] def read(self, tail: bool = True) -> Any:
"""Reads the elements currently recorded."""
# Get the cache manager and wait until the file exists.
cache_manager = ie.current_env().get_cache_manager(self._pipeline)
# Retrieve the coder for the particular PCollection which will be used to
# decode elements read from cache.
coder = cache_manager.load_pcoder('full', self._cache_key)
# Read the elements from the cache.
# Import limiters here to prevent a circular import.
from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
from apache_beam.runners.interactive.options.capture_limiters import ProcessingTimeLimiter
reader, _ = cache_manager.read('full', self._cache_key, tail=tail)
# Because a single TestStreamFileRecord can yield multiple elements, we
# limit the count again here in the to_element_list call.
#
# There are two ways of exiting this loop either a limiter was triggered or
# all elements from the cache were read. In the latter situation, it may be
# the case that the pipeline was still running. Thus, another invocation of
# `read` will yield new elements.
count_limiter = CountLimiter(self._n)
time_limiter = ProcessingTimeLimiter(self._duration_secs)
limiters = (count_limiter, time_limiter)
for e in utils.to_element_list(reader,
coder,
include_window_info=True,
n=self._n,
include_time_events=True):
# From the to_element_list we either get TestStreamPayload.Events if
# include_time_events or decoded elements from the reader. Make sure we
# only count the decoded elements to break early.
if isinstance(e, beam_runner_api_pb2.TestStreamPayload.Event):
time_limiter.update(e)
else:
count_limiter.update(e)
yield e
if any(l.is_triggered() for l in limiters):
break
# A limiter being triggered means that we have fulfilled the user's request.
# This implies that reading from the cache again won't yield any new
# elements. WLOG, this applies to the user pipeline being terminated.
if any(l.is_triggered()
for l in limiters) or ie.current_env().is_terminated(self._pipeline):
self._done = True
[docs]class Recording:
"""A group of PCollections from a given pipeline run."""
def __init__(
self,
user_pipeline: beam.Pipeline,
pcolls: List[beam.pvalue.PCollection], # noqa: F821
result: 'beam.runner.PipelineResult',
max_n: int,
max_duration_secs: float,
):
self._user_pipeline = user_pipeline
self._result = result
self._result_lock = threading.Lock()
self._pcolls = pcolls
pcoll_var = lambda pcoll: {v: k
for k, v in utils.pcoll_by_name().items()}.get(
pcoll, None)
self._streams = {
pcoll: ElementStream(
pcoll,
pcoll_var(pcoll),
CacheKey.from_pcoll(pcoll_var(pcoll), pcoll).to_str(),
max_n,
max_duration_secs)
for pcoll in pcolls
}
self._start = time.time()
self._duration_secs = max_duration_secs
self._set_computed = bcj.is_cache_complete(str(id(user_pipeline)))
# Run a separate thread for marking the PCollections done. This is because
# the pipeline run may be asynchronous.
self._mark_computed = threading.Thread(target=self._mark_all_computed)
self._mark_computed.daemon = True
self._mark_computed.start()
def _mark_all_computed(self) -> None:
"""Marks all the PCollections upon a successful pipeline run."""
if not self._result:
return
while not PipelineState.is_terminal(self._result.state):
with self._result_lock:
bcj = ie.current_env().get_background_caching_job(self._user_pipeline)
if bcj and bcj.is_done():
self._result.wait_until_finish()
elif time.time() - self._start >= self._duration_secs:
self._result.cancel()
self._result.wait_until_finish()
elif all(s.is_done() for s in self._streams.values()):
self._result.cancel()
self._result.wait_until_finish()
time.sleep(0.1)
# Mark the PCollection as computed so that Interactive Beam wouldn't need to
# re-compute.
if self._result.state is PipelineState.DONE and self._set_computed:
ie.current_env().mark_pcollection_computed(self._pcolls)
[docs] def is_computed(self) -> bool:
"""Returns True if all PCollections are computed."""
return all(s.is_computed() for s in self._streams.values())
[docs] def stream(self, pcoll: beam.pvalue.PCollection) -> ElementStream:
"""Returns an ElementStream for a given PCollection."""
return self._streams[pcoll]
[docs] def computed(self) -> None:
"""Returns all computed ElementStreams."""
return {p: s for p, s in self._streams.items() if s.is_computed()}
[docs] def uncomputed(self) -> None:
"""Returns all uncomputed ElementStreams."""
return {p: s for p, s in self._streams.items() if not s.is_computed()}
[docs] def cancel(self) -> None:
"""Cancels the recording."""
with self._result_lock:
self._result.cancel()
[docs] def wait_until_finish(self) -> None:
"""Waits until the pipeline is done and returns the final state.
This also marks any PCollections as computed right away if the pipeline is
successful.
"""
if not self._result:
return beam.runners.runner.PipelineState.DONE
self._mark_computed.join()
return self._result.state
[docs] def describe(self) -> Dict[str, int]:
"""Returns a dictionary describing the cache and recording."""
cache_manager = ie.current_env().get_cache_manager(self._user_pipeline)
size = sum(
cache_manager.size('full', s.cache_key) for s in self._streams.values())
return {'size': size, 'duration': self._duration_secs}
[docs]class RecordingManager:
"""Manages recordings of PCollections for a given pipeline."""
def __init__(
self,
user_pipeline: beam.Pipeline,
pipeline_var: str = None,
test_limiters: List['Limiter'] = None) -> None: # noqa: F821
self.user_pipeline: beam.Pipeline = user_pipeline
self.pipeline_var: str = pipeline_var if pipeline_var else ''
self._recordings: set[Recording] = set()
self._start_time_sec: float = 0
self._test_limiters = test_limiters if test_limiters else []
def _watch(self, pcolls: List[beam.pvalue.PCollection]) -> None:
"""Watch any pcollections not being watched.
This allows for the underlying caching layer to identify the PCollection as
something to be cached.
"""
watched_pcollections = set()
watched_dataframes = set()
for watching in ie.current_env().watching():
for _, val in watching:
if isinstance(val, beam.pvalue.PCollection):
watched_pcollections.add(val)
elif isinstance(val, DeferredBase):
watched_dataframes.add(val)
# Convert them one-by-one to generate a unique label for each. This allows
# caching at a more fine-grained granularity.
#
# TODO(https://github.com/apache/beam/issues/20929): investigate the mixing
# pcollections in multiple pipelines error when using the default label.
for df in watched_dataframes:
pcoll, _ = utils.deferred_df_to_pcollection(df)
watched_pcollections.add(pcoll)
for pcoll in pcolls:
if pcoll not in watched_pcollections:
ie.current_env().watch(
{'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
def _clear(self) -> None:
"""Clears the recording of all non-source PCollections."""
cache_manager = ie.current_env().get_cache_manager(self.user_pipeline)
# Only clear the PCollections that aren't being populated from the
# BackgroundCachingJob.
computed = ie.current_env().computed_pcollections
cacheables = [
c for c in utils.cacheables().values()
if c.pcoll.pipeline is self.user_pipeline and c.pcoll not in computed
]
all_cached = set(str(c.to_key()) for c in cacheables)
source_pcolls = getattr(cache_manager, 'capture_keys', set())
to_clear = all_cached - source_pcolls
self._clear_pcolls(cache_manager, set(to_clear))
def _clear_pcolls(self, cache_manager, pcolls):
for pc in pcolls:
cache_manager.clear('full', pc)
[docs] def clear(self) -> None:
"""Clears all cached PCollections for this RecordingManager."""
cache_manager = ie.current_env().get_cache_manager(self.user_pipeline)
if cache_manager:
cache_manager.cleanup()
[docs] def cancel(self: None) -> None:
"""Cancels the current background recording job."""
bcj.attempt_to_cancel_background_caching_job(self.user_pipeline)
for r in self._recordings:
r.wait_until_finish()
self._recordings = set()
# The recordings rely on a reference to the BCJ to correctly finish. So we
# evict the BCJ after they complete.
ie.current_env().evict_background_caching_job(self.user_pipeline)
[docs] def describe(self) -> Dict[str, int]:
"""Returns a dictionary describing the cache and recording."""
cache_manager = ie.current_env().get_cache_manager(self.user_pipeline)
capture_size = getattr(cache_manager, 'capture_size', 0)
descriptions = [r.describe() for r in self._recordings]
size = sum(d['size'] for d in descriptions) + capture_size
start = self._start_time_sec
bcj = ie.current_env().get_background_caching_job(self.user_pipeline)
if bcj:
state = bcj.state
else:
state = PipelineState.STOPPED
return {
'size': size,
'start': start,
'state': state,
'pipeline_var': self.pipeline_var
}
[docs] def record_pipeline(self) -> bool:
"""Starts a background caching job for this RecordingManager's pipeline."""
runner = self.user_pipeline.runner
if isinstance(runner, ir.InteractiveRunner):
runner = runner._underlying_runner
# Make sure that sources without a user reference are still cached.
ie.current_env().add_user_pipeline(self.user_pipeline)
utils.watch_sources(self.user_pipeline)
# Attempt to run background caching job to record any sources.
warnings.filterwarnings(
'ignore',
'options is deprecated since First stable release. References to '
'<pipeline>.options will not be supported',
category=DeprecationWarning)
if bcj.attempt_to_run_background_caching_job(
runner,
self.user_pipeline,
options=self.user_pipeline.options,
limiters=self._test_limiters):
self._start_time_sec = time.time()
return True
return False
[docs] def record(
self,
pcolls: List[beam.pvalue.PCollection],
max_n: int,
max_duration: Union[int, str]) -> Recording:
# noqa: F821
"""Records the given PCollections."""
# Assert that all PCollection come from the same user_pipeline.
for pcoll in pcolls:
assert pcoll.pipeline is self.user_pipeline, (
'{} belongs to a different user-defined pipeline ({}) than that of'
' other PCollections ({}).'.format(
pcoll, pcoll.pipeline, self.user_pipeline))
if isinstance(max_duration, str) and max_duration != 'inf':
max_duration_secs = pd.to_timedelta(max_duration).total_seconds()
else:
max_duration_secs = max_duration
# Make sure that all PCollections to be shown are watched. If a PCollection
# has not been watched, make up a variable name for that PCollection and
# watch it. No validation is needed here because the watch logic can handle
# arbitrary variables.
self._watch(pcolls)
self.record_pipeline()
# Get the subset of computed PCollections. These do not to be recomputed.
computed_pcolls = set(
pcoll for pcoll in pcolls
if pcoll in ie.current_env().computed_pcollections)
# Start a pipeline fragment to start computing the PCollections.
uncomputed_pcolls = set(pcolls).difference(computed_pcolls)
if uncomputed_pcolls:
# Clear the cache of the given uncomputed PCollections because they are
# incomplete.
self._clear()
cache_path = ie.current_env().options.cache_root
is_remote_run = cache_path and ie.current_env(
).options.cache_root.startswith('gs://')
pf.PipelineFragment(
list(uncomputed_pcolls),
self.user_pipeline.options).run(blocking=is_remote_run)
result = ie.current_env().pipeline_result(self.user_pipeline)
else:
result = None
recording = Recording(
self.user_pipeline, pcolls, result, max_n, max_duration_secs)
self._recordings.add(recording)
return recording
[docs] def read(
self,
pcoll_name: str,
pcoll: beam.pvalue.PValue,
max_n: int,
max_duration_secs: float) -> Union[None, ElementStream]:
# noqa: F821
"""Reads an ElementStream of a computed PCollection.
Returns None if an error occurs. The caller is responsible of validating if
the given pcoll_name and pcoll can identify a watched and computed
PCollection without ambiguity in the notebook.
"""
try:
cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str()
return ElementStream(
pcoll, pcoll_name, cache_key, max_n, max_duration_secs)
except (KeyboardInterrupt, SystemExit):
raise
except Exception as e:
# Caller should handle all validations. Here to avoid redundant
# validations, simply log errors if caller fails to do so.
_LOGGER.error(str(e))
return None