Source code for apache_beam.coders.row_coder

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

# pytype: skip-file

import itertools
from array import array

from apache_beam.coders import typecoders
from apache_beam.coders.coder_impl import StreamCoderImpl
from apache_beam.coders.coders import BooleanCoder
from apache_beam.coders.coders import BytesCoder
from apache_beam.coders.coders import Coder
from apache_beam.coders.coders import FastCoder
from apache_beam.coders.coders import FloatCoder
from apache_beam.coders.coders import IterableCoder
from apache_beam.coders.coders import MapCoder
from apache_beam.coders.coders import NullableCoder
from apache_beam.coders.coders import StrUtf8Coder
from apache_beam.coders.coders import TupleCoder
from apache_beam.coders.coders import VarIntCoder
from apache_beam.portability import common_urns
from apache_beam.portability.api import schema_pb2
from apache_beam.typehints import row_type
from apache_beam.typehints.schemas import PYTHON_ANY_URN
from apache_beam.typehints.schemas import LogicalType
from apache_beam.typehints.schemas import named_tuple_from_schema
from apache_beam.typehints.schemas import schema_from_element_type
from apache_beam.utils import proto_utils

__all__ = ["RowCoder"]

[docs]class RowCoder(FastCoder): """ Coder for `typing.NamedTuple` instances. Implements the beam:coder:row:v1 standard coder spec. """ def __init__(self, schema): """Initializes a :class:`RowCoder`. Args: schema (apache_beam.portability.api.schema_pb2.Schema): The protobuf representation of the schema of the data that the RowCoder will be used to encode/decode. """ self.schema = schema # Use non-null coders because null values are represented separately self.components = [ _nonnull_coder_from_type(field.type) for field in self.schema.fields ] def _create_impl(self): return RowCoderImpl(self.schema, self.components)
[docs] def is_deterministic(self): return all(c.is_deterministic() for c in self.components)
[docs] def to_type_hint(self): return named_tuple_from_schema(self.schema)
def __hash__(self): return hash(self.schema.SerializeToString()) def __eq__(self, other): return type(self) == type(other) and self.schema == other.schema
[docs] def to_runner_api_parameter(self, unused_context): return (common_urns.coders.ROW.urn, self.schema, [])
[docs] @staticmethod @Coder.register_urn(common_urns.coders.ROW.urn, schema_pb2.Schema) def from_runner_api_parameter(schema, components, unused_context): return RowCoder(schema)
[docs] @staticmethod def from_type_hint(type_hint, registry): schema = schema_from_element_type(type_hint) return RowCoder(schema)
[docs] @staticmethod def from_payload(payload): # type: (bytes) -> RowCoder return RowCoder(proto_utils.parse_Bytes(payload, schema_pb2.Schema))
def __reduce__(self): # when pickling, use bytes representation of the schema. schema_pb2.Schema # objects cannot be pickled. return (RowCoder.from_payload, (self.schema.SerializeToString(), ))
typecoders.registry.register_coder(row_type.RowTypeConstraint, RowCoder) def _coder_from_type(field_type): coder = _nonnull_coder_from_type(field_type) if field_type.nullable: return NullableCoder(coder) else: return coder def _nonnull_coder_from_type(field_type): type_info = field_type.WhichOneof("type_info") if type_info == "atomic_type": if field_type.atomic_type in (schema_pb2.INT32, schema_pb2.INT64): return VarIntCoder() elif field_type.atomic_type == schema_pb2.DOUBLE: return FloatCoder() elif field_type.atomic_type == schema_pb2.STRING: return StrUtf8Coder() elif field_type.atomic_type == schema_pb2.BOOLEAN: return BooleanCoder() elif field_type.atomic_type == schema_pb2.BYTES: return BytesCoder() elif type_info == "array_type": return IterableCoder(_coder_from_type(field_type.array_type.element_type)) elif type_info == "map_type": return MapCoder( _coder_from_type(field_type.map_type.key_type), _coder_from_type(field_type.map_type.value_type)) elif type_info == "logical_type": # Special case for the Any logical type. Just use the default coder for an # unknown Python object. if field_type.logical_type.urn == PYTHON_ANY_URN: return typecoders.registry.get_coder(object) logical_type = LogicalType.from_runner_api(field_type.logical_type) return LogicalTypeCoder( logical_type, _coder_from_type(field_type.logical_type.representation)) elif type_info == "row_type": return RowCoder(field_type.row_type.schema) # The Java SDK supports several more types, but the coders are not yet # standard, and are not implemented in Python. raise ValueError( "Encountered a type that is not currently supported by RowCoder: %s" % field_type) class RowCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" SIZE_CODER = VarIntCoder().get_impl() NULL_MARKER_CODER = BytesCoder().get_impl() def __init__(self, schema, components): self.schema = schema self.constructor = named_tuple_from_schema(schema) self.components = list(c.get_impl() for c in components) self.has_nullable_fields = any( field.type.nullable for field in self.schema.fields) def encode_to_stream(self, value, out, nested): nvals = len(self.schema.fields) self.SIZE_CODER.encode_to_stream(nvals, out, True) attrs = [getattr(value, for f in self.schema.fields] words = array('B') if self.has_nullable_fields: nulls = list(attr is None for attr in attrs) if any(nulls): words = array('B', itertools.repeat(0, (nvals + 7) // 8)) for i, is_null in enumerate(nulls): words[i // 8] |= is_null << (i % 8) self.NULL_MARKER_CODER.encode_to_stream(words.tobytes(), out, True) for c, field, attr in zip(self.components, self.schema.fields, attrs): if attr is None: if not field.type.nullable: raise ValueError( "Attempted to encode null for non-nullable field \"{}\".".format( continue c.encode_to_stream(attr, out, True) def decode_from_stream(self, in_stream, nested): nvals = self.SIZE_CODER.decode_from_stream(in_stream, True) words = array('B') words.frombytes(self.NULL_MARKER_CODER.decode_from_stream(in_stream, True)) if words: nulls = ((words[i // 8] >> (i % 8)) & 0x01 for i in range(nvals)) else: nulls = itertools.repeat(False, nvals) # If this coder's schema has more attributes than the encoded value, then # the schema must have changed. Populate the unencoded fields with nulls. if len(self.components) > nvals: nulls = itertools.chain( nulls, itertools.repeat(True, len(self.components) - nvals)) # Note that if this coder's schema has *fewer* attributes than the encoded # value, we just need to ignore the additional values, which will occur # here because we only decode as many values as we have coders for. return self.constructor( *( None if is_null else c.decode_from_stream(in_stream, True) for c, is_null in zip(self.components, nulls))) def _make_value_coder(self, nulls=itertools.repeat(False)): components = [ component for component, is_null in zip(self.components, nulls) if not is_null ] if self.has_nullable_fields else self.components return TupleCoder(components).get_impl() class LogicalTypeCoder(FastCoder): def __init__(self, logical_type, representation_coder): self.logical_type = logical_type self.representation_coder = representation_coder def _create_impl(self): return LogicalTypeCoderImpl(self.logical_type, self.representation_coder) def is_deterministic(self): return self.representation_coder.is_deterministic() def to_type_hint(self): return self.logical_type.language_type() class LogicalTypeCoderImpl(StreamCoderImpl): def __init__(self, logical_type, representation_coder): self.logical_type = logical_type self.representation_coder = representation_coder.get_impl() def encode_to_stream(self, value, out, nested): return self.representation_coder.encode_to_stream( self.logical_type.to_representation_type(value), out, nested) def decode_from_stream(self, in_stream, nested): return self.logical_type.to_language_type( self.representation_coder.decode_from_stream(in_stream, nested))