Source code for apache_beam.coders.coders_test_common

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Tests common to all coder implementations."""

import logging
import math
import unittest

import dill

from apache_beam.transforms.window import GlobalWindow
from apache_beam.utils.timestamp import MIN_TIMESTAMP
import observable
from apache_beam.transforms import window
from apache_beam.utils import timestamp
from apache_beam.utils import windowed_value

from apache_beam.coders import coders
from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message


# Defined out of line for picklability.
[docs]class CustomCoder(coders.Coder):
[docs] def encode(self, x): return str(x+1)
[docs] def decode(self, encoded): return int(encoded) - 1
[docs]class CodersTest(unittest.TestCase): # These class methods ensure that we test each defined coder in both # nested and unnested context. @classmethod
[docs] def setUpClass(cls): cls.seen = set() cls.seen_nested = set()
@classmethod
[docs] def tearDownClass(cls): standard = set(c for c in coders.__dict__.values() if isinstance(c, type) and issubclass(c, coders.Coder) and 'Base' not in c.__name__) standard -= set([coders.Coder, coders.FastCoder, coders.ProtoCoder, coders.ToStringCoder]) assert not standard - cls.seen, standard - cls.seen assert not standard - cls.seen_nested, standard - cls.seen_nested
@classmethod def _observe(cls, coder): cls.seen.add(type(coder)) cls._observe_nested(coder) @classmethod def _observe_nested(cls, coder): if isinstance(coder, coders.TupleCoder): for c in coder.coders(): cls.seen_nested.add(type(c)) cls._observe_nested(c)
[docs] def check_coder(self, coder, *values): self._observe(coder) for v in values: self.assertEqual(v, coder.decode(coder.encode(v))) self.assertEqual(coder.estimate_size(v), len(coder.encode(v))) self.assertEqual(coder.estimate_size(v), coder.get_impl().estimate_size(v)) self.assertEqual(coder.get_impl().get_estimated_size_and_observables(v), (coder.get_impl().estimate_size(v), [])) copy1 = dill.loads(dill.dumps(coder)) copy2 = dill.loads(dill.dumps(coder)) for v in values: self.assertEqual(v, copy1.decode(copy2.encode(v))) if coder.is_deterministic(): self.assertEqual(copy1.encode(v), copy2.encode(v))
[docs] def test_custom_coder(self): self.check_coder(CustomCoder(), 1, -10, 5) self.check_coder(coders.TupleCoder((CustomCoder(), coders.BytesCoder())), (1, 'a'), (-10, 'b'), (5, 'c'))
[docs] def test_pickle_coder(self): self.check_coder(coders.PickleCoder(), 'a', 1, 1.5, (1, 2, 3))
[docs] def test_deterministic_coder(self): coder = coders.FastPrimitivesCoder() deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 'step') self.check_coder(deterministic_coder, 'a', 1, 1.5, (1, 2, 3)) with self.assertRaises(TypeError): self.check_coder(deterministic_coder, dict()) with self.assertRaises(TypeError): self.check_coder(deterministic_coder, [1, dict()]) self.check_coder(coders.TupleCoder((deterministic_coder, coder)), (1, dict()), ('a', [dict()]))
[docs] def test_dill_coder(self): cell_value = (lambda x: lambda: x)(0).func_closure[0] self.check_coder(coders.DillCoder(), 'a', 1, cell_value) self.check_coder( coders.TupleCoder((coders.VarIntCoder(), coders.DillCoder())), (1, cell_value))
[docs] def test_fast_primitives_coder(self): coder = coders.FastPrimitivesCoder(coders.SingletonCoder(len)) self.check_coder(coder, None, 1, -1, 1.5, 'str\0str', u'unicode\0\u0101') self.check_coder(coder, (), (1, 2, 3)) self.check_coder(coder, [], [1, 2, 3]) self.check_coder(coder, dict(), {'a': 'b'}, {0: dict(), 1: len}) self.check_coder(coder, set(), {'a', 'b'}) self.check_coder(coder, True, False) self.check_coder(coder, len) self.check_coder(coders.TupleCoder((coder,)), ('a',), (1,))
[docs] def test_bytes_coder(self): self.check_coder(coders.BytesCoder(), 'a', '\0', 'z' * 1000)
[docs] def test_varint_coder(self): # Small ints. self.check_coder(coders.VarIntCoder(), *range(-10, 10)) # Multi-byte encoding starts at 128 self.check_coder(coders.VarIntCoder(), *range(120, 140)) # Large values MAX_64_BIT_INT = 0x7fffffffffffffff self.check_coder(coders.VarIntCoder(), *[int(math.pow(-1, k) * math.exp(k)) for k in range(0, int(math.log(MAX_64_BIT_INT)))])
[docs] def test_float_coder(self): self.check_coder(coders.FloatCoder(), *[float(0.1 * x) for x in range(-100, 100)]) self.check_coder(coders.FloatCoder(), *[float(2 ** (0.1 * x)) for x in range(-100, 100)]) self.check_coder(coders.FloatCoder(), float('-Inf'), float('Inf')) self.check_coder( coders.TupleCoder((coders.FloatCoder(), coders.FloatCoder())), (0, 1), (-100, 100), (0.5, 0.25))
[docs] def test_singleton_coder(self): a = 'anything' b = 'something else' self.check_coder(coders.SingletonCoder(a), a) self.check_coder(coders.SingletonCoder(b), b) self.check_coder(coders.TupleCoder((coders.SingletonCoder(a), coders.SingletonCoder(b))), (a, b))
[docs] def test_interval_window_coder(self): self.check_coder(coders.IntervalWindowCoder(), *[window.IntervalWindow(x, y) for x in [-2**52, 0, 2**52] for y in range(-100, 100)]) self.check_coder( coders.TupleCoder((coders.IntervalWindowCoder(),)), (window.IntervalWindow(0, 10),))
[docs] def test_timestamp_coder(self): self.check_coder(coders.TimestampCoder(), *[timestamp.Timestamp(micros=x) for x in range(-100, 100)]) self.check_coder(coders.TimestampCoder(), timestamp.Timestamp(micros=-1234567890), timestamp.Timestamp(micros=1234567890)) self.check_coder(coders.TimestampCoder(), timestamp.Timestamp(micros=-1234567890123456789), timestamp.Timestamp(micros=1234567890123456789)) self.check_coder( coders.TupleCoder((coders.TimestampCoder(), coders.BytesCoder())), (timestamp.Timestamp.of(27), 'abc'))
[docs] def test_tuple_coder(self): kv_coder = coders.TupleCoder((coders.VarIntCoder(), coders.BytesCoder())) # Verify cloud object representation self.assertEqual( { '@type': 'kind:pair', 'is_pair_like': True, 'component_encodings': [ coders.VarIntCoder().as_cloud_object(), coders.BytesCoder().as_cloud_object()], }, kv_coder.as_cloud_object()) # Test binary representation self.assertEqual( '\x04abc', kv_coder.encode((4, 'abc'))) # Test unnested self.check_coder( kv_coder, (1, 'a'), (-2, 'a' * 100), (300, 'abc\0' * 5)) # Test nested self.check_coder( coders.TupleCoder( (coders.TupleCoder((coders.PickleCoder(), coders.VarIntCoder())), coders.StrUtf8Coder())), ((1, 2), 'a'), ((-2, 5), u'a\u0101' * 100), ((300, 1), 'abc\0' * 5))
[docs] def test_tuple_sequence_coder(self): int_tuple_coder = coders.TupleSequenceCoder(coders.VarIntCoder()) self.check_coder(int_tuple_coder, (1, -1, 0), (), tuple(range(1000))) self.check_coder( coders.TupleCoder((coders.VarIntCoder(), int_tuple_coder)), (1, (1, 2, 3)))
[docs] def test_base64_pickle_coder(self): self.check_coder(coders.Base64PickleCoder(), 'a', 1, 1.5, (1, 2, 3))
[docs] def test_utf8_coder(self): self.check_coder(coders.StrUtf8Coder(), 'a', u'ab\u00FF', u'\u0101\0')
[docs] def test_iterable_coder(self): iterable_coder = coders.IterableCoder(coders.VarIntCoder()) # Verify cloud object representation self.assertEqual( { '@type': 'kind:stream', 'is_stream_like': True, 'component_encodings': [coders.VarIntCoder().as_cloud_object()] }, iterable_coder.as_cloud_object()) # Test unnested self.check_coder(iterable_coder, [1], [-1, 0, 100]) # Test nested self.check_coder( coders.TupleCoder((coders.VarIntCoder(), coders.IterableCoder(coders.VarIntCoder()))), (1, [1, 2, 3]))
[docs] def test_iterable_coder_unknown_length(self): # Empty self._test_iterable_coder_of_unknown_length(0) # Single element self._test_iterable_coder_of_unknown_length(1) # Multiple elements self._test_iterable_coder_of_unknown_length(100) # Multiple elements with underlying stream buffer overflow. self._test_iterable_coder_of_unknown_length(80000)
def _test_iterable_coder_of_unknown_length(self, count): def iter_generator(count): for i in range(count): yield i iterable_coder = coders.IterableCoder(coders.VarIntCoder()) self.assertItemsEqual(list(iter_generator(count)), iterable_coder.decode( iterable_coder.encode(iter_generator(count))))
[docs] def test_windowed_value_coder(self): coder = coders.WindowedValueCoder(coders.VarIntCoder(), coders.GlobalWindowCoder()) # Verify cloud object representation self.assertEqual( { '@type': 'kind:windowed_value', 'is_wrapper': True, 'component_encodings': [ coders.VarIntCoder().as_cloud_object(), coders.GlobalWindowCoder().as_cloud_object(), ], }, coder.as_cloud_object()) # Test binary representation self.assertEqual('\x7f\xdf;dZ\x1c\xac\t\x00\x00\x00\x01\x0f\x01', coder.encode(window.GlobalWindows.windowed_value(1))) # Test decoding large timestamp self.assertEqual( coder.decode('\x7f\xdf;dZ\x1c\xac\x08\x00\x00\x00\x01\x0f\x00'), windowed_value.create(0, MIN_TIMESTAMP.micros, (GlobalWindow(),))) # Test unnested self.check_coder( coders.WindowedValueCoder(coders.VarIntCoder()), windowed_value.WindowedValue(3, -100, ()), windowed_value.WindowedValue(-1, 100, (1, 2, 3))) # Test Global Window self.check_coder( coders.WindowedValueCoder(coders.VarIntCoder(), coders.GlobalWindowCoder()), window.GlobalWindows.windowed_value(1)) # Test nested self.check_coder( coders.TupleCoder(( coders.WindowedValueCoder(coders.FloatCoder()), coders.WindowedValueCoder(coders.StrUtf8Coder()))), (windowed_value.WindowedValue(1.5, 0, ()), windowed_value.WindowedValue("abc", 10, ('window',))))
[docs] def test_proto_coder(self): # For instructions on how these test proto message were generated, # see coders_test.py ma = test_message.MessageA() mab = ma.field2.add() mab.field1 = True ma.field1 = u'hello world' mb = test_message.MessageA() mb.field1 = u'beam' proto_coder = coders.ProtoCoder(ma.__class__) self.check_coder(proto_coder, ma) self.check_coder(coders.TupleCoder((proto_coder, coders.BytesCoder())), (ma, 'a'), (mb, 'b'))
[docs] def test_global_window_coder(self): coder = coders.GlobalWindowCoder() value = window.GlobalWindow() # Verify cloud object representation self.assertEqual({'@type': 'kind:global_window'}, coder.as_cloud_object()) # Test binary representation self.assertEqual('', coder.encode(value)) self.assertEqual(value, coder.decode('')) # Test unnested self.check_coder(coder, value) # Test nested self.check_coder(coders.TupleCoder((coder, coder)), (value, value))
[docs] def test_length_prefix_coder(self): coder = coders.LengthPrefixCoder(coders.BytesCoder()) # Verify cloud object representation self.assertEqual( { '@type': 'kind:length_prefix', 'component_encodings': [coders.BytesCoder().as_cloud_object()] }, coder.as_cloud_object()) # Test binary representation self.assertEqual('\x00', coder.encode('')) self.assertEqual('\x01a', coder.encode('a')) self.assertEqual('\x02bc', coder.encode('bc')) self.assertEqual('\xff\x7f' + 'z' * 16383, coder.encode('z' * 16383)) # Test unnested self.check_coder(coder, '', 'a', 'bc', 'def') # Test nested self.check_coder(coders.TupleCoder((coder, coder)), ('', 'a'), ('bc', 'def'))
[docs] def test_nested_observables(self): class FakeObservableIterator(observable.ObservableMixin): def __iter__(self): return iter([1, 2, 3]) # Coder for elements from the observable iterator. elem_coder = coders.VarIntCoder() iter_coder = coders.TupleSequenceCoder(elem_coder) # Test nested WindowedValue observable. coder = coders.WindowedValueCoder(iter_coder) observ = FakeObservableIterator() value = windowed_value.WindowedValue(observ, 0, ()) self.assertEqual( coder.get_impl().get_estimated_size_and_observables(value)[1], [(observ, elem_coder.get_impl())]) # Test nested tuple observable. coder = coders.TupleCoder((coders.StrUtf8Coder(), iter_coder)) value = (u'123', observ) self.assertEqual( coder.get_impl().get_estimated_size_and_observables(value)[1], [(observ, elem_coder.get_impl())])
if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()