Source code for apache_beam.runners.dask.overrides

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
import typing as t

import apache_beam as beam
from apache_beam import typehints
from apache_beam.io.iobase import SourceBase
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PTransformOverride
from apache_beam.runners.direct.direct_runner import _GroupAlsoByWindowDoFn
from apache_beam.transforms import ptransform
from apache_beam.transforms.window import GlobalWindows

K = t.TypeVar("K")
V = t.TypeVar("V")


@dataclasses.dataclass
class _Create(beam.PTransform):
  values: t.Tuple[t.Any]

  def expand(self, input_or_inputs):
    return beam.pvalue.PCollection.from_(input_or_inputs)

  def get_windowing(self, inputs: t.Any) -> beam.Windowing:
    return beam.Windowing(GlobalWindows())


@typehints.with_input_types(K)
@typehints.with_output_types(K)
class _Reshuffle(beam.PTransform):
  def expand(self, input_or_inputs):
    return beam.pvalue.PCollection.from_(input_or_inputs)


@dataclasses.dataclass
class _Read(beam.PTransform):
  source: SourceBase

  def expand(self, input_or_inputs):
    return beam.pvalue.PCollection.from_(input_or_inputs)


@typehints.with_input_types(t.Tuple[K, V])
@typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
class _GroupByKeyOnly(beam.PTransform):
  def expand(self, input_or_inputs):
    return beam.pvalue.PCollection.from_(input_or_inputs)

  def infer_output_type(self, input_type):

    key_type, value_type = typehints.trivial_inference.key_value_types(
      input_type
    )
    return typehints.KV[key_type, typehints.Iterable[value_type]]


@typehints.with_input_types(t.Tuple[K, t.Iterable[V]])
@typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
class _GroupAlsoByWindow(beam.ParDo):
  def __init__(self, windowing):
    super().__init__(_GroupAlsoByWindowDoFn(windowing))
    self.windowing = windowing

  def expand(self, input_or_inputs):
    return beam.pvalue.PCollection.from_(input_or_inputs)


@typehints.with_input_types(t.Tuple[K, V])
@typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
class _GroupByKey(beam.PTransform):
  def expand(self, input_or_inputs):
    return (
        input_or_inputs
        | "ReifyWindows" >> beam.ParDo(beam.GroupByKey.ReifyWindows())
        | "GroupByKey" >> _GroupByKeyOnly()
        | "GroupByWindow" >> _GroupAlsoByWindow(input_or_inputs.windowing))


class _Flatten(beam.PTransform):
  def expand(self, input_or_inputs):
    if isinstance(input_or_inputs, beam.PCollection):
      # NOTE(cisaacstern): I needed this to avoid
      #   `TypeError: 'PCollection' object is not iterable`
      # being raised by `all(...)` call below for single-element flattens, i.e.,
      #   `(pcoll, ) | beam.Flatten() | ...`
      is_bounded = input_or_inputs.is_bounded
    else:
      is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs)
    return beam.pvalue.PCollection(self.pipeline, is_bounded=is_bounded)


[docs] def dask_overrides() -> t.List[PTransformOverride]: class CreateOverride(PTransformOverride): def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.Create def get_replacement_transform_for_applied_ptransform( self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: return _Create(t.cast(beam.Create, applied_ptransform.transform).values) class ReshuffleOverride(PTransformOverride): def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.Reshuffle def get_replacement_transform_for_applied_ptransform( self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: return _Reshuffle() class ReadOverride(PTransformOverride): def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.io.Read def get_replacement_transform_for_applied_ptransform( self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: return _Read(t.cast(beam.io.Read, applied_ptransform.transform).source) class GroupByKeyOverride(PTransformOverride): def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.GroupByKey def get_replacement_transform_for_applied_ptransform( self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: return _GroupByKey() class FlattenOverride(PTransformOverride): def matches(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform.transform.__class__ == beam.Flatten def get_replacement_transform_for_applied_ptransform( self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: return _Flatten() return [ CreateOverride(), ReshuffleOverride(), ReadOverride(), GroupByKeyOverride(), FlattenOverride(), ]