#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""DaskRunner, executing remote jobs on Dask.distributed.
The DaskRunner is a runner implementation that executes a graph of
transformations across processes and workers via Dask distributed's
scheduler.
"""
import argparse
import dataclasses
import typing as t
from apache_beam import pvalue
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PipelineVisitor
from apache_beam.runners.dask.overrides import dask_overrides
from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS
from apache_beam.runners.dask.transform_evaluator import NoOp
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineState
from apache_beam.utils.interactive_utils import is_in_notebook
[docs]class DaskOptions(PipelineOptions):
@staticmethod
def _parse_timeout(candidate):
try:
return int(candidate)
except (TypeError, ValueError):
import dask
return dask.config.no_default
@classmethod
def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
'--dask_client_address',
dest='address',
type=str,
default=None,
help='Address of a dask Scheduler server. Will default to a '
'`dask.LocalCluster()`.')
parser.add_argument(
'--dask_connection_timeout',
dest='timeout',
type=DaskOptions._parse_timeout,
help='Timeout duration for initial connection to the scheduler.')
parser.add_argument(
'--dask_scheduler_file',
dest='scheduler_file',
type=str,
default=None,
help='Path to a file with scheduler information if available.')
# TODO(alxr): Add options for security.
parser.add_argument(
'--dask_client_name',
dest='name',
type=str,
default=None,
help='Gives the client a name that will be included in logs generated '
'on the scheduler for matters relating to this client.')
parser.add_argument(
'--dask_connection_limit',
dest='connection_limit',
type=int,
default=512,
help='The number of open comms to maintain at once in the connection '
'pool.')
[docs]@dataclasses.dataclass
class DaskRunnerResult(PipelineResult):
from dask import distributed
client: distributed.Client
futures: t.Sequence[distributed.Future]
def __post_init__(self):
super().__init__(PipelineState.RUNNING)
[docs] def wait_until_finish(self, duration=None) -> str:
try:
if duration is not None:
# Convert milliseconds to seconds
duration /= 1000
self.client.wait_for_workers(timeout=duration)
self.client.gather(self.futures, errors='raise')
self._state = PipelineState.DONE
except: # pylint: disable=broad-except
self._state = PipelineState.FAILED
raise
return self._state
[docs] def cancel(self) -> str:
self._state = PipelineState.CANCELLING
self.client.cancel(self.futures)
self._state = PipelineState.CANCELLED
return self._state
[docs] def metrics(self):
# TODO(alxr): Collect and return metrics...
raise NotImplementedError('collecting metrics will come later!')
[docs]class DaskRunner(BundleBasedDirectRunner):
"""Executes a pipeline on a Dask distributed client."""
[docs] @staticmethod
def to_dask_bag_visitor() -> PipelineVisitor:
from dask import bag as db
@dataclasses.dataclass
class DaskBagVisitor(PipelineVisitor):
bags: t.Dict[AppliedPTransform,
db.Bag] = dataclasses.field(default_factory=dict)
def visit_transform(self, transform_node: AppliedPTransform) -> None:
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
op = op_class(transform_node)
inputs = list(transform_node.inputs)
if inputs:
bag_inputs = []
for input_value in inputs:
if isinstance(input_value, pvalue.PBegin):
bag_inputs.append(None)
prev_op = input_value.producer
if prev_op in self.bags:
bag_inputs.append(self.bags[prev_op])
if len(bag_inputs) == 1:
self.bags[transform_node] = op.apply(bag_inputs[0])
else:
self.bags[transform_node] = op.apply(bag_inputs)
else:
self.bags[transform_node] = op.apply(None)
return DaskBagVisitor()
[docs] @staticmethod
def is_fnapi_compatible():
return False
[docs] def run_pipeline(self, pipeline, options):
# TODO(alxr): Create interactive notebook support.
if is_in_notebook():
raise NotImplementedError('interactive support will come later!')
try:
import dask.distributed as ddist
except ImportError:
raise ImportError(
'DaskRunner is not available. Please install apache_beam[dask].')
dask_options = options.view_as(DaskOptions).get_all_options(
drop_default=True)
client = ddist.Client(**dask_options)
pipeline.replace_all(dask_overrides())
dask_visitor = self.to_dask_bag_visitor()
pipeline.visit(dask_visitor)
futures = client.compute(list(dask_visitor.bags.values()))
return DaskRunnerResult(client, futures)