#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utilities for gracefully handling errors and excluding bad elements."""
import traceback
from apache_beam import transforms
[docs]
class ErrorHandler:
  """ErrorHandlers are used to skip and otherwise process bad records.
  Error handlers allow one to implement the "dead letter queue" pattern in
  a fluent manner, disaggregating the error processing specification from
  the main processing chain.
  This is typically used as follows::
    with error_handling.ErrorHandler(WriteToSomewhere(...)) as error_handler:
      result = pcoll | SomeTransform().with_error_handler(error_handler)
  in which case errors encountered by `SomeTransform()`` in processing pcoll
  will be written by the PTransform `WriteToSomewhere(...)` and excluded from
  `result` rather than failing the pipeline.
  To implement `with_error_handling` on a PTransform, one caches the provided
  error handler for use in `expand`.  During `expand()` one can invoke
  `error_handler.add_error_pcollection(...)` any number of times with
  PCollections containing error records to be processed by the given error
  handler, or (if applicable) simply invoke `with_error_handling(...)` on any
  subtransforms.
  The `with_error_handling` should accept `None` to indicate that error handling
  is not enabled (and make implementation-by-forwarding-error-handlers easier).
  In this case, any non-recoverable errors should fail the pipeline (e.g.
  propagate exceptions in `process` methods) rather than silently ignore errors.
  """
  def __init__(self, consumer):
    self._consumer = consumer
    self._creation_traceback = traceback.format_stack()[-2]
    self._error_pcolls = []
    self._closed = False
  def __enter__(self):
    self._error_pcolls = []
    self._closed = False
    return self
  def __exit__(self, *exec_info):
    if exec_info[0] is None:
      self.close()
[docs]
  def close(self):
    """Indicates all error-producing operations have reported any errors.
    Invokes the provided error consuming PTransform on any provided error
    PCollections.
    """
    self._output = (
        tuple(self._error_pcolls) | transforms.Flatten() | self._consumer)
    self._closed = True 
[docs]
  def output(self):
    """Returns result of applying the error consumer to the error pcollections.
    """
    if not self._closed:
      raise RuntimeError(
          "Cannot access the output of an error handler "
          "until it has been closed.")
    return self._output 
[docs]
  def add_error_pcollection(self, pcoll):
    """Called by a class implementing error handling on the error records.
    """
    pcoll.pipeline._register_error_handler(self)
    self._error_pcolls.append(pcoll) 
[docs]
  def verify_closed(self):
    """Called at end of pipeline construction to ensure errors are not ignored.
    """
    if not self._closed:
      raise RuntimeError(
          "Unclosed error handler initialized at %s" % self._creation_traceback) 
 
class _IdentityPTransform(transforms.PTransform):
  def expand(self, pcoll):
    return pcoll
[docs]
class CollectingErrorHandler(ErrorHandler):
  """An ErrorHandler that simply collects all errors for further processing.
  This ErrorHandler requires the set of errors be retrieved via `output()`
  and consumed (or explicitly discarded).
  """
  def __init__(self):
    super().__init__(_IdentityPTransform())
    self._creation_traceback = traceback.format_stack()[-2]
    self._output_accessed = False
[docs]
  def output(self):
    self._output_accessed = True
    return super().output() 
[docs]
  def verify_closed(self):
    if not self._output_accessed:
      raise RuntimeError(
          "CollectingErrorHandler requires the output to be retrieved. "
          "Initialized at %s" % self._creation_traceback)
    return super().verify_closed()