#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""DaskRunner, executing remote jobs on Dask.distributed.
The DaskRunner is a runner implementation that executes a graph of
transformations across processes and workers via Dask distributed's
scheduler.
"""
import argparse
import collections
import dataclasses
import typing as t
from apache_beam import pvalue
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PipelineVisitor
from apache_beam.runners.dask.overrides import dask_overrides
from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS
from apache_beam.runners.dask.transform_evaluator import DaskBagWindowedIterator
from apache_beam.runners.dask.transform_evaluator import Flatten
from apache_beam.runners.dask.transform_evaluator import NoOp
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineState
from apache_beam.transforms.sideinputs import SideInputMap
from apache_beam.utils.interactive_utils import is_in_notebook
try:
  # Added to try to prevent threading related issues, see
  # https://github.com/pytest-dev/pytest/issues/3216#issuecomment-1502451456
  import dask.distributed as ddist
except ImportError:
  ddist = {}
[docs]
class DaskOptions(PipelineOptions):
  @staticmethod
  def _parse_timeout(candidate):
    try:
      return int(candidate)
    except (TypeError, ValueError):
      import dask
      return dask.config.no_default
  @staticmethod
  def _extract_bag_kwargs(dask_options: t.Dict) -> t.Dict:
    """Parse keyword arguments for `dask.Bag`s; used in graph translation."""
    out = {}
    if npartitions := dask_options.pop('npartitions', None):
      out['npartitions'] = npartitions
    if partition_size := dask_options.pop('partition_size', None):
      out['partition_size'] = partition_size
    return out
  @classmethod
  def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        '--dask_client_address',
        dest='address',
        type=str,
        default=None,
        help='Address of a dask Scheduler server. Will default to a '
        '`dask.LocalCluster()`.')
    parser.add_argument(
        '--dask_connection_timeout',
        dest='timeout',
        type=DaskOptions._parse_timeout,
        help='Timeout duration for initial connection to the scheduler.')
    parser.add_argument(
        '--dask_scheduler_file',
        dest='scheduler_file',
        type=str,
        default=None,
        help='Path to a file with scheduler information if available.')
    # TODO(alxr): Add options for security.
    parser.add_argument(
        '--dask_client_name',
        dest='name',
        type=str,
        default=None,
        help='Gives the client a name that will be included in logs generated '
        'on the scheduler for matters relating to this client.')
    parser.add_argument(
        '--dask_connection_limit',
        dest='connection_limit',
        type=int,
        default=512,
        help='The number of open comms to maintain at once in the connection '
        'pool.')
    partitions_parser = parser.add_mutually_exclusive_group()
    partitions_parser.add_argument(
        '--dask_npartitions',
        dest='npartitions',
        type=int,
        default=None,
        help='The desired number of `dask.Bag` partitions. When unspecified, '
        'an educated guess is made.')
    partitions_parser.add_argument(
        '--dask_partition_size',
        dest='partition_size',
        type=int,
        default=None,
        help='The length of each `dask.Bag` partition. When unspecified, '
        'an educated guess is made.') 
[docs]
@dataclasses.dataclass
class DaskRunnerResult(PipelineResult):
  client: ddist.Client
  futures: t.Sequence[ddist.Future]
  def __post_init__(self):
    super().__init__(PipelineState.RUNNING)
[docs]
  def wait_until_finish(self, duration=None) -> str:
    try:
      if duration is not None:
        # Convert milliseconds to seconds
        duration /= 1000
      for _ in ddist.as_completed(self.futures,
                                  timeout=duration,
                                  with_results=True):
        # without gathering results, worker errors are not raised on the client:
        # https://distributed.dask.org/en/stable/resilience.html#user-code-failures
        # so we want to gather results to raise errors client-side, but we do
        # not actually need to use the results here, so we just pass. to gather,
        # we use the iterative `as_completed(..., with_results=True)`, instead
        # of aggregate `client.gather`, to minimize memory footprint of results.
        pass
      self._state = PipelineState.DONE
    except:  # pylint: disable=broad-except
      self._state = PipelineState.FAILED
      raise
    return self._state 
[docs]
  def cancel(self) -> str:
    self._state = PipelineState.CANCELLING
    self.client.cancel(self.futures)
    self._state = PipelineState.CANCELLED
    return self._state 
[docs]
  def metrics(self):
    # TODO(alxr): Collect and return metrics...
    raise NotImplementedError('collecting metrics will come later!') 
 
[docs]
class DaskRunner(BundleBasedDirectRunner):
  """Executes a pipeline on a Dask distributed client."""
[docs]
  @staticmethod
  def to_dask_bag_visitor(bag_kwargs=None) -> PipelineVisitor:
    from dask import bag as db
    if bag_kwargs is None:
      bag_kwargs = {}
    @dataclasses.dataclass
    class DaskBagVisitor(PipelineVisitor):
      bags: t.Dict[AppliedPTransform, db.Bag] = dataclasses.field(
          default_factory=collections.OrderedDict)
      def visit_transform(self, transform_node: AppliedPTransform) -> None:
        op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
        op = op_class(transform_node, bag_kwargs=bag_kwargs)
        op_kws = {"input_bag": None, "side_inputs": None}
        inputs = list(transform_node.inputs)
        if inputs:
          bag_inputs = []
          for input_value in inputs:
            if isinstance(input_value, pvalue.PBegin):
              bag_inputs.append(None)
            prev_op = input_value.producer
            if prev_op in self.bags:
              bag_inputs.append(self.bags[prev_op])
          # Input to `Flatten` could be of length 1, e.g. a single-element
          # tuple: `(pcoll, ) | beam.Flatten()`. If so, we still pass it as
          # an iterable, because `Flatten.apply` always takes an iterable.
          if len(bag_inputs) == 1 and not isinstance(op, Flatten):
            op_kws["input_bag"] = bag_inputs[0]
          else:
            op_kws["input_bag"] = bag_inputs
        side_inputs = list(transform_node.side_inputs)
        if side_inputs:
          bag_side_inputs = []
          for si in side_inputs:
            si_asbag = self.bags.get(si.pvalue.producer)
            bag_side_inputs.append(
                SideInputMap(
                    type(si),
                    si._view_options(),
                    DaskBagWindowedIterator(si_asbag, si._window_mapping_fn)))
          op_kws["side_inputs"] = bag_side_inputs
        self.bags[transform_node] = op.apply(**op_kws)
    return DaskBagVisitor() 
[docs]
  @staticmethod
  def is_fnapi_compatible():
    return False 
[docs]
  def run_pipeline(self, pipeline, options):
    import dask
    # TODO(alxmrs): Create interactive notebook support.
    if is_in_notebook():
      raise NotImplementedError('interactive support will come later!')
    try:
      import dask.distributed as ddist
    except ImportError:
      raise ImportError(
          'DaskRunner is not available. Please install apache_beam[dask].')
    dask_options = options.view_as(DaskOptions).get_all_options(
        drop_default=True)
    bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
    client = ddist.Client(**dask_options)
    pipeline.replace_all(dask_overrides())
    dask_visitor = self.to_dask_bag_visitor(bag_kwargs)
    pipeline.visit(dask_visitor)
    # The dictionary in this visitor keeps a mapping of every Beam
    # PTransform to the equivalent Bag operation. This is highly
    # redundant. Thus, we can get away with computing just the last
    # value, which should be connected to the full Bag Task Graph.
    opt_graph = dask.optimize(list(dask_visitor.bags.values())[-1])
    futures = client.compute(opt_graph)
    return DaskRunnerResult(client, futures)