#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""``PTransform`` for reading from and writing to Web APIs."""
import abc
import concurrent.futures
import contextlib
import enum
import json
import logging
import sys
import time
from datetime import timedelta
from typing import Any
from typing import Dict
from typing import Generic
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
import redis
from google.api_core.exceptions import TooManyRequests
import apache_beam as beam
from apache_beam import pvalue
from apache_beam.coders import coders
from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
from apache_beam.metrics import Metrics
from apache_beam.ml.inference.vertex_ai_inference import MSEC_TO_SEC
from apache_beam.transforms.util import BatchElements
from apache_beam.utils import retry
RequestT = TypeVar('RequestT')
ResponseT = TypeVar('ResponseT')
# DEFAULT_TIMEOUT_SECS represents the time interval for completing the request
# with external source.
DEFAULT_TIMEOUT_SECS = 30
# DEFAULT_CACHE_ENTRY_TTL_SEC represents the total time-to-live
# for cache record.
DEFAULT_CACHE_ENTRY_TTL_SEC = 24 * 60 * 60
_LOGGER = logging.getLogger(__name__)
__all__ = [
'RequestResponseIO',
'ExponentialBackOffRepeater',
'DefaultThrottler',
'NoOpsRepeater',
'RedisCache',
]
class UserCodeExecutionException(Exception):
"""Base class for errors related to calling Web APIs."""
class UserCodeQuotaException(UserCodeExecutionException):
"""Extends ``UserCodeExecutionException`` to signal specifically that
the Web API client encountered a Quota or API overuse related error.
"""
class UserCodeTimeoutException(UserCodeExecutionException):
"""Extends ``UserCodeExecutionException`` to signal a user code timeout."""
def retry_on_exception(exception: Exception):
"""retry on exceptions caused by unavailability of the remote server."""
return isinstance(
exception,
(TooManyRequests, UserCodeTimeoutException, UserCodeQuotaException))
class _MetricsCollector:
"""A metrics collector that tracks RequestResponseIO related usage."""
def __init__(self, namespace: str):
"""
Args:
namespace: Namespace for the metrics.
"""
self.requests = Metrics.counter(namespace, 'requests')
self.responses = Metrics.counter(namespace, 'responses')
self.failures = Metrics.counter(namespace, 'failures')
self.throttled_requests = Metrics.counter(namespace, 'throttled_requests')
self.throttled_secs = Metrics.counter(
namespace, 'cumulativeThrottlingSeconds')
self.timeout_requests = Metrics.counter(namespace, 'requests_timed_out')
self.call_counter = Metrics.counter(namespace, 'call_invocations')
self.setup_counter = Metrics.counter(namespace, 'setup_counter')
self.teardown_counter = Metrics.counter(namespace, 'teardown_counter')
self.backoff_counter = Metrics.counter(namespace, 'backoff_counter')
self.sleeper_counter = Metrics.counter(namespace, 'sleeper_counter')
self.should_backoff_counter = Metrics.counter(
namespace, 'should_backoff_counter')
class Caller(contextlib.AbstractContextManager,
abc.ABC,
Generic[RequestT, ResponseT]):
"""Interface for user custom code intended for API calls.
For setup and teardown of clients when applicable, implement the
``__enter__`` and ``__exit__`` methods respectively."""
@abc.abstractmethod
def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT:
"""Calls a Web API with the ``RequestT`` and returns a
``ResponseT``. ``RequestResponseIO`` expects implementations of the
``__call__`` method to throw either a ``UserCodeExecutionException``,
``UserCodeQuotaException``, or ``UserCodeTimeoutException``.
"""
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return None
def get_cache_key(self, request: RequestT) -> str:
"""Returns the request to be cached.
This is how the response will be looked up in the cache as well.
By default, entire request is cached as the key for the cache.
Implement this method to override the key for the cache.
For example, in `BigTableEnrichmentHandler`, the row key for the element
is returned here.
"""
return ""
def batch_elements_kwargs(self) -> Mapping[str, Any]:
"""Returns a kwargs suitable for `beam.BatchElements`."""
return {}
class ShouldBackOff(abc.ABC):
"""
Provides mechanism to apply adaptive throttling.
"""
pass
class Repeater(abc.ABC):
"""Provides mechanism to repeat requests for a
configurable condition."""
@abc.abstractmethod
def repeat(
self,
caller: Caller[RequestT, ResponseT],
request: RequestT,
timeout: float,
metrics_collector: Optional[_MetricsCollector]) -> ResponseT:
"""Implements a repeater strategy for RequestResponseIO when a repeater
is enabled.
Args:
caller: a `~apache_beam.io.requestresponse.Caller` object that
calls the API.
request: input request to repeat.
timeout: time to wait for the request to complete.
metrics_collector: (Optional) a
`~apache_beam.io.requestresponse._MetricsCollector` object
to collect the metrics for RequestResponseIO.
"""
pass
def _execute_request(
caller: Caller[RequestT, ResponseT],
request: RequestT,
timeout: float,
metrics_collector: Optional[_MetricsCollector] = None) -> ResponseT:
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(caller, request)
try:
return future.result(timeout=timeout)
except TooManyRequests as e:
_LOGGER.info(
'request could not be completed. got code %i from the service.',
e.code)
raise e
except concurrent.futures.TimeoutError:
if metrics_collector:
metrics_collector.timeout_requests.inc(1)
raise UserCodeTimeoutException(
f'Timeout {timeout} exceeded '
f'while completing request: {request}')
except RuntimeError:
if metrics_collector:
metrics_collector.failures.inc(1)
raise UserCodeExecutionException('could not complete request')
[docs]
class ExponentialBackOffRepeater(Repeater):
"""Configure exponential backoff retry strategy.
It retries for exceptions due to the remote service such as
TooManyRequests (HTTP 429), UserCodeTimeoutException, UserCodeQuotaException.
It utilizes the decorator
:func:`apache_beam.utils.retry.with_exponential_backoff`.
"""
def __init__(self):
pass
[docs]
@retry.with_exponential_backoff(
num_retries=2, retry_filter=retry_on_exception)
def repeat(
self,
caller: Caller[RequestT, ResponseT],
request: RequestT,
timeout: float,
metrics_collector: Optional[_MetricsCollector] = None) -> ResponseT:
"""repeat method is called from the RequestResponseIO when
a repeater is enabled.
Args:
caller: a `~apache_beam.io.requestresponse.Caller` object that
calls the API.
request: input request to repeat.
timeout: time to wait for the request to complete.
metrics_collector: (Optional) a
`~apache_beam.io.requestresponse._MetricsCollector` object to
collect the metrics for RequestResponseIO.
"""
return _execute_request(caller, request, timeout, metrics_collector)
[docs]
class NoOpsRepeater(Repeater):
"""Executes a request just once irrespective of any exception.
"""
[docs]
def repeat(
self,
caller: Caller[RequestT, ResponseT],
request: RequestT,
timeout: float,
metrics_collector: Optional[_MetricsCollector]) -> ResponseT:
return _execute_request(caller, request, timeout, metrics_collector)
class PreCallThrottler(abc.ABC):
"""Provides a throttle mechanism before sending request."""
pass
[docs]
class DefaultThrottler(PreCallThrottler):
"""Default throttler that uses
:class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`
Args:
window_ms (int): length of history to consider, in ms, to set throttling.
bucket_ms (int): granularity of time buckets that we store data in, in ms.
overload_ratio (float): the target ratio between requests sent and
successful requests. This is "K" in the formula in
https://landing.google.com/sre/book/chapters/handling-overload.html.
delay_secs (int): minimum number of seconds to throttle a request.
"""
def __init__(
self,
window_ms: int = 1,
bucket_ms: int = 1,
overload_ratio: float = 2,
delay_secs: int = 5):
self.throttler = AdaptiveThrottler(
window_ms=window_ms, bucket_ms=bucket_ms, overload_ratio=overload_ratio)
self.delay_secs = delay_secs
class _FilterCacheReadFn(beam.DoFn):
"""A `DoFn` that partitions cache reads.
It emits to main output for successful cache read requests or
to the tagged output - `cache_misses` - otherwise."""
def process(self, element: Tuple[RequestT, ResponseT], *args, **kwargs):
if not element[1]:
yield pvalue.TaggedOutput('cache_misses', element[0])
else:
yield element
class _Call(beam.PTransform[beam.PCollection[RequestT],
beam.PCollection[ResponseT]]):
"""(Internal-only) PTransform that invokes a remote function on each element
of the input PCollection.
This PTransform uses a `Caller` object to invoke the actual API calls,
and uses ``__enter__`` and ``__exit__`` to manage setup and teardown of
clients when applicable. Additionally, a timeout value is specified to
regulate the duration of each call, defaults to 30 seconds.
Args:
caller: a `Caller` object that invokes API call.
timeout (float): timeout value in seconds to wait for response from API.
should_backoff: (Optional) provides methods for backoff.
repeater: (Optional) provides methods to repeat requests to API.
throttler: (Optional) provides methods to pre-throttle a request.
"""
def __init__(
self,
caller: Caller[RequestT, ResponseT],
timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
should_backoff: Optional[ShouldBackOff] = None,
repeater: Repeater = None,
throttler: PreCallThrottler = None,
):
self._caller = caller
self._timeout = timeout
self._should_backoff = should_backoff
self._repeater = repeater
self._throttler = throttler
def expand(
self,
requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
return requests | beam.ParDo(
_CallDoFn(self._caller, self._timeout, self._repeater, self._throttler))
class _CallDoFn(beam.DoFn):
def setup(self):
self._caller.__enter__()
self._metrics_collector = _MetricsCollector(self._caller.__str__())
self._metrics_collector.setup_counter.inc(1)
def __init__(
self,
caller: Caller[RequestT, ResponseT],
timeout: float,
repeater: Repeater,
throttler: PreCallThrottler):
self._metrics_collector = None
self._caller = caller
self._timeout = timeout
self._repeater = repeater
self._throttler = throttler
def process(self, request: RequestT, *args, **kwargs):
self._metrics_collector.requests.inc(1)
is_throttled_request = False
if self._throttler:
while self._throttler.throttler.throttle_request(time.time() *
MSEC_TO_SEC):
_LOGGER.info(
"Delaying request for %d seconds" % self._throttler.delay_secs)
time.sleep(self._throttler.delay_secs)
self._metrics_collector.throttled_secs.inc(self._throttler.delay_secs)
is_throttled_request = True
if is_throttled_request:
self._metrics_collector.throttled_requests.inc(1)
try:
req_time = time.time()
response = self._repeater.repeat(
self._caller, request, self._timeout, self._metrics_collector)
self._metrics_collector.responses.inc(1)
self._throttler.throttler.successful_request(req_time * MSEC_TO_SEC)
yield response
except Exception as e:
raise e
def teardown(self):
self._metrics_collector.teardown_counter.inc(1)
self._caller.__exit__(*sys.exc_info())
class Cache(abc.ABC):
"""Base Cache class for
:class:`apache_beam.io.requestresponse.RequestResponseIO`.
For adding cache support to RequestResponseIO, implement this class.
"""
@abc.abstractmethod
def get_read(self):
"""returns a PTransform that reads from the cache."""
pass
@abc.abstractmethod
def get_write(self):
"""returns a PTransform that writes to the cache."""
pass
@property
@abc.abstractmethod
def request_coder(self):
"""request coder to use with Cache."""
pass
@request_coder.setter
@abc.abstractmethod
def request_coder(self, request_coder: coders.Coder):
"""sets the request coder to use with Cache."""
pass
@property
@abc.abstractmethod
def source_caller(self):
"""Actual caller that is using the cache."""
pass
@source_caller.setter
@abc.abstractmethod
def source_caller(self, caller: Caller):
"""Sets the source caller for
:class:`apache_beam.io.requestresponse.RequestResponseIO` to pull
cache request key from respective callers."""
pass
class _RedisMode(enum.Enum):
"""
Mode of operation for redis cache when using
`~apache_beam.io.requestresponse._RedisCaller`.
"""
READ = 0
WRITE = 1
class _RedisCaller(Caller):
"""An implementation of
`~apache_beam.io.requestresponse.Caller` for Redis client.
It provides the functionality for making requests to Redis server using
:class:`apache_beam.io.requestresponse.RequestResponseIO`.
"""
def __init__(
self,
host: str,
port: int,
time_to_live: Union[int, timedelta],
*,
request_coder: Optional[coders.Coder],
response_coder: Optional[coders.Coder],
kwargs: Optional[Dict[str, Any]] = None,
source_caller: Optional[Caller] = None,
mode: _RedisMode,
):
"""
Args:
host (str): The hostname or IP address of the Redis server.
port (int): The port number of the Redis server.
time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
records stored in Redis. Provide an integer (in seconds) or a
`datetime.timedelta` object.
request_coder: (Optional[`coders.Coder`]) coder for requests stored
in Redis.
response_coder: (Optional[`coders.Coder`]) coder for decoding responses
received from Redis.
kwargs: Optional(Dict[str, Any]) additional keyword arguments that
are required to connect to your redis server. Same as `redis.Redis()`.
source_caller: (Optional[`Caller`]): The source caller using this Redis
cache in case of fetching the cache request to store in Redis.
mode: `_RedisMode` An enum type specifying the operational mode of
the `_RedisCaller`.
"""
self.host, self.port = host, port
self.time_to_live = time_to_live
self.request_coder = request_coder
self.response_coder = response_coder
self.kwargs = kwargs
self.source_caller = source_caller
self.mode = mode
def __enter__(self):
self.client = redis.Redis(self.host, self.port, **self.kwargs)
def _read_cache(self, element):
cache_request = self.source_caller.get_cache_key(element)
# check if the caller is a enrichment handler. EnrichmentHandler
# provides the request format for cache.
if cache_request:
encoded_request = self.request_coder.encode(cache_request)
else:
encoded_request = self.request_coder.encode(element)
encoded_response = self.client.get(encoded_request)
if not encoded_response:
# no cache entry present for this request.
return element, None
if self.response_coder is None:
try:
response_dict = json.loads(encoded_response.decode('utf-8'))
response = beam.Row(**response_dict)
except Exception:
_LOGGER.warning(
'cannot decode response from redis cache for %s.' % element)
return element, None
else:
response = self.response_coder.decode(encoded_response)
return element, response
def _write_cache(self, element):
cache_request = self.source_caller.get_cache_key(element[0])
if cache_request:
encoded_request = self.request_coder.encode(cache_request)
else:
encoded_request = self.request_coder.encode(element[0])
if self.response_coder is None:
try:
encoded_response = json.dumps(element[1]._asdict()).encode('utf-8')
except Exception:
_LOGGER.warning(
'cannot encode response %s for %s to store in '
'redis cache.' % (element[1], element[0]))
return element
else:
encoded_response = self.response_coder.encode(element[1])
# Write to cache with TTL. Set nx to True to prevent overwriting for the
# same key.
self.client.set(
encoded_request, encoded_response, self.time_to_live, nx=True)
return element
def __call__(self, element, *args, **kwargs):
if self.mode == _RedisMode.READ:
if isinstance(element, List):
responses = [self._read_cache(e) for e in element]
return responses
else:
return self._read_cache(element)
else:
if isinstance(element, List):
responses = [self._write_cache(e) for e in element]
return responses
else:
return self._write_cache(element)
def __exit__(self, exc_type, exc_val, exc_tb):
self.client.close()
class _ReadFromRedis(beam.PTransform[beam.PCollection[RequestT],
beam.PCollection[ResponseT]]):
"""A `PTransform` that performs Redis cache read."""
def __init__(
self,
host: str,
port: int,
time_to_live: Union[int, timedelta],
*,
kwargs: Optional[Dict[str, Any]] = None,
request_coder: Optional[coders.Coder],
response_coder: Optional[coders.Coder],
source_caller: Optional[Caller[RequestT, ResponseT]] = None,
):
"""
Args:
host (str): The hostname or IP address of the Redis server.
port (int): The port number of the Redis server.
time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
records stored in Redis. Provide an integer (in seconds) or a
`datetime.timedelta` object.
kwargs: Optional(Dict[str, Any]) additional keyword arguments that
are required to connect to your redis server. Same as `redis.Redis()`.
request_coder: (Optional[`coders.Coder`]) coder for requests stored
in Redis.
response_coder: (Optional[`coders.Coder`]) coder for decoding responses
received from Redis.
source_caller: (Optional[`Caller`]): The source caller using this Redis
cache in case of fetching the cache request to store in Redis.
"""
self.request_coder = request_coder
self.response_coder = response_coder
self.redis_caller = _RedisCaller(
host,
port,
time_to_live,
request_coder=self.request_coder,
response_coder=self.response_coder,
kwargs=kwargs,
source_caller=source_caller,
mode=_RedisMode.READ)
def expand(
self,
requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
return requests | RequestResponseIO(self.redis_caller)
class _WriteToRedis(beam.PTransform[beam.PCollection[Tuple[RequestT,
ResponseT]],
beam.PCollection[ResponseT]]):
"""A `PTransfrom` that performs write to Redis cache."""
def __init__(
self,
host: str,
port: int,
time_to_live: Union[int, timedelta],
*,
kwargs: Optional[Dict[str, Any]] = None,
request_coder: Optional[coders.Coder],
response_coder: Optional[coders.Coder],
source_caller: Optional[Caller[RequestT, ResponseT]] = None,
):
"""
Args:
host (str): The hostname or IP address of the Redis server.
port (int): The port number of the Redis server.
time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
records stored in Redis. Provide an integer (in seconds) or a
`datetime.timedelta` object.
kwargs: Optional(Dict[str, Any]) additional keyword arguments that
are required to connect to your redis server. Same as `redis.Redis()`.
request_coder: (Optional[`coders.Coder`]) coder for requests stored
in Redis.
response_coder: (Optional[`coders.Coder`]) coder for decoding responses
received from Redis.
source_caller: (Optional[`Caller`]): The source caller using this Redis
cache in case of fetching the cache request to store in Redis.
"""
self.request_coder = request_coder
self.response_coder = response_coder
self.redis_caller = _RedisCaller(
host,
port,
time_to_live,
request_coder=self.request_coder,
response_coder=self.response_coder,
kwargs=kwargs,
source_caller=source_caller,
mode=_RedisMode.WRITE)
def expand(
self, elements: beam.PCollection[Tuple[RequestT, ResponseT]]
) -> beam.PCollection[ResponseT]:
return elements | RequestResponseIO(self.redis_caller)
def ensure_coders_exist(request_coder):
"""checks if the coder exists to encode the request for caching."""
if not request_coder:
raise ValueError(
'need request coder to be able to use '
'Cache with RequestResponseIO.')
[docs]
class RedisCache(Cache):
"""Configure cache using Redis for
:class:`apache_beam.io.requestresponse.RequestResponseIO`."""
def __init__(
self,
host: str,
port: int,
time_to_live: Union[int, timedelta] = DEFAULT_CACHE_ENTRY_TTL_SEC,
*,
request_coder: Optional[coders.Coder] = None,
response_coder: Optional[coders.Coder] = None,
**kwargs,
):
"""
Args:
host (str): The hostname or IP address of the Redis server.
port (int): The port number of the Redis server.
time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
records stored in Redis. Provide an integer (in seconds) or a
`datetime.timedelta` object.
request_coder: (Optional[`coders.Coder`]) coder for encoding requests.
response_coder: (Optional[`coders.Coder`]) coder for decoding responses
received from Redis.
kwargs: Optional additional keyword arguments that
are required to connect to your redis server. Same as `redis.Redis()`.
"""
self._host = host
self._port = port
self._time_to_live = time_to_live
self._request_coder = request_coder
self._response_coder = response_coder
self._kwargs = kwargs if kwargs else {}
self._source_caller = None
[docs]
def get_read(self):
"""get_read returns a PTransform for reading from the cache."""
ensure_coders_exist(self._request_coder)
return _ReadFromRedis(
self._host,
self._port,
time_to_live=self._time_to_live,
kwargs=self._kwargs,
request_coder=self._request_coder,
response_coder=self._response_coder,
source_caller=self._source_caller)
[docs]
def get_write(self):
"""returns a PTransform for writing to the cache."""
ensure_coders_exist(self._request_coder)
return _WriteToRedis(
self._host,
self._port,
time_to_live=self._time_to_live,
kwargs=self._kwargs,
request_coder=self._request_coder,
response_coder=self._response_coder,
source_caller=self._source_caller)
@property
def source_caller(self):
return self._source_caller
@source_caller.setter
def source_caller(self, source_caller: Caller):
self._source_caller = source_caller
@property
def request_coder(self):
return self._request_coder
@request_coder.setter
def request_coder(self, request_coder: coders.Coder):
self._request_coder = request_coder
class FlattenBatch(beam.DoFn):
"""Flatten a batched PCollection."""
def process(self, elements, *args, **kwargs):
for element in elements:
yield element
[docs]
class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT],
beam.PCollection[ResponseT]]):
"""A :class:`RequestResponseIO` transform to read and write to APIs.
Processes an input :class:`~apache_beam.pvalue.PCollection` of requests
by making a call to the API as defined in `Caller`'s `__call__` method
and returns a :class:`~apache_beam.pvalue.PCollection` of responses.
"""
def __init__(
self,
caller: Caller[RequestT, ResponseT],
timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
should_backoff: Optional[ShouldBackOff] = None,
repeater: Repeater = ExponentialBackOffRepeater(),
cache: Optional[Cache] = None,
throttler: PreCallThrottler = DefaultThrottler(),
):
"""
Instantiates a RequestResponseIO transform.
Args:
caller: an implementation of
`Caller` object that makes call to the API.
timeout (float): timeout value in seconds to wait for response from API.
should_backoff: (Optional) provides methods for backoff.
repeater: provides method to repeat failed requests to API due to service
errors. Defaults to
:class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to
repeat requests with exponential backoff.
cache: (Optional) a `~apache_beam.io.requestresponse.Cache` object
to use the appropriate cache.
throttler: provides methods to pre-throttle a request. Defaults to
:class:`apache_beam.io.requestresponse.DefaultThrottler` for
client-side adaptive throttling using
:class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`
"""
self._caller = caller
self._timeout = timeout
self._should_backoff = should_backoff
if repeater:
self._repeater = repeater
else:
self._repeater = NoOpsRepeater()
self._cache = cache
self._throttler = throttler
self._batching_kwargs = self._caller.batch_elements_kwargs()
[docs]
def expand(
self,
requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
# TODO(riteshghorse): handle Throttle PTransforms when available.
if self._cache:
self._cache.source_caller = self._caller
inputs = requests
if self._cache:
# read from cache.
outputs = inputs | self._cache.get_read()
# filter responses that are None and send them to the Call transform
# to fetch a value from external service.
cached_responses, inputs = (outputs
| beam.ParDo(_FilterCacheReadFn()
).with_outputs(
'cache_misses', main='cached_responses'))
# Batch elements if batching is enabled.
if self._batching_kwargs:
inputs = inputs | BatchElements(**self._batching_kwargs)
if isinstance(self._throttler, DefaultThrottler):
# DefaultThrottler applies throttling in the DoFn of
# Call PTransform.
responses = (
inputs
| _Call(
caller=self._caller,
timeout=self._timeout,
should_backoff=self._should_backoff,
repeater=self._repeater,
throttler=self._throttler))
else:
# No throttling mechanism. The requests are made to the external source
# as they come.
responses = (
inputs
| _Call(
caller=self._caller,
timeout=self._timeout,
should_backoff=self._should_backoff,
repeater=self._repeater))
# if batching is enabled then handle accordingly.
if self._batching_kwargs:
responses = responses | "FlattenBatch" >> beam.ParDo(FlattenBatch())
if self._cache:
# write to cache.
_ = responses | self._cache.get_write()
return (cached_responses, responses) | beam.Flatten()
return responses