#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import re
import uuid
from collections.abc import Iterable
from collections.abc import Mapping
from typing import Any
from typing import Tuple
import yaml
from yaml import SafeLoader
[docs]
class SafeLineLoader(SafeLoader):
  """A yaml loader that attaches line information to mappings and strings."""
[docs]
  class TaggedString(str):
    """A string class to which we can attach metadata.
    This is primarily used to trace a string's origin back to its place in a
    yaml file.
    """
    def __reduce__(self):
      # Pickle as an ordinary string.
      return str, (str(self), ) 
[docs]
  def construct_scalar(self, node):
    value = super().construct_scalar(node)
    if isinstance(value, str):
      value = SafeLineLoader.TaggedString(value)
      value._line_ = node.start_mark.line + 1
    return value 
[docs]
  def construct_mapping(self, node, deep=False):
    mapping = super().construct_mapping(node, deep=deep)
    mapping['__line__'] = node.start_mark.line + 1
    mapping['__uuid__'] = self.create_uuid()
    return mapping 
[docs]
  @classmethod
  def create_uuid(cls):
    return str(uuid.uuid4()) 
[docs]
  @staticmethod
  def get_line(obj):
    if isinstance(obj, dict):
      return obj.get('__line__', 'unknown')
    else:
      return getattr(obj, '_line_', 'unknown') 
 
[docs]
def patch_yaml(original_str: str, updated):
  """Updates a yaml string to match the updated with minimal changes.
  This only changes the portions of original_str that differ between
  original_str and updated in an attempt to preserve comments and formatting.
  """
  if not original_str and updated:
    return yaml.dump(updated, sort_keys=False)
  if original_str[-1] != '\n':
    # Add a trialing newline to avoid having to constantly check this edge case.
    # (It's also a good idea generally...)
    original_str += '\n'
  # The yaml parser returns positions in terms of line and column numbers.
  # Here we construct the mapping between the two.
  line_starts = [0]
  ix = original_str.find('\n')
  while ix != -1:
    line_starts.append(ix + 1)
    ix = original_str.find('\n', ix + 1)
  def pos(line_or_mark, column=0):
    if isinstance(line_or_mark, yaml.Mark):
      line = line_or_mark.line
      column = line_or_mark.column
    else:
      line = line_or_mark
    return line_starts[line] + column
  # Here we define a custom loader with hooks that record where each element is
  # found so we can swap it out appropriately.
  spans = {}
  class SafeMarkLoader(SafeLoader):
    pass
  # We create special subclass types to ensure each returned node is
  # a distinct object.
  marked_types = {}
  def record_yaml_scalar(constructor):
    def wrapper(self, node):
      raw_data = constructor(self, node)
      typ = type(raw_data)
      if typ not in marked_types:
        marked_types[typ] = type(f'Marked_{typ}', (type(raw_data), ), {})
      marked_data = marked_types[typ](raw_data)
      spans[id(marked_data)] = node.start_mark, node.end_mark
      return marked_data
    return wrapper
  SafeMarkLoader.add_constructor(
      'tag:yaml.org,2002:seq',
      record_yaml_scalar(SafeMarkLoader.construct_sequence))
  SafeMarkLoader.add_constructor(
      'tag:yaml.org,2002:map',
      record_yaml_scalar(SafeMarkLoader.construct_mapping))
  for typ in ('bool', 'int', 'float', 'binary', 'timestamp', 'str'):
    SafeMarkLoader.add_constructor(
        f'tag:yaml.org,2002:{typ}',
        record_yaml_scalar(getattr(SafeMarkLoader, f'construct_yaml_{typ}')))
  #  Now load the original yaml using our special parser.
  original = yaml.load(original_str, Loader=SafeMarkLoader)
  # This (recursively) finds the portion of the original string that must
  # be replaced with new content.
  def diff(a: Any, b: Any) -> Iterable[Tuple[int, int, str]]:
    if a == b:
      return
    elif (isinstance(a, dict) and isinstance(b, dict) and
          set(a.keys()) == set(b.keys()) and
          all(id(v) in spans for v in a.values())):
      for k, v in a.items():
        yield from diff(v, b[k])
    elif (isinstance(a, list) and isinstance(b, list) and a and b and
          all(id(v) in spans for v in a)):
      # Diff the matching entries.
      for va, vb in zip(a, b):
        yield from diff(va, vb)
      if len(b) < len(a):
        # Remove extra entries
        yield (
            # End of last preserved element.
            pos(spans[id(a[len(b) - 1])][1]),
            # End of last original element.
            pos(spans[id(a[-1])][1]),
            '')
      elif len(b) > len(a):
        # Add extra entries
        list_start, list_end = spans[id(a)]
        start_char = original_str[pos(list_start)]
        if start_char == '[':
          for v in b[len(a):]:
            yield pos(list_end) - 1, pos(list_end) - 1, ', ' + json.dumps(v)
        else:
          assert start_char == '-'
          indent = original_str[pos(list_start.line):pos(list_start)] + '- '
          content = original_str[pos(list_start):pos(list_end)].rstrip()
          actual_end_pos = pos(list_start) + len(content)
          for v in b[len(a):]:
            if isinstance(v, (list, dict)):
              v_str = (
                  yaml.dump(v, sort_keys=False)
                  # Indent.
                  .replace('\n', '\n' + ' ' * len(indent))
                  # Remove blank line indents.
                  .replace(' ' * len(indent) + '\n', '\n').rstrip())
            else:
              v_str = json.dumps(v)
            yield actual_end_pos, actual_end_pos, '\n' + indent + v_str
    else:
      start, end = spans[id(a)]
      indent = original_str[pos(start.line):pos(start)]
      # We strip trailing whitespace as the "end" of an element is often on
      # a subsequent line where the subsequent element actually starts.
      content = original_str[pos(start):pos(end)].rstrip()
      actual_end_pos = pos(start) + len(content)
      trailing = original_str[actual_end_pos:original_str.
                              find('\n', actual_end_pos)]
      if isinstance(b, (list, dict)):
        if indent.strip() in ('', '-') and not trailing.strip():
          # This element wholly occupies its set of lines, so it is safe to use
          # a multi-line yaml representation (appropriately indented).
          yield (
              pos(start),
              actual_end_pos,
              yaml.dump(b, sort_keys=False)
              # Indent.
              .replace('\n', '\n' + ' ' * len(indent))
              # Remove blank line indents.
              .replace(' ' * len(indent) + '\n', '\n').rstrip())
        else:
          # Force flow style.
          yield (
              pos(start),
              actual_end_pos,
              yaml.dump(b, default_flow_style=True, line_break=False).strip())
      elif isinstance(b, str) and re.match('^[A-Za-z0-9_]+$', b):
        # A simple string literal.
        yield pos(start), actual_end_pos, b
      else:
        # A scalar.
        yield pos(start), actual_end_pos, json.dumps(b)
  # Now stick it all together.
  last_end = 0
  content = []
  for start, end, new_content in sorted(diff(original, updated)):
    content.append(original_str[last_end:start])
    content.append(new_content)
    last_end = end
  content.append(original_str[last_end:])
  return ''.join(content) 
[docs]
def locate_data_file(relpath):
  return os.path.join(os.path.dirname(__file__), relpath)