#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This module contains Splittable DoFn logic that is specific to DirectRunner.
"""
# pytype: skip-file
import uuid
from threading import Lock
from threading import Timer
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Optional
import apache_beam as beam
from apache_beam import TimeDomain
from apache_beam import pvalue
from apache_beam.coders import typecoders
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PTransformOverride
from apache_beam.runners.common import DoFnContext
from apache_beam.runners.common import DoFnInvoker
from apache_beam.runners.common import DoFnSignature
from apache_beam.runners.common import OutputProcessor
from apache_beam.runners.direct.evaluation_context import DirectStepContext
from apache_beam.runners.direct.util import KeyedWorkItem
from apache_beam.runners.direct.watermark_manager import WatermarkManager
from apache_beam.transforms.core import ParDo
from apache_beam.transforms.core import ProcessContinuation
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.trigger import _ReadModifyWriteStateTag
from apache_beam.utils.windowed_value import WindowedValue
if TYPE_CHECKING:
  from apache_beam.iobase import WatermarkEstimator
[docs]class SplittableParDoOverride(PTransformOverride):
  """A transform override for ParDo transformss of SplittableDoFns.
  Replaces the ParDo transform with a SplittableParDo transform that performs
  SDF specific logic.
  """
[docs]  def matches(self, applied_ptransform):
    assert isinstance(applied_ptransform, AppliedPTransform)
    transform = applied_ptransform.transform
    if isinstance(transform, ParDo):
      signature = DoFnSignature(transform.fn)
      return signature.is_splittable_dofn() 
 
[docs]class SplittableParDo(PTransform):
  """A transform that processes a PCollection using a Splittable DoFn."""
  def __init__(self, ptransform):
    assert isinstance(ptransform, ParDo)
    self._ptransform = ptransform
[docs]  def expand(self, pcoll):
    sdf = self._ptransform.fn
    signature = DoFnSignature(sdf)
    restriction_coder = signature.get_restriction_coder()
    element_coder = typecoders.registry.get_coder(pcoll.element_type)
    keyed_elements = (
        pcoll
        | 'pair' >> ParDo(PairWithRestrictionFn(sdf))
        | 'split' >> ParDo(SplitRestrictionFn(sdf))
        | 'explode' >> ParDo(ExplodeWindowsFn())
        | 'random' >> ParDo(RandomUniqueKeyFn()))
    return keyed_elements | ProcessKeyedElements(
        sdf,
        element_coder,
        restriction_coder,
        pcoll.windowing,
        self._ptransform.args,
        self._ptransform.kwargs,
        self._ptransform.side_inputs)  
[docs]class ElementAndRestriction(object):
  """A holder for an element, restriction, and watermark estimator state."""
  def __init__(self, element, restriction, watermark_estimator_state):
    self.element = element
    self.restriction = restriction
    self.watermark_estimator_state = watermark_estimator_state 
[docs]class PairWithRestrictionFn(beam.DoFn):
  """A transform that pairs each element with a restriction."""
  def __init__(self, do_fn):
    self._signature = DoFnSignature(do_fn)
[docs]  def start_bundle(self):
    self._invoker = DoFnInvoker.create_invoker(
        self._signature,
        output_processor=_NoneShallPassOutputProcessor(),
        process_invocation=False) 
[docs]  def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
    initial_restriction = self._invoker.invoke_initial_restriction(element)
    watermark_estimator_state = (
        self._signature.process_method.watermark_estimator_provider.
        initial_estimator_state(element, initial_restriction))
    yield ElementAndRestriction(
        element, initial_restriction, watermark_estimator_state)  
[docs]class SplitRestrictionFn(beam.DoFn):
  """A transform that perform initial splitting of Splittable DoFn inputs."""
  def __init__(self, do_fn):
    self._do_fn = do_fn
[docs]  def start_bundle(self):
    signature = DoFnSignature(self._do_fn)
    self._invoker = DoFnInvoker.create_invoker(
        signature,
        output_processor=_NoneShallPassOutputProcessor(),
        process_invocation=False) 
[docs]  def process(self, element_and_restriction, *args, **kwargs):
    element = element_and_restriction.element
    restriction = element_and_restriction.restriction
    restriction_parts = self._invoker.invoke_split(element, restriction)
    for part in restriction_parts:
      yield ElementAndRestriction(
          element, part, element_and_restriction.watermark_estimator_state)  
[docs]class ExplodeWindowsFn(beam.DoFn):
  """A transform that forces the runner to explode windows.
  This is done to make sure that Splittable DoFn proceses an element for each of
  the windows that element belongs to.
  """
[docs]  def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
    yield element  
[docs]class RandomUniqueKeyFn(beam.DoFn):
  """A transform that assigns a unique key to each element."""
[docs]  def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
    # We ignore UUID collisions here since they are extremely rare.
    yield (uuid.uuid4().bytes, element)  
[docs]class ProcessKeyedElements(PTransform):
  """A primitive transform that performs SplittableDoFn magic.
  Input to this transform should be a PCollection of keyed ElementAndRestriction
  objects.
  """
  def __init__(
      self,
      sdf,
      element_coder,
      restriction_coder,
      windowing_strategy,
      ptransform_args,
      ptransform_kwargs,
      ptransform_side_inputs):
    self.sdf = sdf
    self.element_coder = element_coder
    self.restriction_coder = restriction_coder
    self.windowing_strategy = windowing_strategy
    self.ptransform_args = ptransform_args
    self.ptransform_kwargs = ptransform_kwargs
    self.ptransform_side_inputs = ptransform_side_inputs
[docs]  def expand(self, pcoll):
    return pvalue.PCollection.from_(pcoll)  
[docs]class ProcessKeyedElementsViaKeyedWorkItemsOverride(PTransformOverride):
  """A transform override for ProcessElements transform."""
[docs]  def matches(self, applied_ptransform):
    return isinstance(applied_ptransform.transform, ProcessKeyedElements) 
 
[docs]class ProcessKeyedElementsViaKeyedWorkItems(PTransform):
  """A transform that processes Splittable DoFn input via KeyedWorkItems."""
  def __init__(self, process_keyed_elements_transform):
    self._process_keyed_elements_transform = process_keyed_elements_transform
[docs]  def expand(self, pcoll):
    process_elements = ProcessElements(self._process_keyed_elements_transform)
    process_elements.args = (
        self._process_keyed_elements_transform.ptransform_args)
    process_elements.kwargs = (
        self._process_keyed_elements_transform.ptransform_kwargs)
    process_elements.side_inputs = (
        self._process_keyed_elements_transform.ptransform_side_inputs)
    return pcoll | beam.core.GroupByKey() | process_elements  
[docs]class ProcessElements(PTransform):
  """A primitive transform for processing keyed elements or KeyedWorkItems.
  Will be evaluated by
  `runners.direct.transform_evaluator._ProcessElementsEvaluator`.
  """
  def __init__(self, process_keyed_elements_transform):
    self._process_keyed_elements_transform = process_keyed_elements_transform
    self.sdf = self._process_keyed_elements_transform.sdf
[docs]  def expand(self, pcoll):
    return pvalue.PCollection.from_(pcoll) 
[docs]  def new_process_fn(self, sdf):
    return ProcessFn(
        sdf,
        self._process_keyed_elements_transform.ptransform_args,
        self._process_keyed_elements_transform.ptransform_kwargs)  
[docs]class ProcessFn(beam.DoFn):
  """A `DoFn` that executes machineary for invoking a Splittable `DoFn`.
  Input to the `ParDo` step that includes a `ProcessFn` will be a `PCollection`
  of `ElementAndRestriction` objects.
  This class is mainly responsible for following.
  (1) setup environment for properly invoking a Splittable `DoFn`.
  (2) invoke `process()` method of a Splittable `DoFn`.
  (3) after the `process()` invocation of the Splittable `DoFn`, determine if a
  re-invocation of the element is needed. If this is the case, set state and
  a timer for a re-invocation and hold output watermark till this
  re-invocation.
  (4) after the final invocation of a given element clear any previous state set
  for re-invoking the element and release the output watermark.
  """
  def __init__(self, sdf, args_for_invoker, kwargs_for_invoker):
    self.sdf = sdf
    self._element_tag = _ReadModifyWriteStateTag('element')
    self._restriction_tag = _ReadModifyWriteStateTag('restriction')
    self._watermark_state_tag = _ReadModifyWriteStateTag(
        'watermark_estimator_state')
    self.watermark_hold_tag = _ReadModifyWriteStateTag('watermark_hold')
    self._process_element_invoker = None
    self._output_processor = _OutputProcessor()
    self.sdf_invoker = DoFnInvoker.create_invoker(
        DoFnSignature(self.sdf),
        context=DoFnContext('unused_context'),
        output_processor=self._output_processor,
        input_args=args_for_invoker,
        input_kwargs=kwargs_for_invoker)
    self._step_context = None
  @property
  def step_context(self):
    return self._step_context
  @step_context.setter
  def step_context(self, step_context):
    assert isinstance(step_context, DirectStepContext)
    self._step_context = step_context
[docs]  def set_process_element_invoker(self, process_element_invoker):
    assert isinstance(process_element_invoker, SDFProcessElementInvoker)
    self._process_element_invoker = process_element_invoker 
[docs]  def process(
      self,
      element,
      timestamp=beam.DoFn.TimestampParam,
      window=beam.DoFn.WindowParam,
      *args,
      **kwargs):
    if isinstance(element, KeyedWorkItem):
      # Must be a timer firing.
      key = element.encoded_key
    else:
      key, values = element
      values = list(values)
      assert len(values) == 1
      # Value here will either be a WindowedValue or an ElementAndRestriction
      # object.
      # TODO: handle key collisions here.
      assert len(values) == 1, 'Internal error. Processing of splittable ' \
                               
'DoFn cannot continue since elements did not ' \
                               
'have unique keys.'
      value = values[0]
      if len(values) != 1:
        raise ValueError('')
    state = self._step_context.get_keyed_state(key)
    element_state = state.get_state(window, self._element_tag)
    # Initially element_state is an empty list.
    is_seed_call = not element_state
    if not is_seed_call:
      element = state.get_state(window, self._element_tag)
      restriction = state.get_state(window, self._restriction_tag)
      watermark_estimator_state = state.get_state(
          window, self._watermark_state_tag)
      windowed_element = WindowedValue(element, timestamp, [window])
    else:
      # After values iterator is expanded above we should have gotten a list
      # with a single ElementAndRestriction object.
      assert isinstance(value, ElementAndRestriction)
      element_and_restriction = value
      element = element_and_restriction.element
      restriction = element_and_restriction.restriction
      watermark_estimator_state = (
          element_and_restriction.watermark_estimator_state)
      if isinstance(value, WindowedValue):
        windowed_element = WindowedValue(
            element, value.timestamp, value.windows)
      else:
        windowed_element = WindowedValue(element, timestamp, [window])
    assert self._process_element_invoker
    assert isinstance(self._process_element_invoker, SDFProcessElementInvoker)
    output_values = self._process_element_invoker.invoke_process_element(
        self.sdf_invoker,
        self._output_processor,
        windowed_element,
        restriction,
        watermark_estimator_state,
        *args,
        **kwargs)
    sdf_result = None
    for output in output_values:
      if isinstance(output, SDFProcessElementInvoker.Result):
        # SDFProcessElementInvoker.Result should be the last item yielded.
        sdf_result = output
        break
      yield output
    assert sdf_result, ('SDFProcessElementInvoker must return a '
                        'SDFProcessElementInvoker.Result object as the last '
                        'value of a SDF invoke_process_element() invocation.')
    if not sdf_result.residual_restriction:
      # All work for current residual and restriction pair is complete.
      state.clear_state(window, self._element_tag)
      state.clear_state(window, self._restriction_tag)
      state.clear_state(window, self._watermark_state_tag)
      # Releasing output watermark by setting it to positive infinity.
      state.add_state(
          window, self.watermark_hold_tag, WatermarkManager.WATERMARK_POS_INF)
    else:
      state.add_state(window, self._element_tag, element)
      state.add_state(
          window, self._restriction_tag, sdf_result.residual_restriction)
      state.add_state(
          window, self._watermark_state_tag, watermark_estimator_state)
      # Holding output watermark by setting it to negative infinity.
      state.add_state(
          window, self.watermark_hold_tag, WatermarkManager.WATERMARK_NEG_INF)
      # Setting a timer to be reinvoked to continue processing the element.
      # Currently Python SDK only supports setting timers based on watermark. So
      # forcing a reinvocation by setting a timer for watermark negative
      # infinity.
      # TODO(chamikara): update this by setting a timer for the proper
      # processing time when Python SDK supports that.
      state.set_timer(
          window, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_NEG_INF)  
[docs]class SDFProcessElementInvoker(object):
  """A utility that invokes SDF `process()` method and requests checkpoints.
  This class is responsible for invoking the `process()` method of a Splittable
  `DoFn` and making sure that invocation terminated properly. Based on the input
  configuration, this class may decide to request a checkpoint for a `process()`
  execution so that runner can process current output and resume the invocation
  at a later time.
  More specifically, when initializing a `SDFProcessElementInvoker`, caller may
  specify the number of output elements or processing time after which a
  checkpoint should be requested. This class is responsible for properly
  requesting a checkpoint based on either of these criteria.
  When the `process()` call of Splittable `DoFn` ends, this class performs
  validations to make sure that processing ended gracefully and returns a
  `SDFProcessElementInvoker.Result` that contains information which can be used
  by the caller to perform another `process()` invocation for the residual.
  A `process()` invocation may decide to give up processing voluntarily by
  returning a `ProcessContinuation` object (see documentation of
  `ProcessContinuation` for more details). So if a 'ProcessContinuation' is
  produced this class ends the execution and performs steps to finalize the
  current invocation.
  """
[docs]  class Result(object):
    def __init__(
        self,
        residual_restriction=None,
        process_continuation=None,
        future_output_watermark=None):
      """Returned as a result of a `invoke_process_element()` invocation.
      Args:
        residual_restriction: a restriction for the unprocessed part of the
                             element.
        process_continuation: a `ProcessContinuation` if one was returned as the
                              last element of the SDF `process()` invocation.
        future_output_watermark: output watermark of the results that will be
                                 produced when invoking the Splittable `DoFn`
                                 for the current element with
                                 `residual_restriction`.
      """
      self.residual_restriction = residual_restriction
      self.process_continuation = process_continuation
      self.future_output_watermark = future_output_watermark 
  def __init__(self, max_num_outputs, max_duration):
    self._max_num_outputs = max_num_outputs
    self._max_duration = max_duration
    self._checkpoint_lock = Lock()
[docs]  def test_method(self):
    raise ValueError 
[docs]  def invoke_process_element(
      self,
      sdf_invoker,
      output_processor,
      element,
      restriction,
      watermark_estimator_state,
      *args,
      **kwargs):
    """Invokes `process()` method of a Splittable `DoFn` for a given element.
     Args:
       sdf_invoker: a `DoFnInvoker` for the Splittable `DoFn`.
       element: the element to process
     Returns:
       a `SDFProcessElementInvoker.Result` object.
     """
    assert isinstance(sdf_invoker, DoFnInvoker)
    class CheckpointState(object):
      def __init__(self):
        self.checkpointed = None
        self.residual_restriction = None
    checkpoint_state = CheckpointState()
    def initiate_checkpoint():
      with self._checkpoint_lock:
        if checkpoint_state.checkpointed:
          return
        checkpoint_state.checkpointed = object()
      split = sdf_invoker.try_split(0)
      if split:
        _, checkpoint_state.residual_restriction = split
      else:
        # Clear the checkpoint if the split didn't happen. This counters
        # a very unlikely race condition that the Timer attempted to initiate
        # a checkpoint before invoke_process set the current element allowing
        # for another attempt to checkpoint.
        checkpoint_state.checkpointed = None
    output_processor.reset()
    Timer(self._max_duration, initiate_checkpoint).start()
    sdf_invoker.invoke_process(
        element,
        additional_args=args,
        restriction=restriction,
        watermark_estimator_state=watermark_estimator_state)
    assert output_processor.output_iter is not None
    output_count = 0
    # We have to expand and re-yield here to support ending execution for a
    # given number of output elements as well as to capture the
    # ProcessContinuation of one was returned.
    process_continuation = None
    for output in output_processor.output_iter:
      # A ProcessContinuation, if returned, should be the last element.
      assert not process_continuation
      if isinstance(output, ProcessContinuation):
        # Taking a checkpoint so that we can determine primary and residual
        # restrictions.
        initiate_checkpoint()
        # A ProcessContinuation should always be the last element produced by
        # the output iterator.
        # TODO: support continuing after the specified amount of delay.
        # Continuing here instead of breaking to enforce that this is the last
        # element.
        process_continuation = output
        continue
      yield output
      output_count += 1
      if self._max_num_outputs and output_count >= self._max_num_outputs:
        initiate_checkpoint()
    result = (
        SDFProcessElementInvoker.Result(
            residual_restriction=checkpoint_state.residual_restriction)
        if checkpoint_state.residual_restriction else
        SDFProcessElementInvoker.Result())
    yield result  
class _OutputProcessor(OutputProcessor):
  def __init__(self):
    self.output_iter = None
  def process_outputs(
      self, windowed_input_element, output_iter, watermark_estimator=None):
    # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None
    self.output_iter = output_iter
  def reset(self):
    self.output_iter = None
class _NoneShallPassOutputProcessor(OutputProcessor):
  def process_outputs(
      self, windowed_input_element, output_iter, watermark_estimator=None):
    # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None
    raise RuntimeError()