Source code for apache_beam.runners.dask.transform_evaluator

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Transform Beam PTransforms into Dask Bag operations.

A minimum set of operation substitutions, to adap Beam's PTransform model
to Dask Bag functions.

TODO(alxr): Translate ops from https://docs.dask.org/en/latest/bag-api.html.
"""
import abc
import dataclasses
import math
import typing as t

import apache_beam
import dask.bag as db
from apache_beam.pipeline import AppliedPTransform
from apache_beam.runners.dask.overrides import _Create
from apache_beam.runners.dask.overrides import _Flatten
from apache_beam.runners.dask.overrides import _GroupByKeyOnly

OpInput = t.Union[db.Bag, t.Sequence[db.Bag], None]


[docs] @dataclasses.dataclass class DaskBagOp(abc.ABC): applied: AppliedPTransform @property def transform(self): return self.applied.transform
[docs] @abc.abstractmethod def apply(self, input_bag: OpInput) -> db.Bag: pass
[docs] class NoOp(DaskBagOp):
[docs] def apply(self, input_bag: OpInput) -> db.Bag: return input_bag
[docs] class Create(DaskBagOp):
[docs] def apply(self, input_bag: OpInput) -> db.Bag: assert input_bag is None, 'Create expects no input!' original_transform = t.cast(_Create, self.transform) items = original_transform.values return db.from_sequence( items, partition_size=max( 1, math.ceil(math.sqrt(len(items)) / math.sqrt(100))))
[docs] class ParDo(DaskBagOp):
[docs] def apply(self, input_bag: db.Bag) -> db.Bag: transform = t.cast(apache_beam.ParDo, self.transform) return input_bag.map( transform.fn.process, *transform.args, **transform.kwargs).flatten()
[docs] class Map(DaskBagOp):
[docs] def apply(self, input_bag: db.Bag) -> db.Bag: transform = t.cast(apache_beam.Map, self.transform) return input_bag.map( transform.fn.process, *transform.args, **transform.kwargs)
[docs] class GroupByKey(DaskBagOp):
[docs] def apply(self, input_bag: db.Bag) -> db.Bag: def key(item): return item[0] def value(item): k, v = item return k, [elm[1] for elm in v] return input_bag.groupby(key).map(value)
[docs] class Flatten(DaskBagOp):
[docs] def apply(self, input_bag: OpInput) -> db.Bag: assert type(input_bag) is list, 'Must take a sequence of bags!' return db.concat(input_bag)
TRANSLATIONS = { _Create: Create, apache_beam.ParDo: ParDo, apache_beam.Map: Map, _GroupByKeyOnly: GroupByKey, _Flatten: Flatten, }