#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests common to all coder implementations."""
import logging
import math
import unittest
import dill
from apache_beam.transforms.window import GlobalWindow
from apache_beam.utils.timestamp import MIN_TIMESTAMP
import observable
from apache_beam.transforms import window
from apache_beam.utils import timestamp
from apache_beam.utils import windowed_value
from apache_beam.coders import coders
from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
# Defined out of line for picklability.
[docs]class CustomCoder(coders.Coder):
[docs] def encode(self, x):
return str(x+1)
[docs] def decode(self, encoded):
return int(encoded) - 1
[docs]class CodersTest(unittest.TestCase):
# These class methods ensure that we test each defined coder in both
# nested and unnested context.
@classmethod
[docs] def setUpClass(cls):
cls.seen = set()
cls.seen_nested = set()
@classmethod
[docs] def tearDownClass(cls):
standard = set(c
for c in coders.__dict__.values()
if isinstance(c, type) and issubclass(c, coders.Coder) and
'Base' not in c.__name__)
standard -= set([coders.Coder,
coders.FastCoder,
coders.ProtoCoder,
coders.ToStringCoder])
assert not standard - cls.seen, standard - cls.seen
assert not standard - cls.seen_nested, standard - cls.seen_nested
@classmethod
def _observe(cls, coder):
cls.seen.add(type(coder))
cls._observe_nested(coder)
@classmethod
def _observe_nested(cls, coder):
if isinstance(coder, coders.TupleCoder):
for c in coder.coders():
cls.seen_nested.add(type(c))
cls._observe_nested(c)
[docs] def check_coder(self, coder, *values):
self._observe(coder)
for v in values:
self.assertEqual(v, coder.decode(coder.encode(v)))
self.assertEqual(coder.estimate_size(v),
len(coder.encode(v)))
self.assertEqual(coder.estimate_size(v),
coder.get_impl().estimate_size(v))
self.assertEqual(coder.get_impl().get_estimated_size_and_observables(v),
(coder.get_impl().estimate_size(v), []))
copy1 = dill.loads(dill.dumps(coder))
copy2 = dill.loads(dill.dumps(coder))
for v in values:
self.assertEqual(v, copy1.decode(copy2.encode(v)))
if coder.is_deterministic():
self.assertEqual(copy1.encode(v), copy2.encode(v))
[docs] def test_custom_coder(self):
self.check_coder(CustomCoder(), 1, -10, 5)
self.check_coder(coders.TupleCoder((CustomCoder(), coders.BytesCoder())),
(1, 'a'), (-10, 'b'), (5, 'c'))
[docs] def test_pickle_coder(self):
self.check_coder(coders.PickleCoder(), 'a', 1, 1.5, (1, 2, 3))
[docs] def test_deterministic_coder(self):
coder = coders.FastPrimitivesCoder()
deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 'step')
self.check_coder(deterministic_coder, 'a', 1, 1.5, (1, 2, 3))
with self.assertRaises(TypeError):
self.check_coder(deterministic_coder, dict())
with self.assertRaises(TypeError):
self.check_coder(deterministic_coder, [1, dict()])
self.check_coder(coders.TupleCoder((deterministic_coder, coder)),
(1, dict()), ('a', [dict()]))
[docs] def test_dill_coder(self):
cell_value = (lambda x: lambda: x)(0).func_closure[0]
self.check_coder(coders.DillCoder(), 'a', 1, cell_value)
self.check_coder(
coders.TupleCoder((coders.VarIntCoder(), coders.DillCoder())),
(1, cell_value))
[docs] def test_fast_primitives_coder(self):
coder = coders.FastPrimitivesCoder(coders.SingletonCoder(len))
self.check_coder(coder, None, 1, -1, 1.5, 'str\0str', u'unicode\0\u0101')
self.check_coder(coder, (), (1, 2, 3))
self.check_coder(coder, [], [1, 2, 3])
self.check_coder(coder, dict(), {'a': 'b'}, {0: dict(), 1: len})
self.check_coder(coder, set(), {'a', 'b'})
self.check_coder(coder, True, False)
self.check_coder(coder, len)
self.check_coder(coders.TupleCoder((coder,)), ('a',), (1,))
[docs] def test_bytes_coder(self):
self.check_coder(coders.BytesCoder(), 'a', '\0', 'z' * 1000)
[docs] def test_varint_coder(self):
# Small ints.
self.check_coder(coders.VarIntCoder(), *range(-10, 10))
# Multi-byte encoding starts at 128
self.check_coder(coders.VarIntCoder(), *range(120, 140))
# Large values
MAX_64_BIT_INT = 0x7fffffffffffffff
self.check_coder(coders.VarIntCoder(),
*[int(math.pow(-1, k) * math.exp(k))
for k in range(0, int(math.log(MAX_64_BIT_INT)))])
[docs] def test_float_coder(self):
self.check_coder(coders.FloatCoder(),
*[float(0.1 * x) for x in range(-100, 100)])
self.check_coder(coders.FloatCoder(),
*[float(2 ** (0.1 * x)) for x in range(-100, 100)])
self.check_coder(coders.FloatCoder(), float('-Inf'), float('Inf'))
self.check_coder(
coders.TupleCoder((coders.FloatCoder(), coders.FloatCoder())),
(0, 1), (-100, 100), (0.5, 0.25))
[docs] def test_singleton_coder(self):
a = 'anything'
b = 'something else'
self.check_coder(coders.SingletonCoder(a), a)
self.check_coder(coders.SingletonCoder(b), b)
self.check_coder(coders.TupleCoder((coders.SingletonCoder(a),
coders.SingletonCoder(b))), (a, b))
[docs] def test_interval_window_coder(self):
self.check_coder(coders.IntervalWindowCoder(),
*[window.IntervalWindow(x, y)
for x in [-2**52, 0, 2**52]
for y in range(-100, 100)])
self.check_coder(
coders.TupleCoder((coders.IntervalWindowCoder(),)),
(window.IntervalWindow(0, 10),))
[docs] def test_timestamp_coder(self):
self.check_coder(coders.TimestampCoder(),
*[timestamp.Timestamp(micros=x) for x in range(-100, 100)])
self.check_coder(coders.TimestampCoder(),
timestamp.Timestamp(micros=-1234567890),
timestamp.Timestamp(micros=1234567890))
self.check_coder(coders.TimestampCoder(),
timestamp.Timestamp(micros=-1234567890123456789),
timestamp.Timestamp(micros=1234567890123456789))
self.check_coder(
coders.TupleCoder((coders.TimestampCoder(), coders.BytesCoder())),
(timestamp.Timestamp.of(27), 'abc'))
[docs] def test_tuple_coder(self):
kv_coder = coders.TupleCoder((coders.VarIntCoder(), coders.BytesCoder()))
# Verify cloud object representation
self.assertEqual(
{
'@type': 'kind:pair',
'is_pair_like': True,
'component_encodings': [
coders.VarIntCoder().as_cloud_object(),
coders.BytesCoder().as_cloud_object()],
},
kv_coder.as_cloud_object())
# Test binary representation
self.assertEqual(
'\x04abc',
kv_coder.encode((4, 'abc')))
# Test unnested
self.check_coder(
kv_coder,
(1, 'a'),
(-2, 'a' * 100),
(300, 'abc\0' * 5))
# Test nested
self.check_coder(
coders.TupleCoder(
(coders.TupleCoder((coders.PickleCoder(), coders.VarIntCoder())),
coders.StrUtf8Coder())),
((1, 2), 'a'),
((-2, 5), u'a\u0101' * 100),
((300, 1), 'abc\0' * 5))
[docs] def test_tuple_sequence_coder(self):
int_tuple_coder = coders.TupleSequenceCoder(coders.VarIntCoder())
self.check_coder(int_tuple_coder, (1, -1, 0), (), tuple(range(1000)))
self.check_coder(
coders.TupleCoder((coders.VarIntCoder(), int_tuple_coder)),
(1, (1, 2, 3)))
[docs] def test_base64_pickle_coder(self):
self.check_coder(coders.Base64PickleCoder(), 'a', 1, 1.5, (1, 2, 3))
[docs] def test_utf8_coder(self):
self.check_coder(coders.StrUtf8Coder(), 'a', u'ab\u00FF', u'\u0101\0')
[docs] def test_iterable_coder(self):
iterable_coder = coders.IterableCoder(coders.VarIntCoder())
# Verify cloud object representation
self.assertEqual(
{
'@type': 'kind:stream',
'is_stream_like': True,
'component_encodings': [coders.VarIntCoder().as_cloud_object()]
},
iterable_coder.as_cloud_object())
# Test unnested
self.check_coder(iterable_coder,
[1], [-1, 0, 100])
# Test nested
self.check_coder(
coders.TupleCoder((coders.VarIntCoder(),
coders.IterableCoder(coders.VarIntCoder()))),
(1, [1, 2, 3]))
[docs] def test_iterable_coder_unknown_length(self):
# Empty
self._test_iterable_coder_of_unknown_length(0)
# Single element
self._test_iterable_coder_of_unknown_length(1)
# Multiple elements
self._test_iterable_coder_of_unknown_length(100)
# Multiple elements with underlying stream buffer overflow.
self._test_iterable_coder_of_unknown_length(80000)
def _test_iterable_coder_of_unknown_length(self, count):
def iter_generator(count):
for i in range(count):
yield i
iterable_coder = coders.IterableCoder(coders.VarIntCoder())
self.assertItemsEqual(list(iter_generator(count)),
iterable_coder.decode(
iterable_coder.encode(iter_generator(count))))
[docs] def test_windowed_value_coder(self):
coder = coders.WindowedValueCoder(coders.VarIntCoder(),
coders.GlobalWindowCoder())
# Verify cloud object representation
self.assertEqual(
{
'@type': 'kind:windowed_value',
'is_wrapper': True,
'component_encodings': [
coders.VarIntCoder().as_cloud_object(),
coders.GlobalWindowCoder().as_cloud_object(),
],
},
coder.as_cloud_object())
# Test binary representation
self.assertEqual('\x7f\xdf;dZ\x1c\xac\t\x00\x00\x00\x01\x0f\x01',
coder.encode(window.GlobalWindows.windowed_value(1)))
# Test decoding large timestamp
self.assertEqual(
coder.decode('\x7f\xdf;dZ\x1c\xac\x08\x00\x00\x00\x01\x0f\x00'),
windowed_value.create(0, MIN_TIMESTAMP.micros, (GlobalWindow(),)))
# Test unnested
self.check_coder(
coders.WindowedValueCoder(coders.VarIntCoder()),
windowed_value.WindowedValue(3, -100, ()),
windowed_value.WindowedValue(-1, 100, (1, 2, 3)))
# Test Global Window
self.check_coder(
coders.WindowedValueCoder(coders.VarIntCoder(),
coders.GlobalWindowCoder()),
window.GlobalWindows.windowed_value(1))
# Test nested
self.check_coder(
coders.TupleCoder((
coders.WindowedValueCoder(coders.FloatCoder()),
coders.WindowedValueCoder(coders.StrUtf8Coder()))),
(windowed_value.WindowedValue(1.5, 0, ()),
windowed_value.WindowedValue("abc", 10, ('window',))))
[docs] def test_proto_coder(self):
# For instructions on how these test proto message were generated,
# see coders_test.py
ma = test_message.MessageA()
mab = ma.field2.add()
mab.field1 = True
ma.field1 = u'hello world'
mb = test_message.MessageA()
mb.field1 = u'beam'
proto_coder = coders.ProtoCoder(ma.__class__)
self.check_coder(proto_coder, ma)
self.check_coder(coders.TupleCoder((proto_coder, coders.BytesCoder())),
(ma, 'a'), (mb, 'b'))
[docs] def test_global_window_coder(self):
coder = coders.GlobalWindowCoder()
value = window.GlobalWindow()
# Verify cloud object representation
self.assertEqual({'@type': 'kind:global_window'},
coder.as_cloud_object())
# Test binary representation
self.assertEqual('', coder.encode(value))
self.assertEqual(value, coder.decode(''))
# Test unnested
self.check_coder(coder, value)
# Test nested
self.check_coder(coders.TupleCoder((coder, coder)),
(value, value))
[docs] def test_length_prefix_coder(self):
coder = coders.LengthPrefixCoder(coders.BytesCoder())
# Verify cloud object representation
self.assertEqual(
{
'@type': 'kind:length_prefix',
'component_encodings': [coders.BytesCoder().as_cloud_object()]
},
coder.as_cloud_object())
# Test binary representation
self.assertEqual('\x00', coder.encode(''))
self.assertEqual('\x01a', coder.encode('a'))
self.assertEqual('\x02bc', coder.encode('bc'))
self.assertEqual('\xff\x7f' + 'z' * 16383, coder.encode('z' * 16383))
# Test unnested
self.check_coder(coder, '', 'a', 'bc', 'def')
# Test nested
self.check_coder(coders.TupleCoder((coder, coder)),
('', 'a'),
('bc', 'def'))
[docs] def test_nested_observables(self):
class FakeObservableIterator(observable.ObservableMixin):
def __iter__(self):
return iter([1, 2, 3])
# Coder for elements from the observable iterator.
elem_coder = coders.VarIntCoder()
iter_coder = coders.TupleSequenceCoder(elem_coder)
# Test nested WindowedValue observable.
coder = coders.WindowedValueCoder(iter_coder)
observ = FakeObservableIterator()
value = windowed_value.WindowedValue(observ, 0, ())
self.assertEqual(
coder.get_impl().get_estimated_size_and_observables(value)[1],
[(observ, elem_coder.get_impl())])
# Test nested tuple observable.
coder = coders.TupleCoder((coders.StrUtf8Coder(), iter_coder))
value = (u'123', observ)
self.assertEqual(
coder.get_impl().get_estimated_size_and_observables(value)[1],
[(observ, elem_coder.get_impl())])
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()