Source code for apache_beam.yaml.yaml_utils

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json
import os
import re
import uuid
from collections.abc import Iterable
from collections.abc import Mapping
from typing import Any
from typing import Tuple

import yaml
from yaml import SafeLoader


[docs] class SafeLineLoader(SafeLoader): """A yaml loader that attaches line information to mappings and strings."""
[docs] class TaggedString(str): """A string class to which we can attach metadata. This is primarily used to trace a string's origin back to its place in a yaml file. """ def __reduce__(self): # Pickle as an ordinary string. return str, (str(self), )
[docs] def construct_scalar(self, node): value = super().construct_scalar(node) if isinstance(value, str): value = SafeLineLoader.TaggedString(value) value._line_ = node.start_mark.line + 1 return value
[docs] def construct_mapping(self, node, deep=False): mapping = super().construct_mapping(node, deep=deep) mapping['__line__'] = node.start_mark.line + 1 mapping['__uuid__'] = self.create_uuid() return mapping
[docs] @classmethod def create_uuid(cls): return str(uuid.uuid4())
[docs] @classmethod def strip_metadata(cls, spec, tagged_str=True): if isinstance(spec, Mapping): return { cls.strip_metadata(key, tagged_str): cls.strip_metadata(value, tagged_str) for (key, value) in spec.items() if key not in ('__line__', '__uuid__') } elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): return [cls.strip_metadata(value, tagged_str) for value in spec] elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: return str(spec) else: return spec
[docs] @staticmethod def get_line(obj): if isinstance(obj, dict): return obj.get('__line__', 'unknown') else: return getattr(obj, '_line_', 'unknown')
[docs] def patch_yaml(original_str: str, updated): """Updates a yaml string to match the updated with minimal changes. This only changes the portions of original_str that differ between original_str and updated in an attempt to preserve comments and formatting. """ if not original_str and updated: return yaml.dump(updated, sort_keys=False) if original_str[-1] != '\n': # Add a trialing newline to avoid having to constantly check this edge case. # (It's also a good idea generally...) original_str += '\n' # The yaml parser returns positions in terms of line and column numbers. # Here we construct the mapping between the two. line_starts = [0] ix = original_str.find('\n') while ix != -1: line_starts.append(ix + 1) ix = original_str.find('\n', ix + 1) def pos(line_or_mark, column=0): if isinstance(line_or_mark, yaml.Mark): line = line_or_mark.line column = line_or_mark.column else: line = line_or_mark return line_starts[line] + column # Here we define a custom loader with hooks that record where each element is # found so we can swap it out appropriately. spans = {} class SafeMarkLoader(SafeLoader): pass # We create special subclass types to ensure each returned node is # a distinct object. marked_types = {} def record_yaml_scalar(constructor): def wrapper(self, node): raw_data = constructor(self, node) typ = type(raw_data) if typ not in marked_types: marked_types[typ] = type(f'Marked_{typ}', (type(raw_data), ), {}) marked_data = marked_types[typ](raw_data) spans[id(marked_data)] = node.start_mark, node.end_mark return marked_data return wrapper SafeMarkLoader.add_constructor( 'tag:yaml.org,2002:seq', record_yaml_scalar(SafeMarkLoader.construct_sequence)) SafeMarkLoader.add_constructor( 'tag:yaml.org,2002:map', record_yaml_scalar(SafeMarkLoader.construct_mapping)) for typ in ('bool', 'int', 'float', 'binary', 'timestamp', 'str'): SafeMarkLoader.add_constructor( f'tag:yaml.org,2002:{typ}', record_yaml_scalar(getattr(SafeMarkLoader, f'construct_yaml_{typ}'))) # Now load the original yaml using our special parser. original = yaml.load(original_str, Loader=SafeMarkLoader) # This (recursively) finds the portion of the original string that must # be replaced with new content. def diff(a: Any, b: Any) -> Iterable[Tuple[int, int, str]]: if a == b: return elif (isinstance(a, dict) and isinstance(b, dict) and set(a.keys()) == set(b.keys()) and all(id(v) in spans for v in a.values())): for k, v in a.items(): yield from diff(v, b[k]) elif (isinstance(a, list) and isinstance(b, list) and a and b and all(id(v) in spans for v in a)): # Diff the matching entries. for va, vb in zip(a, b): yield from diff(va, vb) if len(b) < len(a): # Remove extra entries yield ( # End of last preserved element. pos(spans[id(a[len(b) - 1])][1]), # End of last original element. pos(spans[id(a[-1])][1]), '') elif len(b) > len(a): # Add extra entries list_start, list_end = spans[id(a)] start_char = original_str[pos(list_start)] if start_char == '[': for v in b[len(a):]: yield pos(list_end) - 1, pos(list_end) - 1, ', ' + json.dumps(v) else: assert start_char == '-' indent = original_str[pos(list_start.line):pos(list_start)] + '- ' content = original_str[pos(list_start):pos(list_end)].rstrip() actual_end_pos = pos(list_start) + len(content) for v in b[len(a):]: if isinstance(v, (list, dict)): v_str = ( yaml.dump(v, sort_keys=False) # Indent. .replace('\n', '\n' + ' ' * len(indent)) # Remove blank line indents. .replace(' ' * len(indent) + '\n', '\n').rstrip()) else: v_str = json.dumps(v) yield actual_end_pos, actual_end_pos, '\n' + indent + v_str else: start, end = spans[id(a)] indent = original_str[pos(start.line):pos(start)] # We strip trailing whitespace as the "end" of an element is often on # a subsequent line where the subsequent element actually starts. content = original_str[pos(start):pos(end)].rstrip() actual_end_pos = pos(start) + len(content) trailing = original_str[actual_end_pos:original_str. find('\n', actual_end_pos)] if isinstance(b, (list, dict)): if indent.strip() in ('', '-') and not trailing.strip(): # This element wholly occupies its set of lines, so it is safe to use # a multi-line yaml representation (appropriately indented). yield ( pos(start), actual_end_pos, yaml.dump(b, sort_keys=False) # Indent. .replace('\n', '\n' + ' ' * len(indent)) # Remove blank line indents. .replace(' ' * len(indent) + '\n', '\n').rstrip()) else: # Force flow style. yield ( pos(start), actual_end_pos, yaml.dump(b, default_flow_style=True, line_break=False).strip()) elif isinstance(b, str) and re.match('^[A-Za-z0-9_]+$', b): # A simple string literal. yield pos(start), actual_end_pos, b else: # A scalar. yield pos(start), actual_end_pos, json.dumps(b) # Now stick it all together. last_end = 0 content = [] for start, end, new_content in sorted(diff(original, updated)): content.append(original_str[last_end:start]) content.append(new_content) last_end = end content.append(original_str[last_end:]) return ''.join(content)
[docs] def locate_data_file(relpath): return os.path.join(os.path.dirname(__file__), relpath)