#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This module contains Splittable DoFn logic that's common to all runners."""
import uuid
import apache_beam as beam
from apache_beam import pvalue
from apache_beam.coders import typecoders
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PTransformOverride
from apache_beam.runners.common import DoFnInvoker
from apache_beam.runners.common import DoFnSignature
from apache_beam.transforms.core import ParDo
from apache_beam.transforms.ptransform import PTransform
[docs]class SplittableParDoOverride(PTransformOverride):
  """A transform override for ParDo transformss of SplittableDoFns.
  Replaces the ParDo transform with a SplittableParDo transform that performs
  SDF specific logic.
  """
[docs]  def matches(self, applied_ptransform):
    assert isinstance(applied_ptransform, AppliedPTransform)
    transform = applied_ptransform.transform
    if isinstance(transform, ParDo):
      signature = DoFnSignature(transform.fn)
      return signature.is_splittable_dofn() 
 
[docs]class SplittableParDo(PTransform):
  """A transform that processes a PCollection using a Splittable DoFn."""
  def __init__(self, ptransform):
    assert isinstance(ptransform, ParDo)
    self._ptransform = ptransform
[docs]  def expand(self, pcoll):
    sdf = self._ptransform.fn
    signature = DoFnSignature(sdf)
    invoker = DoFnInvoker.create_invoker(signature, process_invocation=False)
    element_coder = typecoders.registry.get_coder(pcoll.element_type)
    restriction_coder = invoker.invoke_restriction_coder()
    keyed_elements = (pcoll
                      | 'pair' >> ParDo(PairWithRestrictionFn(sdf))
                      | 'split' >> ParDo(SplitRestrictionFn(sdf))
                      | 'explode' >> ParDo(ExplodeWindowsFn())
                      | 'random' >> ParDo(RandomUniqueKeyFn()))
    return keyed_elements | ProcessKeyedElements(
        sdf, element_coder, restriction_coder,
        pcoll.windowing, self._ptransform.args, self._ptransform.kwargs,
        self._ptransform.side_inputs)  
[docs]class ElementAndRestriction(object):
  """A holder for an element and a restriction."""
  def __init__(self, element, restriction):
    self.element = element
    self.restriction = restriction 
[docs]class PairWithRestrictionFn(beam.DoFn):
  """A transform that pairs each element with a restriction."""
  def __init__(self, do_fn):
    self._do_fn = do_fn
[docs]  def start_bundle(self):
    signature = DoFnSignature(self._do_fn)
    self._invoker = DoFnInvoker.create_invoker(
        signature, process_invocation=False) 
[docs]  def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
    initial_restriction = self._invoker.invoke_initial_restriction(element)
    yield ElementAndRestriction(element, initial_restriction)  
[docs]class SplitRestrictionFn(beam.DoFn):
  """A transform that perform initial splitting of Splittable DoFn inputs."""
  def __init__(self, do_fn):
    self._do_fn = do_fn
[docs]  def start_bundle(self):
    signature = DoFnSignature(self._do_fn)
    self._invoker = DoFnInvoker.create_invoker(
        signature, process_invocation=False) 
[docs]  def process(self, element_and_restriction, *args, **kwargs):
    element = element_and_restriction.element
    restriction = element_and_restriction.restriction
    restriction_parts = self._invoker.invoke_split(
        element,
        restriction)
    for part in restriction_parts:
      yield ElementAndRestriction(element, part)  
[docs]class ExplodeWindowsFn(beam.DoFn):
  """A transform that forces the runner to explode windows.
  This is done to make sure that Splittable DoFn proceses an element for each of
  the windows that element belongs to.
  """
[docs]  def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
    yield element  
[docs]class RandomUniqueKeyFn(beam.DoFn):
  """A transform that assigns a unique key to each element."""
[docs]  def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
    # We ignore UUID collisions here since they are extremely rare.
    yield (uuid.uuid4().bytes, element)  
[docs]class ProcessKeyedElements(PTransform):
  """A primitive transform that performs SplittableDoFn magic.
  Input to this transform should be a PCollection of keyed ElementAndRestriction
  objects.
  """
  def __init__(
      self, sdf, element_coder, restriction_coder, windowing_strategy,
      ptransform_args, ptransform_kwargs, ptransform_side_inputs):
    self.sdf = sdf
    self.element_coder = element_coder
    self.restriction_coder = restriction_coder
    self.windowing_strategy = windowing_strategy
    self.ptransform_args = ptransform_args
    self.ptransform_kwargs = ptransform_kwargs
    self.ptransform_side_inputs = ptransform_side_inputs
[docs]  def expand(self, pcoll):
    return pvalue.PCollection(pcoll.pipeline)