Source code for apache_beam.runners.interactive.display.interactive_pipeline_graph

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Helper to render pipeline graph in IPython when running interactively.

This module is experimental. No backwards-compatibility guarantees.
"""

# pytype: skip-file

import re

from apache_beam.runners.interactive.display import pipeline_graph


[docs] def nice_str(o): s = repr(o) s = s.replace('"', "'") s = s.replace('\\', '|') s = re.sub(r'[^\x20-\x7F]', ' ', s) assert '"' not in s if len(s) > 35: s = s[:35] + '...' return s
[docs] def format_sample(contents, count=1000): contents = list(contents) elems = ', '.join([nice_str(o) for o in contents[:count]]) if len(contents) > count: elems += ', ...' assert '"' not in elems return '{%s}' % elems
[docs] class InteractivePipelineGraph(pipeline_graph.PipelineGraph): """Creates the DOT representation of an interactive pipeline. Thread-safe.""" def __init__( self, pipeline, required_transforms=None, referenced_pcollections=None, cached_pcollections=None): """Constructor of PipelineGraph. Args: pipeline: (Pipeline proto) or (Pipeline) pipeline to be rendered. required_transforms: (list/set of str) ID of top level PTransforms that lead to visible results. referenced_pcollections: (list/set of str) ID of PCollections that are referenced by top level PTransforms executed (i.e. required_transforms) cached_pcollections: (set of str) a set of PCollection IDs of those whose cached results are used in the execution. """ self._required_transforms = required_transforms or set() self._referenced_pcollections = referenced_pcollections or set() self._cached_pcollections = cached_pcollections or set() super().__init__( pipeline=pipeline, default_vertex_attrs={ 'color': 'gray', 'fontcolor': 'gray' }, default_edge_attrs={'color': 'gray'}) transform_updates, pcollection_updates = self._generate_graph_update_dicts() self._update_graph(transform_updates, pcollection_updates)
[docs] def update_pcollection_stats(self, pcollection_stats): """Updates PCollection stats. Args: pcollection_stats: (dict of dict) maps PCollection IDs to informations. In particular, we only care about the field 'sample' which should be a the PCollection result in as a list. """ edge_dict = {} for pcoll_id, stats in pcollection_stats.items(): attrs = {} pcoll_list = stats['sample'] if pcoll_list: attrs['label'] = format_sample(pcoll_list, 1) attrs['labeltooltip'] = format_sample(pcoll_list, 10) else: attrs['label'] = '?' edge_dict[pcoll_id] = attrs self._update_graph(edge_dict=edge_dict)
def _generate_graph_update_dicts(self): """Generate updates specific to interactive pipeline. Returns: vertex_dict: (Dict[str, Dict[str, str]]) maps vertex name to attributes edge_dict: (Dict[str, Dict[str, str]]) maps vertex name to attributes """ transform_dict = {} # maps PTransform IDs to properties pcoll_dict = {} # maps PCollection IDs to properties for transform_id, transform_proto in self._top_level_transforms(): transform_dict[transform_proto.unique_name] = { 'required': transform_id in self._required_transforms } for pcoll_id in transform_proto.outputs.values(): pcoll_dict[pcoll_id] = { 'cached': pcoll_id in self._cached_pcollections, 'referenced': pcoll_id in self._referenced_pcollections } def vertex_properties_to_attributes(vertex): """Converts PCollection properties to DOT vertex attributes.""" attrs = {} if 'leaf' in vertex: attrs['style'] = 'invis' elif vertex.get('required'): attrs['color'] = 'blue' attrs['fontcolor'] = 'blue' else: attrs['color'] = 'grey' return attrs def edge_properties_to_attributes(edge): """Converts PTransform properties to DOT edge attributes.""" attrs = {} if edge.get('cached'): attrs['color'] = 'red' elif edge.get('referenced'): attrs['color'] = 'black' else: attrs['color'] = 'grey' return attrs vertex_dict = {} # maps vertex names to attributes edge_dict = {} # maps edge names to attributes for transform_name, transform_properties in transform_dict.items(): vertex_dict[transform_name] = vertex_properties_to_attributes( transform_properties) for pcoll_id, pcoll_properties in pcoll_dict.items(): edge_dict[pcoll_id] = edge_properties_to_attributes(pcoll_properties) return vertex_dict, edge_dict