#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A ValueProvider abstracts the notion of fetching a value that may or
may not be currently available.
This can be used to parameterize transforms that only read values in at
runtime, for example.
"""
# pytype: skip-file
from functools import wraps
from typing import Set
from apache_beam import error
__all__ = [
    'ValueProvider',
    'StaticValueProvider',
    'RuntimeValueProvider',
    'NestedValueProvider',
    'check_accessible',
]
[docs]class ValueProvider(object):
  """Base class that all other ValueProviders must implement.
  """
[docs]  def is_accessible(self):
    """Whether the contents of this ValueProvider is available to routines
    that run at graph construction time.
    """
    raise NotImplementedError(
        'ValueProvider.is_accessible implemented in derived classes') 
[docs]  def get(self):
    """Return the value wrapped by this ValueProvider.
    """
    raise NotImplementedError(
        'ValueProvider.get implemented in derived classes')  
[docs]class StaticValueProvider(ValueProvider):
  """StaticValueProvider is an implementation of ValueProvider that allows
  for a static value to be provided.
  """
  def __init__(self, value_type, value):
    """
    Args:
        value_type: Type of the static value
        value: Static value
    """
    self.value_type = value_type
    self.value = value_type(value)
[docs]  def is_accessible(self):
    return True 
[docs]  def get(self):
    return self.value 
  def __str__(self):
    return str(self.value)
  def __eq__(self, other):
    if self.value == other:
      return True
    if isinstance(other, StaticValueProvider):
      if (self.value_type == other.value_type and self.value == other.value):
        return True
    return False
  def __hash__(self):
    return hash((type(self), self.value_type, self.value)) 
[docs]class RuntimeValueProvider(ValueProvider):
  """RuntimeValueProvider is an implementation of ValueProvider that
  allows for a value to be provided at execution time rather than
  at graph construction time.
  """
  runtime_options = None
  experiments = set()  # type: Set[str]
  def __init__(self, option_name, value_type, default_value):
    self.option_name = option_name
    self.default_value = default_value
    self.value_type = value_type
[docs]  def is_accessible(self):
    return RuntimeValueProvider.runtime_options is not None 
[docs]  @classmethod
  def get_value(cls, option_name, value_type, default_value):
    if not RuntimeValueProvider.runtime_options:
      return default_value
    candidate = RuntimeValueProvider.runtime_options.get(option_name)
    if candidate:
      return value_type(candidate)
    else:
      return default_value 
[docs]  def get(self):
    if RuntimeValueProvider.runtime_options is None:
      raise error.RuntimeValueProviderError(
          '%s.get() not called from a runtime context' % self)
    return RuntimeValueProvider.get_value(
        self.option_name, self.value_type, self.default_value) 
[docs]  @classmethod
  def set_runtime_options(cls, pipeline_options):
    RuntimeValueProvider.runtime_options = pipeline_options
    RuntimeValueProvider.experiments = RuntimeValueProvider.get_value(
        'experiments', set, set()) 
  def __str__(self):
    return '%s(option: %s, type: %s, default_value: %s)' % (
        self.__class__.__name__,
        self.option_name,
        self.value_type.__name__,
        repr(self.default_value)) 
[docs]class NestedValueProvider(ValueProvider):
  """NestedValueProvider is an implementation of ValueProvider that allows
  for wrapping another ValueProvider object.
  """
  def __init__(self, value, translator):
    """Creates a NestedValueProvider that wraps the provided ValueProvider.
    Args:
      value: ValueProvider object to wrap
      translator: function that is applied to the ValueProvider
    Raises:
      ``RuntimeValueProviderError``: if any of the provided objects are not
        accessible.
    """
    self.value = value
    self.translator = translator
[docs]  def is_accessible(self):
    return self.value.is_accessible() 
[docs]  def get(self):
    try:
      return self.cached_value
    except AttributeError:
      self.cached_value = self.translator(self.value.get())
      return self.cached_value 
  def __str__(self):
    return "%s(value: %s, translator: %s)" % (
        self.__class__.__name__,
        self.value,
        self.translator.__name__,
    ) 
[docs]def check_accessible(value_provider_list):
  """A decorator that checks accessibility of a list of ValueProvider objects.
  Args:
    value_provider_list: list of ValueProvider objects
  Raises:
    ``RuntimeValueProviderError``: if any of the provided objects are not
      accessible.
  """
  assert isinstance(value_provider_list, list)
  def _check_accessible(fnc):
    @wraps(fnc)
    def _f(self, *args, **kwargs):
      for obj in [getattr(self, vp) for vp in value_provider_list]:
        if not obj.is_accessible():
          raise error.RuntimeValueProviderError('%s not accessible' % obj)
      return fnc(self, *args, **kwargs)
    return _f
  return _check_accessible