#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
import functools
import json
import random
import unittest
import uuid
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
import yaml
import apache_beam as beam
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform
from apache_beam.yaml import yaml_utils
[docs]
class YamlTestCase(unittest.TestCase):
def __init__(self, pipeline_spec, test_spec, options, fix_tests):
super().__init__()
self._pipeline_spec = pipeline_spec
self._test_spec = test_spec
self._options = options
self._fix_tests = fix_tests
[docs]
def runTest(self):
self._fixes = run_test(
self._pipeline_spec, self._test_spec, self._options, self._fix_tests)
[docs]
def fixed_test(self):
fixed_test_spec = yaml_transform.SafeLineLoader.strip_metadata(
self._test_spec)
if self._fixes:
expectation_by_id = {(loc, expectation['name']): expectation
for loc in ('expected_inputs', 'expected_outputs')
for expectation in fixed_test_spec.get(loc, [])}
for name_loc, values in self._fixes.items():
expectation_by_id[name_loc]['elements'] = sorted(values, key=json.dumps)
return fixed_test_spec
[docs]
def id(self):
return (
self._test_spec.get('name', 'unknown') +
f' (line {yaml_transform.SafeLineLoader.get_line(self._test_spec)})')
def __str__(self):
return self.id()
[docs]
def run_test(pipeline_spec, test_spec, options=None, fix_failures=False):
if isinstance(pipeline_spec, str):
pipeline_spec = yaml.load(pipeline_spec, Loader=yaml_utils.SafeLineLoader)
pipeline_spec = _preprocess_for_testing(pipeline_spec)
transform_spec, recording_ids = inject_test_tranforms(
pipeline_spec,
test_spec,
fix_failures)
allowed_sources = set(test_spec.get('allowed_sources', []) + ['Create'])
for transform in transform_spec['transforms']:
name_or_type = transform.get('name', transform['type'])
if (not yaml_transform.empty_if_explicitly_empty(transform.get('input', []))
and not transform.get('name') in allowed_sources and
not transform['type'] in allowed_sources):
raise ValueError(
f'Non-mocked source {name_or_type} '
f'at line {yaml_transform.SafeLineLoader.get_line(transform)}')
if options is None:
options = beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle',
**yaml_transform.SafeLineLoader.strip_metadata(
pipeline_spec.get('options', {})))
with beam.Pipeline(options=options) as p:
_ = p | yaml_transform.YamlTransform(
transform_spec,
providers={'AssertEqualAndRecord': AssertEqualAndRecord})
if fix_failures:
fixes = {}
for recording_id in recording_ids:
if AssertEqualAndRecord.has_recorded_result(recording_id):
fixes[recording_id[1:]] = [
_try_row_as_dict(row)
for row in AssertEqualAndRecord.get_recorded_result(recording_id)
]
AssertEqualAndRecord.remove_recorded_result(recording_id)
return fixes
def _preprocess_for_testing(pipeline_spec):
spec = yaml_transform.pipeline_as_composite(pipeline_spec['pipeline'])
# These are idempotent, so it's OK to do them preemptively.
for phase in [
yaml_transform.ensure_transforms_have_types,
yaml_transform.preprocess_source_sink,
yaml_transform.preprocess_chain,
yaml_transform.tag_explicit_inputs,
yaml_transform.normalize_inputs_outputs,
]:
spec = yaml_transform.apply_phase(phase, spec)
return spec
[docs]
def validate_test_spec(test_spec):
if not isinstance(test_spec, dict):
raise TypeError(
f'Test specification must be an object, got {type(test_spec)}')
identifier = (
test_spec.get('name', 'unknown') +
f' at line {yaml_transform.SafeLineLoader.get_line(test_spec)}')
if not isinstance(test_spec.get('allowed_sources', []), list):
raise TypeError(
f'allowed_sources of test specification {identifier} '
f'must be a list, got {type(test_spec["allowed_sources"])}')
if (not test_spec.get('expected_outputs', []) and
not test_spec.get('expected_inputs', [])):
raise ValueError(
f'test specification {identifier} '
f'must have at least one expected_outputs or expected_inputs')
unknown_attrs = set(
yaml_transform.SafeLineLoader.strip_metadata(test_spec).keys()) - set([
'name',
'mock_inputs',
'mock_outputs',
'expected_outputs',
'expected_inputs',
'allowed_sources',
])
if unknown_attrs:
raise ValueError(
f'test specification {identifier} '
f'has unknown attributes {list(unknown_attrs)}')
for attr_type in ('mock_inputs',
'mock_outputs',
'expected_outputs',
'expected_inputs'):
attr = test_spec.get(attr_type, [])
if not isinstance(attr, list):
raise TypeError(
f'{attr_type} of test specification {identifier} '
f'must be a list, got {type(attr_type)}')
for ix, attr_item in enumerate(attr):
if not isinstance(attr_item, dict):
raise TypeError(
f'{attr_type} {ix} of test specification {identifier} '
f'must be an object, got {type(attr_item)}')
if 'name' not in attr_item:
raise TypeError(
f'{attr_type} {ix} of test specification {identifier} '
f'missing a name')
if 'elements' not in attr_item:
raise TypeError(
f'{attr_type} {ix} of test specification {identifier} '
f'missing a elements')
if not isinstance(attr_item['elements'], list):
raise TypeError(
f'{attr_type} {ix} of test specification {identifier} '
f'must be a list, got {type(attr_item["elements"])}')
[docs]
class AssertEqualAndRecord(beam.PTransform):
_recorded_results = {}
[docs]
@classmethod
def store_recorded_result(cls, recording_id, value):
assert recording_id not in cls._recorded_results
cls._recorded_results[recording_id] = value
[docs]
@classmethod
def has_recorded_result(cls, recording_id):
return recording_id in cls._recorded_results
[docs]
@classmethod
def get_recorded_result(cls, recording_id):
return cls._recorded_results[recording_id]
[docs]
@classmethod
def remove_recorded_result(cls, recording_id):
del cls._recorded_results[recording_id]
def __init__(self, elements, recording_id):
self._elements = elements
self._recording_id = recording_id
[docs]
def expand(self, pcoll):
equal_to_matcher = equal_to(yaml_provider.dicts_to_rows(self._elements))
def matcher(actual):
try:
equal_to_matcher(actual)
except Exception:
if self._recording_id:
AssertEqualAndRecord.store_recorded_result(
tuple(self._recording_id), actual)
else:
raise
return assert_that(
pcoll | beam.Map(lambda row: beam.Row(**row._asdict())), matcher)
[docs]
def create_test(
pipeline_spec, options=None, max_num_inputs=40, min_num_outputs=3):
if isinstance(pipeline_spec, str):
pipeline_spec = yaml.load(pipeline_spec, Loader=yaml_utils.SafeLineLoader)
transform_spec = _preprocess_for_testing(pipeline_spec)
if options is None:
options = beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle',
**yaml_transform.SafeLineLoader.strip_metadata(
pipeline_spec.get('options', {})))
def get_name(transform):
if 'name' in transform:
return str(transform['name'])
else:
if sum(1 for t in transform_spec['transforms']
if t['type'] == transform['type']) > 1:
raise ValueError('Ambiguous unnamed transform {transform["type"]}')
return str(transform['type'])
input_transforms = [
t for t in transform_spec['transforms'] if t['type'] != 'Create' and
not yaml_transform.empty_if_explicitly_empty(t.get('input', []))
]
mock_outputs = [{
'name': get_name(t),
'elements': [
_try_row_as_dict(row) for row in _first_n(t, options, max_num_inputs)
],
} for t in input_transforms]
output_transforms = [
t for t in transform_spec['transforms'] if t['type'] == 'LogForTesting' or
yaml_transform.empty_if_explicitly_empty(t.get('output', [])) or
t['type'].startswith('Write')
]
expected_inputs = [{
'name': get_name(t),
'elements': [],
} for t in output_transforms]
if not expected_inputs:
# TODO: Optionally take this as a parameter.
raise ValueError('No output transforms detected.')
num_inputs = min_num_outputs
while True:
test_spec = {
'mock_outputs': [{
'name': t['name'],
'elements': random.sample(
t['elements'], min(len(t['elements']), num_inputs)),
} for t in mock_outputs],
'expected_inputs': expected_inputs,
}
fixes = run_test(pipeline_spec, test_spec, options, fix_failures=True)
if len(fixes) < len(output_transforms):
actual_output_size = 0
else:
actual_output_size = min(len(e) for e in fixes.values())
if actual_output_size >= min_num_outputs:
break
elif num_inputs == max_num_inputs:
break
else:
num_inputs = min(2 * num_inputs, max_num_inputs)
for expected_input in test_spec['expected_inputs']:
if ('expected_inputs', expected_input['name']) in fixes:
expected_input['elements'] = fixes['expected_inputs',
expected_input['name']]
return test_spec
class _DoneException(Exception):
pass
[docs]
class RecordElements(beam.PTransform):
_recorded_results = collections.defaultdict(list)
def __init__(self, n):
self._n = n
self._id = str(uuid.uuid4())
[docs]
def get_and_remove(self):
listing = RecordElements._recorded_results[self._id]
del RecordElements._recorded_results[self._id]
return listing
[docs]
def expand(self, pcoll):
def record(element):
listing = RecordElements._recorded_results[self._id]
if len(listing) < self._n:
listing.append(element)
else:
raise _DoneException()
return pcoll | beam.Map(record)
def _first_n(transform_spec, options, n):
recorder = RecordElements(n)
try:
with beam.Pipeline(options=options) as p:
_ = (
p
| yaml_transform.YamlTransform(
transform_spec,
providers={'AssertEqualAndRecord': AssertEqualAndRecord})
| recorder)
except _DoneException:
pass
except Exception as exn:
# Runners don't always raise a faithful exception type.
if not '_DoneException' in str(exn):
raise
return recorder.get_and_remove()
K1 = TypeVar('K1')
K2 = TypeVar('K2')
V = TypeVar('V')
InputsType = Dict[str, Union[str, List[str]]]
def _composite_key_to_nested(
d: Mapping[Tuple[K1, K2], V]) -> Mapping[K1, Mapping[K2, V]]:
nested = collections.defaultdict(dict)
for (k1, k2), v in d.items():
nested[k1][k2] = v
return nested
def _try_row_as_dict(row):
try:
return row._asdict()
except AttributeError:
return row
# Linter: No need for unittest.main here.