Source code for apache_beam.yaml.yaml_testing

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import collections
import functools
import json
import random
import unittest
import uuid
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union

import yaml

import apache_beam as beam
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform
from apache_beam.yaml import yaml_utils


[docs] class YamlTestCase(unittest.TestCase): def __init__(self, pipeline_spec, test_spec, options, fix_tests): super().__init__() self._pipeline_spec = pipeline_spec self._test_spec = test_spec self._options = options self._fix_tests = fix_tests
[docs] def runTest(self): self._fixes = run_test( self._pipeline_spec, self._test_spec, self._options, self._fix_tests)
[docs] def fixed_test(self): fixed_test_spec = yaml_transform.SafeLineLoader.strip_metadata( self._test_spec) if self._fixes: expectation_by_id = {(loc, expectation['name']): expectation for loc in ('expected_inputs', 'expected_outputs') for expectation in fixed_test_spec.get(loc, [])} for name_loc, values in self._fixes.items(): expectation_by_id[name_loc]['elements'] = sorted(values, key=json.dumps) return fixed_test_spec
[docs] def id(self): return ( self._test_spec.get('name', 'unknown') + f' (line {yaml_transform.SafeLineLoader.get_line(self._test_spec)})')
def __str__(self): return self.id()
[docs] def run_test(pipeline_spec, test_spec, options=None, fix_failures=False): if isinstance(pipeline_spec, str): pipeline_spec = yaml.load(pipeline_spec, Loader=yaml_utils.SafeLineLoader) pipeline_spec = _preprocess_for_testing(pipeline_spec) transform_spec, recording_ids = inject_test_tranforms( pipeline_spec, test_spec, fix_failures) allowed_sources = set(test_spec.get('allowed_sources', []) + ['Create']) for transform in transform_spec['transforms']: name_or_type = transform.get('name', transform['type']) if (not yaml_transform.empty_if_explicitly_empty(transform.get('input', [])) and not transform.get('name') in allowed_sources and not transform['type'] in allowed_sources): raise ValueError( f'Non-mocked source {name_or_type} ' f'at line {yaml_transform.SafeLineLoader.get_line(transform)}') if options is None: options = beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle', **yaml_transform.SafeLineLoader.strip_metadata( pipeline_spec.get('options', {}))) with beam.Pipeline(options=options) as p: _ = p | yaml_transform.YamlTransform( transform_spec, providers={'AssertEqualAndRecord': AssertEqualAndRecord}) if fix_failures: fixes = {} for recording_id in recording_ids: if AssertEqualAndRecord.has_recorded_result(recording_id): fixes[recording_id[1:]] = [ _try_row_as_dict(row) for row in AssertEqualAndRecord.get_recorded_result(recording_id) ] AssertEqualAndRecord.remove_recorded_result(recording_id) return fixes
def _preprocess_for_testing(pipeline_spec): spec = yaml_transform.pipeline_as_composite(pipeline_spec['pipeline']) # These are idempotent, so it's OK to do them preemptively. for phase in [ yaml_transform.ensure_transforms_have_types, yaml_transform.preprocess_source_sink, yaml_transform.preprocess_chain, yaml_transform.tag_explicit_inputs, yaml_transform.normalize_inputs_outputs, ]: spec = yaml_transform.apply_phase(phase, spec) return spec
[docs] def validate_test_spec(test_spec): if not isinstance(test_spec, dict): raise TypeError( f'Test specification must be an object, got {type(test_spec)}') identifier = ( test_spec.get('name', 'unknown') + f' at line {yaml_transform.SafeLineLoader.get_line(test_spec)}') if not isinstance(test_spec.get('allowed_sources', []), list): raise TypeError( f'allowed_sources of test specification {identifier} ' f'must be a list, got {type(test_spec["allowed_sources"])}') if (not test_spec.get('expected_outputs', []) and not test_spec.get('expected_inputs', [])): raise ValueError( f'test specification {identifier} ' f'must have at least one expected_outputs or expected_inputs') unknown_attrs = set( yaml_transform.SafeLineLoader.strip_metadata(test_spec).keys()) - set([ 'name', 'mock_inputs', 'mock_outputs', 'expected_outputs', 'expected_inputs', 'allowed_sources', ]) if unknown_attrs: raise ValueError( f'test specification {identifier} ' f'has unknown attributes {list(unknown_attrs)}') for attr_type in ('mock_inputs', 'mock_outputs', 'expected_outputs', 'expected_inputs'): attr = test_spec.get(attr_type, []) if not isinstance(attr, list): raise TypeError( f'{attr_type} of test specification {identifier} ' f'must be a list, got {type(attr_type)}') for ix, attr_item in enumerate(attr): if not isinstance(attr_item, dict): raise TypeError( f'{attr_type} {ix} of test specification {identifier} ' f'must be an object, got {type(attr_item)}') if 'name' not in attr_item: raise TypeError( f'{attr_type} {ix} of test specification {identifier} ' f'missing a name') if 'elements' not in attr_item: raise TypeError( f'{attr_type} {ix} of test specification {identifier} ' f'missing a elements') if not isinstance(attr_item['elements'], list): raise TypeError( f'{attr_type} {ix} of test specification {identifier} ' f'must be a list, got {type(attr_item["elements"])}')
[docs] def inject_test_tranforms(spec, test_spec, fix_failures): validate_test_spec(test_spec) scope = yaml_transform.LightweightScope(spec['transforms']) mocked_inputs_by_id = { scope.get_transform_id(mock_input['name']): mock_input for mock_input in test_spec.get('mock_inputs', []) } mocked_outputs_by_id = _composite_key_to_nested({ scope.get_transform_id_and_output_name(mock_output['name']): mock_output for mock_output in test_spec.get('mock_outputs', []) }) recording_id_prefix = str(uuid.uuid4()) recording_ids = [] transforms = [] @functools.cache def create_inputs(transform_id: str) -> InputsType: def require_output_or_outputs(name_or_names): if isinstance(name_or_names, str): return require_output(name_or_names) else: return [require_output(name) for name in name_or_names] if transform_id in mocked_inputs_by_id: return create_mocked_input(transform_id) else: input_spec = scope.get_transform_spec(transform_id)['input'] return { tag: require_output_or_outputs(input_ref) for tag, input_ref in yaml_transform.empty_if_explicitly_empty( input_spec).items() } def require_output(name: str) -> str: # The same output may be referenced under different names. # Normalize before we cache. transform_id, tag = scope.get_transform_id_and_output_name(name) return _require_output(transform_id, tag) or name @functools.cache def _require_output(transform_id: str, tag: str) -> Optional[str]: if transform_id in mocked_outputs_by_id: if tag not in mocked_outputs_by_id[transform_id]: name = next(iter( mocked_outputs_by_id[transform_id].values()))['name'].split('.')[0] raise ValueError( f'Unmocked output {tag} of {name}.' 'If any used output is mocked all used outputs must be mocked.') return create_mocked_output(transform_id, tag) else: _use_transform(transform_id) return None # Use original name. @functools.cache def _use_transform(transform_id: str) -> None: transform_spec = dict(scope.get_transform_spec(transform_id)) transform_spec['input'] = create_inputs(transform_id) transforms.append(transform_spec) @functools.cache def create_mocked_input(transform_id: str) -> str: transform = create_create( f'MockInput[{mocked_inputs_by_id[transform_id]["name"]}]', mocked_inputs_by_id[transform_id]['elements'], mocked_inputs_by_id[transform_id]['name']) transforms.append(transform) return transform['__uuid__'] @functools.cache def create_mocked_output(transform_id: str, tag: str) -> str: transform = create_create( f'MockOutput[{mocked_outputs_by_id[transform_id][tag]["name"]}]', mocked_outputs_by_id[transform_id][tag]['elements'], mocked_outputs_by_id[transform_id][tag]['name']) transforms.append(transform) return transform['__uuid__'] def create_create(name, elements, line_source): return { '__uuid__': yaml_utils.SafeLineLoader.create_uuid(), '__line__': yaml_utils.SafeLineLoader.get_line(line_source), 'name': name, 'type': 'Create', 'config': { 'elements': elements, }, } def create_assertion(name, inputs, elements, recording_id, line_source): return { '__uuid__': yaml_utils.SafeLineLoader.create_uuid(), '__line__': yaml_utils.SafeLineLoader.get_line(line_source), 'name': name, 'input': inputs, 'type': 'AssertEqualAndRecord', 'config': { 'elements': elements, 'recording_id': recording_id, }, } for expected_output in test_spec.get('expected_outputs', []): if fix_failures: recording_id = ( recording_id_prefix, 'expected_outputs', expected_output['name']) recording_ids.append(recording_id) else: recording_id = None require_output(expected_output['name']) transforms.append( create_assertion( f'CheckExpectedOutput[{expected_output["name"]}]', expected_output['name'], expected_output['elements'], recording_id, expected_output['name'])) for expected_input in test_spec.get('expected_inputs', []): if fix_failures: recording_id = ( recording_id_prefix, 'expected_inputs', expected_input['name']) recording_ids.append(recording_id) else: recording_id = None transform_id = scope.get_transform_id(expected_input['name']) transforms.append( create_assertion( f'CheckExpectedInput[{expected_input["name"]}]', create_inputs(transform_id), expected_input['elements'], recording_id, expected_input['name'])) return { '__uuid__': yaml_utils.SafeLineLoader.create_uuid(), '__line__': 0, 'type': 'composite', 'transforms': transforms, }, recording_ids
[docs] class AssertEqualAndRecord(beam.PTransform): _recorded_results = {}
[docs] @classmethod def store_recorded_result(cls, recording_id, value): assert recording_id not in cls._recorded_results cls._recorded_results[recording_id] = value
[docs] @classmethod def has_recorded_result(cls, recording_id): return recording_id in cls._recorded_results
[docs] @classmethod def get_recorded_result(cls, recording_id): return cls._recorded_results[recording_id]
[docs] @classmethod def remove_recorded_result(cls, recording_id): del cls._recorded_results[recording_id]
def __init__(self, elements, recording_id): self._elements = elements self._recording_id = recording_id
[docs] def expand(self, pcoll): equal_to_matcher = equal_to(yaml_provider.dicts_to_rows(self._elements)) def matcher(actual): try: equal_to_matcher(actual) except Exception: if self._recording_id: AssertEqualAndRecord.store_recorded_result( tuple(self._recording_id), actual) else: raise return assert_that( pcoll | beam.Map(lambda row: beam.Row(**row._asdict())), matcher)
[docs] def create_test( pipeline_spec, options=None, max_num_inputs=40, min_num_outputs=3): if isinstance(pipeline_spec, str): pipeline_spec = yaml.load(pipeline_spec, Loader=yaml_utils.SafeLineLoader) transform_spec = _preprocess_for_testing(pipeline_spec) if options is None: options = beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle', **yaml_transform.SafeLineLoader.strip_metadata( pipeline_spec.get('options', {}))) def get_name(transform): if 'name' in transform: return str(transform['name']) else: if sum(1 for t in transform_spec['transforms'] if t['type'] == transform['type']) > 1: raise ValueError('Ambiguous unnamed transform {transform["type"]}') return str(transform['type']) input_transforms = [ t for t in transform_spec['transforms'] if t['type'] != 'Create' and not yaml_transform.empty_if_explicitly_empty(t.get('input', [])) ] mock_outputs = [{ 'name': get_name(t), 'elements': [ _try_row_as_dict(row) for row in _first_n(t, options, max_num_inputs) ], } for t in input_transforms] output_transforms = [ t for t in transform_spec['transforms'] if t['type'] == 'LogForTesting' or yaml_transform.empty_if_explicitly_empty(t.get('output', [])) or t['type'].startswith('Write') ] expected_inputs = [{ 'name': get_name(t), 'elements': [], } for t in output_transforms] if not expected_inputs: # TODO: Optionally take this as a parameter. raise ValueError('No output transforms detected.') num_inputs = min_num_outputs while True: test_spec = { 'mock_outputs': [{ 'name': t['name'], 'elements': random.sample( t['elements'], min(len(t['elements']), num_inputs)), } for t in mock_outputs], 'expected_inputs': expected_inputs, } fixes = run_test(pipeline_spec, test_spec, options, fix_failures=True) if len(fixes) < len(output_transforms): actual_output_size = 0 else: actual_output_size = min(len(e) for e in fixes.values()) if actual_output_size >= min_num_outputs: break elif num_inputs == max_num_inputs: break else: num_inputs = min(2 * num_inputs, max_num_inputs) for expected_input in test_spec['expected_inputs']: if ('expected_inputs', expected_input['name']) in fixes: expected_input['elements'] = fixes['expected_inputs', expected_input['name']] return test_spec
class _DoneException(Exception): pass
[docs] class RecordElements(beam.PTransform): _recorded_results = collections.defaultdict(list) def __init__(self, n): self._n = n self._id = str(uuid.uuid4())
[docs] def get_and_remove(self): listing = RecordElements._recorded_results[self._id] del RecordElements._recorded_results[self._id] return listing
[docs] def expand(self, pcoll): def record(element): listing = RecordElements._recorded_results[self._id] if len(listing) < self._n: listing.append(element) else: raise _DoneException() return pcoll | beam.Map(record)
def _first_n(transform_spec, options, n): recorder = RecordElements(n) try: with beam.Pipeline(options=options) as p: _ = ( p | yaml_transform.YamlTransform( transform_spec, providers={'AssertEqualAndRecord': AssertEqualAndRecord}) | recorder) except _DoneException: pass except Exception as exn: # Runners don't always raise a faithful exception type. if not '_DoneException' in str(exn): raise return recorder.get_and_remove() K1 = TypeVar('K1') K2 = TypeVar('K2') V = TypeVar('V') InputsType = Dict[str, Union[str, List[str]]] def _composite_key_to_nested( d: Mapping[Tuple[K1, K2], V]) -> Mapping[K1, Mapping[K2, V]]: nested = collections.defaultdict(dict) for (k1, k2), v in d.items(): nested[k1][k2] = v return nested def _try_row_as_dict(row): try: return row._asdict() except AttributeError: return row # Linter: No need for unittest.main here.