#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Manages displaying pipeline graph and execution status on the frontend.
This module is experimental. No backwards-compatibility guarantees.
"""
# pytype: skip-file
import collections
import threading
import time
from typing import TYPE_CHECKING
from apache_beam.runners.interactive.display import interactive_pipeline_graph
try:
  import IPython  # pylint: disable=import-error
  from IPython import get_ipython  # pylint: disable=import-error
  from IPython.display import display as ip_display  # pylint: disable=import-error
  # _display_progress defines how outputs are printed on the frontend.
  _display_progress = ip_display
  if not TYPE_CHECKING:
    def _formatter(string, pp, cycle):  # pylint: disable=unused-argument
      pp.text(string)
    if get_ipython():
      plain = get_ipython().display_formatter.formatters['text/plain']  # pylint: disable=undefined-variable
      plain.for_type(str, _formatter)
except ImportError:
  IPython = None
  _display_progress = print
[docs]class DisplayManager(object):
  """Manages displaying pipeline graph and execution status on the frontend."""
  def __init__(
      self,
      pipeline_proto,
      pipeline_analyzer,
      cache_manager,
      pipeline_graph_renderer):
    """Constructor of DisplayManager.
    Args:
      pipeline_proto: (Pipeline proto)
      pipeline_analyzer: (PipelineAnalyzer) the pipeline analyzer that
          corresponds to this round of execution. This will provide more
          detailed informations about the pipeline
      cache_manager: (interactive_runner.CacheManager) DisplayManager fetches
          the latest status of pipeline execution by querying cache_manager.
      pipeline_graph_renderer: (pipeline_graph_renderer.PipelineGraphRenderer)
          decides how a pipeline graph is rendered.
    """
    # Every parameter except cache_manager is expected to remain constant.
    self._analyzer = pipeline_analyzer
    self._cache_manager = cache_manager
    self._pipeline_graph = interactive_pipeline_graph.InteractivePipelineGraph(
        pipeline_proto,
        required_transforms=self._analyzer.tl_required_trans_ids(),
        referenced_pcollections=self._analyzer.tl_referenced_pcoll_ids(),
        cached_pcollections=self._analyzer.caches_used())
    self._renderer = pipeline_graph_renderer
    # _text_to_print keeps track of information to be displayed.
    self._text_to_print = collections.OrderedDict()
    self._text_to_print['summary'] = (
        'Using %s cached PCollections\nExecuting %s of %s '
        'transforms.') % (
            len(self._analyzer.caches_used()),
            (
                len(self._analyzer.tl_required_trans_ids()) -
                len(self._analyzer.read_cache_ids()) -
                len(self._analyzer.write_cache_ids())),
            len(
                pipeline_proto.components.transforms[
                    pipeline_proto.root_transform_ids[0]].subtransforms))
    self._text_to_print.update(
        {pcoll_id: ""
         for pcoll_id in self._analyzer.tl_referenced_pcoll_ids()})
    # _pcollection_stats maps pcoll_id to
    # { 'cache_label': cache_label, version': version, 'sample': pcoll_in_list }
    self._pcollection_stats = {}
    for pcoll_id in self._analyzer.tl_referenced_pcoll_ids():
      self._pcollection_stats[pcoll_id] = {
          'cache_label': self._analyzer.pipeline_info().cache_label(pcoll_id),
          'version': -1,
          'sample': []
      }
    self._producers = {}
    for _, transform in pipeline_proto.components.transforms.items():
      for pcoll_id in transform.outputs.values():
        if pcoll_id not in self._producers or '/' not in transform.unique_name:
          self._producers[pcoll_id] = transform.unique_name
    # For periodic update.
    self._lock = threading.Lock()
    self._periodic_update = False
[docs]  def update_display(self, force=False):
    """Updates display on the frontend.
    Retrieves the latest execution status by querying CacheManager and updates
    display on the fronend. The assumption is that there is only one pipeline in
    a cell, because it clears up everything in the cell output every update
    cycle.
    Args:
      force: (bool) whether to force updating when no stats change happens.
    """
    with self._lock:
      stats_updated = False
      for pcoll_id, stats in self._pcollection_stats.items():
        cache_label = stats['cache_label']
        version = stats['version']
        if force or not self._cache_manager.is_latest_version(
            version, 'sample', cache_label):
          pcoll_list, version = self._cache_manager.read('sample', cache_label)
          stats['sample'] = list(pcoll_list)
          stats['version'] = version
          stats_updated = True
          if pcoll_id in self._analyzer.tl_referenced_pcoll_ids():
            self._text_to_print[pcoll_id] = (
                str(
                    '%s produced %s' % (
                        self._producers[pcoll_id],
                        interactive_pipeline_graph.format_sample(pcoll_list,
                                                                 5))))
      if force or stats_updated:
        self._pipeline_graph.update_pcollection_stats(self._pcollection_stats)
        if IPython:
          from IPython import display
          display.clear_output(True)
          rendered_graph = self._renderer.render_pipeline_graph(
              self._pipeline_graph)
          display.display(display.HTML(rendered_graph))
        _display_progress('Running...')
        for text in self._text_to_print.values():
          if text != "":
            _display_progress(text) 
[docs]  def start_periodic_update(self):
    """Start a thread that periodically updates the display."""
    self.update_display(True)
    self._periodic_update = True
    def _updater():
      while self._periodic_update:
        self.update_display()
        time.sleep(.02)
    t = threading.Thread(target=_updater)
    t.daemon = True
    t.start() 
[docs]  def stop_periodic_update(self):
    """Stop periodically updating the display."""
    self.update_display(True)
    self._periodic_update = False