#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Transform Beam PTransforms into Dask Bag operations.
A minimum set of operation substitutions, to adap Beam's PTransform model
to Dask Bag functions.
TODO(alxr): Translate ops from https://docs.dask.org/en/latest/bag-api.html.
"""
import abc
import dataclasses
import typing as t
import apache_beam
import dask.bag as db
from apache_beam.pipeline import AppliedPTransform
from apache_beam.runners.dask.overrides import _Create
from apache_beam.runners.dask.overrides import _Flatten
from apache_beam.runners.dask.overrides import _GroupByKeyOnly
OpInput = t.Union[db.Bag, t.Sequence[db.Bag], None]
[docs]@dataclasses.dataclass
class DaskBagOp(abc.ABC):
applied: AppliedPTransform
@property
def transform(self):
return self.applied.transform
[docs] @abc.abstractmethod
def apply(self, input_bag: OpInput) -> db.Bag:
pass
[docs]class NoOp(DaskBagOp):
[docs] def apply(self, input_bag: OpInput) -> db.Bag:
return input_bag
[docs]class Create(DaskBagOp):
[docs] def apply(self, input_bag: OpInput) -> db.Bag:
assert input_bag is None, 'Create expects no input!'
original_transform = t.cast(_Create, self.transform)
items = original_transform.values
return db.from_sequence(items)
[docs]class ParDo(DaskBagOp):
[docs] def apply(self, input_bag: db.Bag) -> db.Bag:
transform = t.cast(apache_beam.ParDo, self.transform)
return input_bag.map(
transform.fn.process, *transform.args, **transform.kwargs).flatten()
[docs]class Map(DaskBagOp):
[docs] def apply(self, input_bag: db.Bag) -> db.Bag:
transform = t.cast(apache_beam.Map, self.transform)
return input_bag.map(
transform.fn.process, *transform.args, **transform.kwargs)
[docs]class GroupByKey(DaskBagOp):
[docs] def apply(self, input_bag: db.Bag) -> db.Bag:
def key(item):
return item[0]
def value(item):
k, v = item
return k, [elm[1] for elm in v]
return input_bag.groupby(key).map(value)
[docs]class Flatten(DaskBagOp):
[docs] def apply(self, input_bag: OpInput) -> db.Bag:
assert type(input_bag) is list, 'Must take a sequence of bags!'
return db.concat(input_bag)
TRANSLATIONS = {
_Create: Create,
apache_beam.ParDo: ParDo,
apache_beam.Map: Map,
_GroupByKeyOnly: GroupByKey,
_Flatten: Flatten,
}