# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
from apache_beam.typehints import typehints
from apache_beam.typehints.batch import BatchConverter
from apache_beam.typehints.batch import N

[docs]class PytorchBatchConverter(BatchConverter): def __init__( self, batch_type, element_type, dtype, element_shape=(), partition_dimension=0): super().__init__(batch_type, element_type) self.dtype = dtype self.element_shape = element_shape self.partition_dimension = partition_dimension
[docs] @staticmethod @BatchConverter.register(name="pytorch") def from_typehints(element_type, batch_type) -> Optional['PytorchBatchConverter']: if not isinstance(element_type, PytorchTypeHint.PytorchTypeConstraint): element_type = PytorchTensor[element_type, ()] if not isinstance(batch_type, PytorchTypeHint.PytorchTypeConstraint): if not batch_type == torch.Tensor: raise TypeError( "batch type must be torch.Tensor or " "beam.typehints.pytorch_type_compatibility.PytorchTensor[..]") batch_type = PytorchTensor[element_type.dtype, (N, )] if not batch_type.dtype == element_type.dtype: raise TypeError( "batch type and element type must have equivalent dtypes " f"(batch={batch_type.dtype}, element={element_type.dtype})") computed_element_shape = list(batch_type.shape) partition_dimension = computed_element_shape.index(N) computed_element_shape.pop(partition_dimension) if not tuple(computed_element_shape) == element_type.shape: raise TypeError( "Could not align batch type's batch dimension with element type. " f"(batch type dimensions: {batch_type.shape}, element type " f"dimenstions: {element_type.shape}") return PytorchBatchConverter( batch_type, element_type, batch_type.dtype, element_type.shape, partition_dimension)
[docs] def produce_batch(self, elements): return torch.stack(elements, dim=self.partition_dimension)
[docs] def explode_batch(self, batch): """Convert an instance of B to Generator[E].""" yield from torch.swapaxes(batch, self.partition_dimension, 0)
[docs] def combine_batches(self, batches): return, dim=self.partition_dimension)
[docs] def get_length(self, batch): return batch.size(dim=self.partition_dimension)
[docs] def estimate_byte_size(self, batch): return batch.nelement() * batch.element_size()
[docs]class PytorchTypeHint():
[docs] class PytorchTypeConstraint(typehints.TypeConstraint): def __init__(self, dtype, shape=()): self.dtype = dtype self.shape = shape
[docs] def type_check(self, batch): if not isinstance(batch, torch.Tensor): raise TypeError(f"Batch {batch!r} is not an instance of torch.Tensor") if not batch.dtype == self.dtype: raise TypeError( f"Batch {batch!r} does not have expected dtype: {self.dtype!r}") for dim in range(len(self.shape)): if not self.shape[dim] == N and not batch.shape[dim] == self.shape[dim]: raise TypeError( f"Batch {batch!r} does not have expected shape: {self.shape!r}")
def _consistent_with_check_(self, sub): # TODO Check sub against batch type, and element type return True def __key(self): return (self.dtype, self.shape) def __eq__(self, other) -> bool: if isinstance(other, PytorchTypeHint.PytorchTypeConstraint): return self.__key() == other.__key() return NotImplemented def __hash__(self) -> int: return hash(self.__key()) def __repr__(self): if self.shape == (N, ): return f'PytorchTensor[{self.dtype!r}]' else: return f'PytorchTensor[{self.dtype!r}, {self.shape!r}]'
def __getitem__(self, value): if isinstance(value, tuple): if len(value) == 2: dtype, shape = value return self.PytorchTypeConstraint(dtype, shape=shape) else: raise ValueError else: dtype = value return self.PytorchTypeConstraint(dtype, shape=(N, ))
PytorchTensor = PytorchTypeHint()