Source code for apache_beam.runners.interactive.testing.pipeline_assertion

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module to verify implicit cache transforms applied by Interactive Beam.

For internal use only; no backwards-compatibility guarantees.
This utility should only be used by Interactive Beam tests. For example, it can
be used to verify if the implicit cache transforms are applied as expected when
running a pipeline with the InteractiveRunner. It can also be used to verify if
a pipeline fragment has pruned unnecessary transforms. It shouldn't be used to
verify equivalence between pipelines if the code to be tested depends on or
mutates user code within transforms in pipelines.

[docs]def assert_pipeline_equal(test_case, expected_pipeline, actual_pipeline): """Asserts the equivalence between two given apache_beam.Pipeline instances. Args: test_case: (unittest.TestCase) the unittest testcase where it asserts. expected_pipeline: (Pipeline) the pipeline instance expected. actual_pipeline: (Pipeline) the actual pipeline instance to be asserted. """ expected_pipeline_proto = expected_pipeline.to_runner_api( use_fake_coders=True) actual_pipeline_proto = actual_pipeline.to_runner_api(use_fake_coders=True) assert_pipeline_proto_equal( test_case, expected_pipeline_proto, actual_pipeline_proto)
[docs]def assert_pipeline_proto_equal( test_case, expected_pipeline_proto, actual_pipeline_proto): """Asserts the equivalence between two pipeline proto representations.""" components1 = expected_pipeline_proto.components components2 = actual_pipeline_proto.components test_case.assertEqual( len(components1.transforms), len(components2.transforms)) test_case.assertEqual( len(components1.pcollections), len(components2.pcollections)) # TODO(BEAM-7926): Update tests and make below 2 assertions assertEqual. test_case.assertLessEqual( len(components1.windowing_strategies), len(components2.windowing_strategies)) test_case.assertLessEqual(len(components1.coders), len(components2.coders)) _assert_transform_equal( test_case, actual_pipeline_proto, actual_pipeline_proto.root_transform_ids[0], expected_pipeline_proto, expected_pipeline_proto.root_transform_ids[0])
[docs]def assert_pipeline_proto_contain_top_level_transform( test_case, pipeline_proto, transform_label): """Asserts the top level transforms contain a transform with the given transform label.""" _assert_pipeline_proto_contains_top_level_transform( test_case, pipeline_proto, transform_label, True)
[docs]def assert_pipeline_proto_not_contain_top_level_transform( test_case, pipeline_proto, transform_label): """Asserts the top level transforms do not contain a transform with the given transform label.""" _assert_pipeline_proto_contains_top_level_transform( test_case, pipeline_proto, transform_label, False)
def _assert_pipeline_proto_contains_top_level_transform( test_case, pipeline_proto, transform_label, contain): top_level_transform_labels = pipeline_proto.components.transforms[ pipeline_proto.root_transform_ids[0]].subtransforms test_case.assertEqual( contain, any([ transform_label in top_level_transform_label for top_level_transform_label in top_level_transform_labels ])) def _assert_transform_equal( test_case, expected_pipeline_proto, expected_transform_id, actual_pipeline_proto, actual_transform_id): """Asserts the equivalence between transforms from two given pipelines. """ transform_proto1 = expected_pipeline_proto.components.transforms[ expected_transform_id] transform_proto2 = actual_pipeline_proto.components.transforms[ actual_transform_id] test_case.assertEqual(transform_proto1.spec.urn, transform_proto2.spec.urn) # Skipping payload checking because PTransforms of the same functionality # could generate different payloads. test_case.assertEqual( len(transform_proto1.subtransforms), len(transform_proto2.subtransforms)) test_case.assertSetEqual( set(transform_proto1.inputs), set(transform_proto2.inputs)) test_case.assertSetEqual( set(transform_proto1.outputs), set(transform_proto2.outputs))