#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Helper functions and test harnesses for source implementations.
This module contains helper functions and test harnesses for checking
correctness of source (a subclass of ``iobase.BoundedSource``) and range
tracker (a subclass of``iobase.RangeTracker``) implementations.
Contains a few lightweight utilities (e.g. reading items from a source such as
``readFromSource()``, as well as heavyweight property testing and stress
testing harnesses that help getting a large amount of test coverage with few
code.
Most notable ones are:
* ``assertSourcesEqualReferenceSource()`` helps testing that the data read by
the union of sources produced by ``BoundedSource.split()`` is the same as data
read by the original source.
* If your source implements dynamic work rebalancing, use the
``assertSplitAtFraction()`` family of functions - they test behavior of
``RangeTracker.try_split()``, in particular, that various consistency
properties are respected and the total set of data read by the source is
preserved when splits happen. Use ``assertSplitAtFractionBehavior()`` to test
individual cases of ``RangeTracker.try_split()`` and use
``assertSplitAtFractionExhaustive()`` as a heavy-weight stress test including
concurrency. We strongly recommend to use both.
For example usages, see the unit tests of modules such as
* apache_beam.io.source_test_utils_test.py
* apache_beam.io.avroio_test.py
"""
# pytype: skip-file
import logging
import threading
import weakref
from collections import namedtuple
from multiprocessing.pool import ThreadPool
from apache_beam.io import iobase
from apache_beam.testing.util import equal_to
__all__ = [
'read_from_source',
'assert_sources_equal_reference_source',
'assert_reentrant_reads_succeed',
'assert_split_at_fraction_behavior',
'assert_split_at_fraction_binary',
'assert_split_at_fraction_exhaustive',
'assert_split_at_fraction_fails',
'assert_split_at_fraction_succeeds_and_consistent'
]
_LOGGER = logging.getLogger(__name__)
class ExpectedSplitOutcome(object):
MUST_SUCCEED_AND_BE_CONSISTENT = 1
MUST_FAIL = 2
MUST_BE_CONSISTENT_IF_SUCCEEDS = 3
SplitAtFractionResult = namedtuple(
'SplitAtFractionResult', 'num_primary_items num_residual_items')
SplitFractionStatistics = namedtuple(
'SplitFractionStatistics', 'successful_fractions non_trivial_fractions')
[docs]def read_from_source(source, start_position=None, stop_position=None):
"""Reads elements from the given ```BoundedSource```.
Only reads elements within the given position range.
Args:
source (~apache_beam.io.iobase.BoundedSource):
:class:`~apache_beam.io.iobase.BoundedSource` implementation.
start_position (int): start position for reading.
stop_position (int): stop position for reading.
Returns:
List[str]: the set of values read from the sources.
"""
values = []
range_tracker = source.get_range_tracker(start_position, stop_position)
assert isinstance(range_tracker, iobase.RangeTracker)
reader = source.read(range_tracker)
for value in reader:
values.append(value)
return values
def _ThreadPool(threads):
# ThreadPool crashes in old versions of Python (< 2.7.5) if created from a
# child thread. (http://bugs.python.org/issue10015)
if not hasattr(threading.current_thread(), '_children'):
threading.current_thread()._children = weakref.WeakKeyDictionary()
return ThreadPool(threads)
[docs]def assert_sources_equal_reference_source(reference_source_info, sources_info):
"""Tests if a reference source is equal to a given set of sources.
Given a reference source (a :class:`~apache_beam.io.iobase.BoundedSource`
and a position range) and a list of sources, assert that the union of the
records read from the list of sources is equal to the records read from the
reference source.
Args:
reference_source_info\
(Tuple[~apache_beam.io.iobase.BoundedSource, int, int]):
a three-tuple that gives the reference
:class:`~apache_beam.io.iobase.BoundedSource`, position to start
reading at, and position to stop reading at.
sources_info\
(Iterable[Tuple[~apache_beam.io.iobase.BoundedSource, int, int]]):
a set of sources. Each source is a three-tuple that is of the same
format described above.
Raises:
ValueError: if the set of data produced by the reference source
and the given set of sources are not equivalent.
"""
if not (isinstance(reference_source_info, tuple) and
len(reference_source_info) == 3 and
isinstance(reference_source_info[0], iobase.BoundedSource)):
raise ValueError(
'reference_source_info must a three-tuple where first'
'item of the tuple gives a '
'iobase.BoundedSource. Received: %r' % reference_source_info)
reference_records = read_from_source(*reference_source_info)
source_records = []
for source_info in sources_info:
assert isinstance(source_info, tuple)
assert len(source_info) == 3
if not (isinstance(source_info, tuple) and len(source_info) == 3 and
isinstance(source_info[0], iobase.BoundedSource)):
raise ValueError(
'source_info must a three tuple where first'
'item of the tuple gives a '
'iobase.BoundedSource. Received: %r' % source_info)
if (type(reference_source_info[0].default_output_coder()) != type(
source_info[0].default_output_coder())):
raise ValueError(
'Reference source %r and the source %r must use the same coder. '
'They are using %r and %r respectively instead.' % (
reference_source_info[0],
source_info[0],
type(reference_source_info[0].default_output_coder()),
type(source_info[0].default_output_coder())))
source_records.extend(read_from_source(*source_info))
if len(reference_records) != len(source_records):
raise ValueError(
'Reference source must produce the same number of records as the '
'list of sources. Number of records were %d and %d instead.' %
(len(reference_records), len(source_records)))
if equal_to(reference_records)(source_records):
raise ValueError(
'Reference source and provided list of sources must produce the '
'same set of records.')
[docs]def assert_reentrant_reads_succeed(source_info):
"""Tests if a given source can be read in a reentrant manner.
Assume that given source produces the set of values ``{v1, v2, v3, ... vn}``.
For ``i`` in range ``[1, n-1]`` this method performs a reentrant read after
reading ``i`` elements and verifies that both the original and reentrant read
produce the expected set of values.
Args:
source_info (Tuple[~apache_beam.io.iobase.BoundedSource, int, int]):
a three-tuple that gives the reference
:class:`~apache_beam.io.iobase.BoundedSource`, position to start reading
at, and a position to stop reading at.
Raises:
ValueError: if source is too trivial or reentrant read result
in an incorrect read.
"""
source, start_position, stop_position = source_info
assert isinstance(source, iobase.BoundedSource)
expected_values = [
val for val in source.read(
source.get_range_tracker(start_position, stop_position))
]
if len(expected_values) < 2:
raise ValueError(
'Source is too trivial since it produces only %d '
'values. Please give a source that reads at least 2 '
'values.' % len(expected_values))
for i in range(1, len(expected_values) - 1):
read_iter = source.read(
source.get_range_tracker(start_position, stop_position))
original_read = []
for _ in range(i):
original_read.append(next(read_iter))
# Reentrant read
reentrant_read = [
val for val in source.read(
source.get_range_tracker(start_position, stop_position))
]
# Continuing original read.
for val in read_iter:
original_read.append(val)
if equal_to(original_read)(expected_values):
raise ValueError(
'Source did not produce expected values when '
'performing a reentrant read after reading %d values. '
'Expected %r received %r.' % (i, expected_values, original_read))
if equal_to(reentrant_read)(expected_values):
raise ValueError(
'A reentrant read of source after reading %d values '
'did not produce expected values. Expected %r '
'received %r.' % (i, expected_values, reentrant_read))
[docs]def assert_split_at_fraction_behavior(
source, num_items_to_read_before_split, split_fraction, expected_outcome):
"""Verifies the behaviour of splitting a source at a given fraction.
Asserts that splitting a :class:`~apache_beam.io.iobase.BoundedSource` either
fails after reading **num_items_to_read_before_split** items, or succeeds in
a way that is consistent according to
:func:`assert_split_at_fraction_succeeds_and_consistent()`.
Args:
source (~apache_beam.io.iobase.BoundedSource): the source to perform
dynamic splitting on.
num_items_to_read_before_split (int): number of items to read before
splitting.
split_fraction (float): fraction to split at.
expected_outcome (int): a value from
:class:`~apache_beam.io.source_test_utils.ExpectedSplitOutcome`.
Returns:
Tuple[int, int]: a tuple that gives the number of items produced by reading
the two ranges produced after dynamic splitting. If splitting did not
occur, the first value of the tuple will represent the full set of records
read by the source while the second value of the tuple will be ``-1``.
"""
assert isinstance(source, iobase.BoundedSource)
expected_items = read_from_source(source, None, None)
return _assert_split_at_fraction_behavior(
source,
expected_items,
num_items_to_read_before_split,
split_fraction,
expected_outcome)
def _assert_split_at_fraction_behavior(
source,
expected_items,
num_items_to_read_before_split,
split_fraction,
expected_outcome,
start_position=None,
stop_position=None):
range_tracker = source.get_range_tracker(start_position, stop_position)
assert isinstance(range_tracker, iobase.RangeTracker)
current_items = []
reader = source.read(range_tracker)
# Reading 'num_items_to_read_before_split' items.
reader_iter = iter(reader)
for _ in range(num_items_to_read_before_split):
current_items.append(next(reader_iter))
suggested_split_position = range_tracker.position_at_fraction(split_fraction)
stop_position_before_split = range_tracker.stop_position()
split_result = range_tracker.try_split(suggested_split_position)
if split_result is not None:
if len(split_result) != 2:
raise ValueError(
'Split result must be a tuple that contains split '
'position and split fraction. Received: %r' % (split_result, ))
if range_tracker.stop_position() != split_result[0]:
raise ValueError(
'After a successful split, the stop position of the '
'RangeTracker must be the same as the returned split '
'position. Observed %r and %r which are different.' %
(range_tracker.stop_position() % (split_result[0], )))
if split_fraction < 0 or split_fraction > 1:
raise ValueError(
'Split fraction must be within the range [0,1]',
'Observed split fraction was %r.' % (split_result[1], ))
stop_position_after_split = range_tracker.stop_position()
if split_result and stop_position_after_split == stop_position_before_split:
raise ValueError(
'Stop position %r did not change after a successful '
'split of source %r at fraction %r.' %
(stop_position_before_split, source, split_fraction))
if expected_outcome == ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT:
if not split_result:
raise ValueError(
'Expected split of source %r at fraction %r to be '
'successful after reading %d elements. But '
'the split failed.' %
(source, split_fraction, num_items_to_read_before_split))
elif expected_outcome == ExpectedSplitOutcome.MUST_FAIL:
if split_result:
raise ValueError(
'Expected split of source %r at fraction %r after '
'reading %d elements to fail. But splitting '
'succeeded with result %r.' % (
source,
split_fraction,
num_items_to_read_before_split,
split_result))
elif (
expected_outcome != ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS):
raise ValueError('Unknown type of expected outcome: %r' % expected_outcome)
current_items.extend([value for value in reader_iter])
residual_range = (
split_result[0], stop_position_before_split) if split_result else None
return _verify_single_split_fraction_result(
source,
expected_items,
current_items,
split_result,
(range_tracker.start_position(), range_tracker.stop_position()),
residual_range,
split_fraction)
def _range_to_str(start, stop):
return '[' + (str(start) + ',' + str(stop) + ')')
def _verify_single_split_fraction_result(
source,
expected_items,
current_items,
split_successful,
primary_range,
residual_range,
split_fraction):
assert primary_range
primary_items = read_from_source(source, *primary_range)
if not split_successful:
# For unsuccessful splits, residual_range should be None.
assert not residual_range
residual_items = (
read_from_source(source, *residual_range) if split_successful else [])
total_items = primary_items + residual_items
if current_items != primary_items:
raise ValueError(
'Current source %r and a source created using the '
'range of the primary source %r determined '
'by performing dynamic work rebalancing at fraction '
'%r produced different values. Expected '
'these sources to produce the same list of values.' %
(source, _range_to_str(*primary_range), split_fraction))
if expected_items != total_items:
raise ValueError(
'Items obtained by reading the source %r for primary '
'and residual ranges %s and %s did not produce the '
'expected list of values.' %
(source, _range_to_str(*primary_range), _range_to_str(*residual_range)))
result = (len(primary_items), len(residual_items) if split_successful else -1)
return result
[docs]def assert_split_at_fraction_succeeds_and_consistent(
source, num_items_to_read_before_split, split_fraction):
"""Verifies some consistency properties of dynamic work rebalancing.
Equivalent to the following pseudocode:::
original_range_tracker = source.getRangeTracker(None, None)
original_reader = source.read(original_range_tracker)
items_before_split = read N items from original_reader
suggested_split_position = original_range_tracker.position_for_fraction(
split_fraction)
original_stop_position - original_range_tracker.stop_position()
split_result = range_tracker.try_split()
split_position, split_fraction = split_result
primary_range_tracker = source.get_range_tracker(
original_range_tracker.start_position(), split_position)
residual_range_tracker = source.get_range_tracker(split_position,
original_stop_position)
assert that: items when reading source.read(primary_range_tracker) ==
items_before_split + items from continuing to read 'original_reader'
assert that: items when reading source.read(original_range_tracker) =
items when reading source.read(primary_range_tracker) + items when reading
source.read(residual_range_tracker)
Args:
source: source to perform dynamic work rebalancing on.
num_items_to_read_before_split: number of items to read before splitting.
split_fraction: fraction to split at.
"""
assert_split_at_fraction_behavior(
source,
num_items_to_read_before_split,
split_fraction,
ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT)
[docs]def assert_split_at_fraction_fails(
source, num_items_to_read_before_split, split_fraction):
"""Asserts that dynamic work rebalancing at a given fraction fails.
Asserts that trying to perform dynamic splitting after reading
'num_items_to_read_before_split' items from the source fails.
Args:
source: source to perform dynamic splitting on.
num_items_to_read_before_split: number of items to read before splitting.
split_fraction: fraction to split at.
"""
assert_split_at_fraction_behavior(
source,
num_items_to_read_before_split,
split_fraction,
ExpectedSplitOutcome.MUST_FAIL)
[docs]def assert_split_at_fraction_binary(
source,
expected_items,
num_items_to_read_before_split,
left_fraction,
left_result,
right_fraction,
right_result,
stats,
start_position=None,
stop_position=None):
"""Performs dynamic work rebalancing for fractions within a given range.
Asserts that given a start position, a source can be split at every
interesting fraction (halfway between two fractions that differ by at
least one item) and the results are consistent if a split succeeds.
Args:
source: source to perform dynamic splitting on.
expected_items: total set of items expected when reading the source.
num_items_to_read_before_split: number of items to read before splitting.
left_fraction: left fraction for binary splitting.
left_result: result received by splitting at left fraction.
right_fraction: right fraction for binary splitting.
right_result: result received by splitting at right fraction.
stats: a ``SplitFractionStatistics`` for storing results.
"""
assert right_fraction > left_fraction
if right_fraction - left_fraction < 0.001:
# This prevents infinite recursion.
return
middle_fraction = (left_fraction + right_fraction) / 2
if left_result is None:
left_result = _assert_split_at_fraction_behavior(
source,
expected_items,
num_items_to_read_before_split,
left_fraction,
ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS)
if right_result is None:
right_result = _assert_split_at_fraction_behavior(
source,
expected_items,
num_items_to_read_before_split,
right_fraction,
ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS)
middle_result = _assert_split_at_fraction_behavior(
source,
expected_items,
num_items_to_read_before_split,
middle_fraction,
ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS)
if middle_result[1] != -1:
stats.successful_fractions.append(middle_fraction)
if middle_result[1] > 0:
stats.non_trivial_fractions.append(middle_fraction)
# Two split results are equivalent if primary and residual ranges of them
# produce the same number of records (simply checking the size of primary
# enough since the total number of records is constant).
if left_result[0] != middle_result[0]:
assert_split_at_fraction_binary(
source,
expected_items,
num_items_to_read_before_split,
left_fraction,
left_result,
middle_fraction,
middle_result,
stats)
# We special case right_fraction=1.0 since that could fail due to being out
# of range. (even if a dynamic split fails at 'middle_fraction' and at
# fraction 1.0, there might be fractions in range ('middle_fraction', 1.0)
# where dynamic splitting succeeds).
if right_fraction == 1.0 or middle_result[0] != right_result[0]:
assert_split_at_fraction_binary(
source,
expected_items,
num_items_to_read_before_split,
middle_fraction,
middle_result,
right_fraction,
right_result,
stats)
MAX_CONCURRENT_SPLITTING_TRIALS_PER_ITEM = 100
MAX_CONCURRENT_SPLITTING_TRIALS_TOTAL = 1000
[docs]def assert_split_at_fraction_exhaustive(
source,
start_position=None,
stop_position=None,
perform_multi_threaded_test=True):
"""Performs and tests dynamic work rebalancing exhaustively.
Asserts that for each possible start position, a source can be split at
every interesting fraction (halfway between two fractions that differ by at
least one item) and the results are consistent if a split succeeds.
Verifies multi threaded splitting as well.
Args:
source (~apache_beam.io.iobase.BoundedSource): the source to perform
dynamic splitting on.
perform_multi_threaded_test (bool): if :data:`True` performs a
multi-threaded test, otherwise this test is skipped.
Raises:
ValueError: if the exhaustive splitting test fails.
"""
expected_items = read_from_source(source, start_position, stop_position)
if not expected_items:
raise ValueError('Source %r is empty.' % source)
if len(expected_items) == 1:
raise ValueError('Source %r only reads a single item.' % source)
all_non_trivial_fractions = []
any_successful_fractions = False
any_non_trivial_fractions = False
for i in range(len(expected_items)):
stats = SplitFractionStatistics([], [])
assert_split_at_fraction_binary(
source, expected_items, i, 0.0, None, 1.0, None, stats)
if stats.successful_fractions:
any_successful_fractions = True
if stats.non_trivial_fractions:
any_non_trivial_fractions = True
all_non_trivial_fractions.append(stats.non_trivial_fractions)
if not any_successful_fractions:
raise ValueError(
'SplitAtFraction test completed vacuously: no '
'successful split fractions found')
if not any_non_trivial_fractions:
raise ValueError(
'SplitAtFraction test completed vacuously: no non-trivial split '
'fractions found')
if not perform_multi_threaded_test:
return
num_total_trials = 0
for i in range(len(expected_items)):
non_trivial_fractions = [2.0] # 2.0 is larger than any valid fraction.
non_trivial_fractions.extend(all_non_trivial_fractions[i])
min_non_trivial_fraction = min(non_trivial_fractions)
if min_non_trivial_fraction == 2.0:
# This will not happen all the time. Otherwise previous test will fail
# due to vacuousness.
continue
num_trials = 0
have_success = False
have_failure = False
thread_pool = _ThreadPool(2)
try:
while True:
num_trials += 1
if num_trials > MAX_CONCURRENT_SPLITTING_TRIALS_PER_ITEM:
_LOGGER.warning(
'After %d concurrent splitting trials at item #%d, observed '
'only %s, giving up on this item',
num_trials,
i,
'success' if have_success else 'failure')
break
if _assert_split_at_fraction_concurrent(source,
expected_items,
i,
min_non_trivial_fraction,
thread_pool):
have_success = True
else:
have_failure = True
if have_success and have_failure:
_LOGGER.info(
'%d trials to observe both success and failure of '
'concurrent splitting at item #%d',
num_trials,
i)
break
finally:
thread_pool.close()
num_total_trials += num_trials
if num_total_trials > MAX_CONCURRENT_SPLITTING_TRIALS_TOTAL:
_LOGGER.warning(
'After %d total concurrent splitting trials, considered '
'only %d items, giving up.',
num_total_trials,
i)
break
_LOGGER.info(
'%d total concurrent splitting trials for %d items',
num_total_trials,
len(expected_items))
def _assert_split_at_fraction_concurrent(
source,
expected_items,
num_items_to_read_before_splitting,
split_fraction,
thread_pool=None):
range_tracker = source.get_range_tracker(None, None)
stop_position_before_split = range_tracker.stop_position()
reader = source.read(range_tracker)
reader_iter = iter(reader)
current_items = []
for _ in range(num_items_to_read_before_splitting):
current_items.append(next(reader_iter))
def read_or_split(test_params):
if test_params[0]:
return [val for val in test_params[1]]
else:
position = test_params[1].position_at_fraction(test_params[2])
result = test_params[1].try_split(position)
return result
inputs = []
pool = thread_pool if thread_pool else _ThreadPool(2)
try:
inputs.append([True, reader_iter])
inputs.append([False, range_tracker, split_fraction])
results = pool.map(read_or_split, inputs)
finally:
if not thread_pool:
pool.close()
current_items.extend(results[0])
primary_range = (
range_tracker.start_position(), range_tracker.stop_position())
split_result = results[1]
residual_range = (
split_result[0], stop_position_before_split) if split_result else None
res = _verify_single_split_fraction_result(
source,
expected_items,
current_items,
split_result,
primary_range,
residual_range,
split_fraction)
return res[1] > 0