#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import threading
import time
import warnings
import pandas as pd
import apache_beam as beam
from apache_beam.dataframe.frame_base import DeferredBase
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.interactive import background_caching_job as bcj
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive import interactive_runner as ir
from apache_beam.runners.interactive import pipeline_fragment as pf
from apache_beam.runners.interactive import utils
from apache_beam.runners.interactive.caching.cacheable import CacheKey
from apache_beam.runners.runner import PipelineState
_LOGGER = logging.getLogger(__name__)
[docs]class ElementStream:
  """A stream of elements from a given PCollection."""
  def __init__(
      self,
      pcoll,  # type: beam.pvalue.PCollection
      var,  # type: str
      cache_key,  # type: str
      max_n,  # type: int
      max_duration_secs  # type: float
      ):
    self._pcoll = pcoll
    self._cache_key = cache_key
    self._pipeline = ie.current_env().user_pipeline(pcoll.pipeline)
    self._var = var
    self._n = max_n
    self._duration_secs = max_duration_secs
    # A small state variable that when True, indicates that no more new elements
    # will be yielded if read() is called again.
    self._done = False
  @property
  def var(self):
    # type: () -> str
    """Returns the variable named that defined this PCollection."""
    return self._var
  @property
  def pcoll(self):
    # type: () -> beam.pvalue.PCollection
    """Returns the PCollection that supplies this stream with data."""
    return self._pcoll
  @property
  def cache_key(self):
    # type: () -> str
    """Returns the cache key for this stream."""
    return self._cache_key
[docs]  def display_id(self, suffix):
    # type: (str) -> str
    """Returns a unique id able to be displayed in a web browser."""
    return utils.obfuscate(self._cache_key, suffix) 
[docs]  def is_computed(self):
    # type: () -> boolean # noqa: F821
    """Returns True if no more elements will be recorded."""
    return self._pcoll in ie.current_env().computed_pcollections 
[docs]  def is_done(self):
    # type: () -> boolean # noqa: F821
    """Returns True if no more new elements will be yielded."""
    return self._done 
[docs]  def read(self, tail=True):
    # type: (boolean) -> Any # noqa: F821
    """Reads the elements currently recorded."""
    # Get the cache manager and wait until the file exists.
    cache_manager = ie.current_env().get_cache_manager(self._pipeline)
    # Retrieve the coder for the particular PCollection which will be used to
    # decode elements read from cache.
    coder = cache_manager.load_pcoder('full', self._cache_key)
    # Read the elements from the cache.
    # Import limiters here to prevent a circular import.
    from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
    from apache_beam.runners.interactive.options.capture_limiters import ProcessingTimeLimiter
    reader, _ = cache_manager.read('full', self._cache_key, tail=tail)
    # Because a single TestStreamFileRecord can yield multiple elements, we
    # limit the count again here in the to_element_list call.
    #
    # There are two ways of exiting this loop either a limiter was triggered or
    # all elements from the cache were read. In the latter situation, it may be
    # the case that the pipeline was still running. Thus, another invocation of
    # `read` will yield new elements.
    count_limiter = CountLimiter(self._n)
    time_limiter = ProcessingTimeLimiter(self._duration_secs)
    limiters = (count_limiter, time_limiter)
    for e in utils.to_element_list(reader,
                                   coder,
                                   include_window_info=True,
                                   n=self._n,
                                   include_time_events=True):
      # From the to_element_list we either get TestStreamPayload.Events if
      # include_time_events or decoded elements from the reader. Make sure we
      # only count the decoded elements to break early.
      if isinstance(e, beam_runner_api_pb2.TestStreamPayload.Event):
        time_limiter.update(e)
      else:
        count_limiter.update(e)
        yield e
      if any(l.is_triggered() for l in limiters):
        break
    # A limiter being triggered means that we have fulfilled the user's request.
    # This implies that reading from the cache again won't yield any new
    # elements. WLOG, this applies to the user pipeline being terminated.
    if any(l.is_triggered()
           for l in limiters) or ie.current_env().is_terminated(self._pipeline):
      self._done = True  
[docs]class Recording:
  """A group of PCollections from a given pipeline run."""
  def __init__(
      self,
      user_pipeline,  # type: beam.Pipeline
      pcolls,  # type: List[beam.pvalue.PCollection] # noqa: F821
      result,  # type: beam.runner.PipelineResult
      max_n,  # type: int
      max_duration_secs,  # type: float
      ):
    self._user_pipeline = user_pipeline
    self._result = result
    self._result_lock = threading.Lock()
    self._pcolls = pcolls
    pcoll_var = lambda pcoll: {v: k
                               for k, v in utils.pcoll_by_name().items()}.get(
                                   pcoll, None)
    self._streams = {
        pcoll: ElementStream(
            pcoll,
            pcoll_var(pcoll),
            CacheKey.from_pcoll(pcoll_var(pcoll), pcoll).to_str(),
            max_n,
            max_duration_secs)
        for pcoll in pcolls
    }
    self._start = time.time()
    self._duration_secs = max_duration_secs
    self._set_computed = bcj.is_cache_complete(str(id(user_pipeline)))
    # Run a separate thread for marking the PCollections done. This is because
    # the pipeline run may be asynchronous.
    self._mark_computed = threading.Thread(target=self._mark_all_computed)
    self._mark_computed.daemon = True
    self._mark_computed.start()
  def _mark_all_computed(self):
    # type: () -> None
    """Marks all the PCollections upon a successful pipeline run."""
    if not self._result:
      return
    while not PipelineState.is_terminal(self._result.state):
      with self._result_lock:
        bcj = ie.current_env().get_background_caching_job(self._user_pipeline)
        if bcj and bcj.is_done():
          self._result.wait_until_finish()
        elif time.time() - self._start >= self._duration_secs:
          self._result.cancel()
          self._result.wait_until_finish()
        elif all(s.is_done() for s in self._streams.values()):
          self._result.cancel()
          self._result.wait_until_finish()
      time.sleep(0.1)
    # Mark the PCollection as computed so that Interactive Beam wouldn't need to
    # re-compute.
    if self._result.state is PipelineState.DONE and self._set_computed:
      ie.current_env().mark_pcollection_computed(self._pcolls)
[docs]  def is_computed(self):
    # type: () -> boolean # noqa: F821
    """Returns True if all PCollections are computed."""
    return all(s.is_computed() for s in self._streams.values()) 
[docs]  def stream(self, pcoll):
    # type: (beam.pvalue.PCollection) -> ElementStream
    """Returns an ElementStream for a given PCollection."""
    return self._streams[pcoll] 
[docs]  def computed(self):
    # type: () -> None
    """Returns all computed ElementStreams."""
    return {p: s for p, s in self._streams.items() if s.is_computed()} 
[docs]  def uncomputed(self):
    # type: () -> None
    """Returns all uncomputed ElementStreams."""
    return {p: s for p, s in self._streams.items() if not s.is_computed()} 
[docs]  def cancel(self):
    # type: () -> None
    """Cancels the recording."""
    with self._result_lock:
      self._result.cancel() 
[docs]  def wait_until_finish(self):
    # type: () -> None
    """Waits until the pipeline is done and returns the final state.
    This also marks any PCollections as computed right away if the pipeline is
    successful.
    """
    if not self._result:
      return beam.runners.runner.PipelineState.DONE
    self._mark_computed.join()
    return self._result.state 
[docs]  def describe(self):
    # type: () -> dict[str, int]
    """Returns a dictionary describing the cache and recording."""
    cache_manager = ie.current_env().get_cache_manager(self._user_pipeline)
    size = sum(
        cache_manager.size('full', s.cache_key) for s in self._streams.values())
    return {'size': size, 'duration': self._duration_secs}  
[docs]class RecordingManager:
  """Manages recordings of PCollections for a given pipeline."""
  def __init__(self, user_pipeline, pipeline_var=None, test_limiters=None):
    # type: (beam.Pipeline, str, list[Limiter]) -> None # noqa: F821
    self.user_pipeline = user_pipeline  # type: beam.Pipeline
    self.pipeline_var = pipeline_var if pipeline_var else ''  # type: str
    self._recordings = set()  # type: set[Recording]
    self._start_time_sec = 0  # type: float
    self._test_limiters = test_limiters if test_limiters else []
  def _watch(self, pcolls):
    # type: (List[beam.pvalue.PCollection]) -> None # noqa: F821
    """Watch any pcollections not being watched.
    This allows for the underlying caching layer to identify the PCollection as
    something to be cached.
    """
    watched_pcollections = set()
    watched_dataframes = set()
    for watching in ie.current_env().watching():
      for _, val in watching:
        if isinstance(val, beam.pvalue.PCollection):
          watched_pcollections.add(val)
        elif isinstance(val, DeferredBase):
          watched_dataframes.add(val)
    # Convert them one-by-one to generate a unique label for each. This allows
    # caching at a more fine-grained granularity.
    #
    # TODO(https://github.com/apache/beam/issues/20929): investigate the mixing
    # pcollections in multiple pipelines error when using the default label.
    for df in watched_dataframes:
      pcoll, _ = utils.deferred_df_to_pcollection(df)
      watched_pcollections.add(pcoll)
    for pcoll in pcolls:
      if pcoll not in watched_pcollections:
        ie.current_env().watch(
            {'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
  def _clear(self):
    # type: () -> None
    """Clears the recording of all non-source PCollections."""
    cache_manager = ie.current_env().get_cache_manager(self.user_pipeline)
    # Only clear the PCollections that aren't being populated from the
    # BackgroundCachingJob.
    computed = ie.current_env().computed_pcollections
    cacheables = [
        c for c in utils.cacheables().values()
        if c.pcoll.pipeline is self.user_pipeline and c.pcoll not in computed
    ]
    all_cached = set(str(c.to_key()) for c in cacheables)
    source_pcolls = getattr(cache_manager, 'capture_keys', set())
    to_clear = all_cached - source_pcolls
    self._clear_pcolls(cache_manager, set(to_clear))
  def _clear_pcolls(self, cache_manager, pcolls):
    for pc in pcolls:
      cache_manager.clear('full', pc)
[docs]  def clear(self):
    # type: () -> None
    """Clears all cached PCollections for this RecordingManager."""
    cache_manager = ie.current_env().get_cache_manager(self.user_pipeline)
    if cache_manager:
      cache_manager.cleanup() 
[docs]  def cancel(self):
    # type: (None) -> None
    """Cancels the current background recording job."""
    bcj.attempt_to_cancel_background_caching_job(self.user_pipeline)
    for r in self._recordings:
      r.wait_until_finish()
    self._recordings = set()
    # The recordings rely on a reference to the BCJ to correctly finish. So we
    # evict the BCJ after they complete.
    ie.current_env().evict_background_caching_job(self.user_pipeline) 
[docs]  def describe(self):
    # type: () -> dict[str, int]
    """Returns a dictionary describing the cache and recording."""
    cache_manager = ie.current_env().get_cache_manager(self.user_pipeline)
    capture_size = getattr(cache_manager, 'capture_size', 0)
    descriptions = [r.describe() for r in self._recordings]
    size = sum(d['size'] for d in descriptions) + capture_size
    start = self._start_time_sec
    bcj = ie.current_env().get_background_caching_job(self.user_pipeline)
    if bcj:
      state = bcj.state
    else:
      state = PipelineState.STOPPED
    return {
        'size': size,
        'start': start,
        'state': state,
        'pipeline_var': self.pipeline_var
    } 
[docs]  def record_pipeline(self):
    # type: () -> bool
    """Starts a background caching job for this RecordingManager's pipeline."""
    runner = self.user_pipeline.runner
    if isinstance(runner, ir.InteractiveRunner):
      runner = runner._underlying_runner
    # Make sure that sources without a user reference are still cached.
    ie.current_env().add_user_pipeline(self.user_pipeline)
    utils.watch_sources(self.user_pipeline)
    # Attempt to run background caching job to record any sources.
    warnings.filterwarnings(
        'ignore',
        'options is deprecated since First stable release. References to '
        '<pipeline>.options will not be supported',
        category=DeprecationWarning)
    if bcj.attempt_to_run_background_caching_job(
        runner,
        self.user_pipeline,
        options=self.user_pipeline.options,
        limiters=self._test_limiters):
      self._start_time_sec = time.time()
      return True
    return False 
[docs]  def record(self, pcolls, max_n, max_duration):
    # type: (List[beam.pvalue.PCollection], int, Union[int,str]) -> Recording # noqa: F821
    """Records the given PCollections."""
    # Assert that all PCollection come from the same user_pipeline.
    for pcoll in pcolls:
      assert pcoll.pipeline is self.user_pipeline, (
        '{} belongs to a different user-defined pipeline ({}) than that of'
        ' other PCollections ({}).'.format(
            pcoll, pcoll.pipeline, self.user_pipeline))
    if isinstance(max_duration, str) and max_duration != 'inf':
      max_duration_secs = pd.to_timedelta(max_duration).total_seconds()
    else:
      max_duration_secs = max_duration
    # Make sure that all PCollections to be shown are watched. If a PCollection
    # has not been watched, make up a variable name for that PCollection and
    # watch it. No validation is needed here because the watch logic can handle
    # arbitrary variables.
    self._watch(pcolls)
    self.record_pipeline()
    # Get the subset of computed PCollections. These do not to be recomputed.
    computed_pcolls = set(
        pcoll for pcoll in pcolls
        if pcoll in ie.current_env().computed_pcollections)
    # Start a pipeline fragment to start computing the PCollections.
    uncomputed_pcolls = set(pcolls).difference(computed_pcolls)
    if uncomputed_pcolls:
      # Clear the cache of the given uncomputed PCollections because they are
      # incomplete.
      self._clear()
      cache_path = ie.current_env().options.cache_root
      is_remote_run = cache_path and ie.current_env(
      ).options.cache_root.startswith('gs://')
      pf.PipelineFragment(
          list(uncomputed_pcolls),
          self.user_pipeline.options).run(blocking=is_remote_run)
      result = ie.current_env().pipeline_result(self.user_pipeline)
    else:
      result = None
    recording = Recording(
        self.user_pipeline, pcolls, result, max_n, max_duration_secs)
    self._recordings.add(recording)
    return recording 
[docs]  def read(self, pcoll_name, pcoll, max_n, max_duration_secs):
    # type: (str, beam.pvalue.PValue, int, float) -> Union[None, ElementStream] # noqa: F821
    """Reads an ElementStream of a computed PCollection.
    Returns None if an error occurs. The caller is responsible of validating if
    the given pcoll_name and pcoll can identify a watched and computed
    PCollection without ambiguity in the notebook.
    """
    try:
      cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str()
      return ElementStream(
          pcoll, pcoll_name, cache_key, max_n, max_duration_secs)
    except (KeyboardInterrupt, SystemExit):
      raise
    except Exception as e:
      # Caller should handle all validations. Here to avoid redundant
      # validations, simply log errors if caller fails to do so.
      _LOGGER.error(str(e))
      return None