#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A connector for sending API requests to the GCP Recommendations AI
API (https://cloud.google.com/recommendations).
"""
from __future__ import absolute_import
from typing import Sequence
from typing import Tuple
from google.api_core.retry import Retry
from apache_beam import pvalue
from apache_beam.metrics import Metrics
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.transforms import DoFn
from apache_beam.transforms import ParDo
from apache_beam.transforms import PTransform
from apache_beam.transforms.util import GroupIntoBatches
from cachetools.func import ttl_cache
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
try:
  from google.cloud import recommendationengine
except ImportError:
  raise ImportError(
      'Google Cloud Recommendation AI not supported for this execution '
      'environment (could not import google.cloud.recommendationengine).')
# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
__all__ = [
    'CreateCatalogItem',
    'WriteUserEvent',
    'ImportCatalogItems',
    'ImportUserEvents',
    'PredictUserEvent'
]
FAILED_CATALOG_ITEMS = "failed_catalog_items"
@ttl_cache(maxsize=128, ttl=3600)
def get_recommendation_prediction_client():
  """Returns a Recommendation AI - Prediction Service client."""
  _client = recommendationengine.PredictionServiceClient()
  return _client
@ttl_cache(maxsize=128, ttl=3600)
def get_recommendation_catalog_client():
  """Returns a Recommendation AI - Catalog Service client."""
  _client = recommendationengine.CatalogServiceClient()
  return _client
@ttl_cache(maxsize=128, ttl=3600)
def get_recommendation_user_event_client():
  """Returns a Recommendation AI - UserEvent Service client."""
  _client = recommendationengine.UserEventServiceClient()
  return _client
[docs]class CreateCatalogItem(PTransform):
  """Creates catalogitem information.
    The ``PTransform`` returns a PCollectionTuple with a PCollections of
    successfully and failed created CatalogItems.
    Example usage::
      pipeline | CreateCatalogItem(
        project='example-gcp-project',
        catalog_name='my-catalog')
    """
  def __init__(
      self,
      project: str = None,
      retry: Retry = None,
      timeout: float = 120,
      metadata: Sequence[Tuple[str, str]] = (),
      catalog_name: str = "default_catalog"):
    """Initializes a :class:`CreateCatalogItem` transform.
        Args:
            project (str): Optional. GCP project name in which the catalog
              data will be imported.
            retry: Optional. Designation of what
              errors, if any, should be retried.
            timeout (float): Optional. The amount of time, in seconds, to wait
              for the request to complete.
            metadata: Optional. Strings which
              should be sent along with the request as metadata.
            catalog_name (str): Optional. Name of the catalog.
              Default: 'default_catalog'
        """
    self.project = project
    self.retry = retry
    self.timeout = timeout
    self.metadata = metadata
    self.catalog_name = catalog_name
[docs]  def expand(self, pcoll):
    if self.project is None:
      self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
    if self.project is None:
      raise ValueError(
          """GCP project name needs to be specified in "project" pipeline
            option""")
    return pcoll | ParDo(
        _CreateCatalogItemFn(
            self.project,
            self.retry,
            self.timeout,
            self.metadata,
            self.catalog_name))  
class _CreateCatalogItemFn(DoFn):
  def __init__(
      self,
      project: str = None,
      retry: Retry = None,
      timeout: float = 120,
      metadata: Sequence[Tuple[str, str]] = (),
      catalog_name: str = None):
    self._client = None
    self.retry = retry
    self.timeout = timeout
    self.metadata = metadata
    self.parent = f"projects/{project}/locations/global/catalogs/{catalog_name}"
    self.counter = Metrics.counter(self.__class__, "api_calls")
  def setup(self):
    if self._client is None:
      self._client = get_recommendation_catalog_client()
  def process(self, element):
    catalog_item = recommendationengine.CatalogItem(element)
    request = recommendationengine.CreateCatalogItemRequest(
        parent=self.parent, catalog_item=catalog_item)
    try:
      created_catalog_item = self._client.create_catalog_item(
          request=request,
          retry=self.retry,
          timeout=self.timeout,
          metadata=self.metadata)
      self.counter.inc()
      yield recommendationengine.CatalogItem.to_dict(created_catalog_item)
    except Exception:
      yield pvalue.TaggedOutput(
          FAILED_CATALOG_ITEMS,
          recommendationengine.CatalogItem.to_dict(catalog_item))
[docs]class ImportCatalogItems(PTransform):
  """Imports catalogitems in bulk.
    The `PTransform` returns a PCollectionTuple with PCollections of
    successfully and failed imported CatalogItems.
    Example usage::
      pipeline
      | ImportCatalogItems(
          project='example-gcp-project',
          catalog_name='my-catalog')
    """
  def __init__(
      self,
      max_batch_size: int = 5000,
      project: str = None,
      retry: Retry = None,
      timeout: float = 120,
      metadata: Sequence[Tuple[str, str]] = (),
      catalog_name: str = "default_catalog"):
    """Initializes a :class:`ImportCatalogItems` transform
        Args:
            batch_size (int): Required. Maximum number of catalogitems per
              request.
            project (str): Optional. GCP project name in which the catalog
              data will be imported.
            retry: Optional. Designation of what
              errors, if any, should be retried.
            timeout (float): Optional. The amount of time, in seconds, to wait
              for the request to complete.
            metadata: Optional. Strings which
              should be sent along with the request as metadata.
            catalog_name (str): Optional. Name of the catalog.
              Default: 'default_catalog'
        """
    self.max_batch_size = max_batch_size
    self.project = project
    self.retry = retry
    self.timeout = timeout
    self.metadata = metadata
    self.catalog_name = catalog_name
[docs]  def expand(self, pcoll):
    if self.project is None:
      self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
    if self.project is None:
      raise ValueError(
          'GCP project name needs to be specified in "project" pipeline option')
    return (
        pcoll | GroupIntoBatches.WithShardedKey(self.max_batch_size) | ParDo(
            _ImportCatalogItemsFn(
                self.project,
                self.retry,
                self.timeout,
                self.metadata,
                self.catalog_name)))  
class _ImportCatalogItemsFn(DoFn):
  def __init__(
      self,
      project=None,
      retry=None,
      timeout=120,
      metadata=None,
      catalog_name=None):
    self._client = None
    self.retry = retry
    self.timeout = timeout
    self.metadata = metadata
    self.parent = f"projects/{project}/locations/global/catalogs/{catalog_name}"
    self.counter = Metrics.counter(self.__class__, "api_calls")
  def setup(self):
    if self._client is None:
      self.client = get_recommendation_catalog_client()
  def process(self, element):
    catalog_items = [recommendationengine.CatalogItem(e) for e in element[1]]
    catalog_inline_source = recommendationengine.CatalogInlineSource(
        {"catalog_items": catalog_items})
    input_config = recommendationengine.InputConfig(
        catalog_inline_source=catalog_inline_source)
    request = recommendationengine.ImportCatalogItemsRequest(
        parent=self.parent, input_config=input_config)
    try:
      operation = self._client.import_catalog_items(
          request=request,
          retry=self.retry,
          timeout=self.timeout,
          metadata=self.metadata)
      self.counter.inc(len(catalog_items))
      yield operation.result()
    except Exception:
      yield pvalue.TaggedOutput(FAILED_CATALOG_ITEMS, catalog_items)
[docs]class WriteUserEvent(PTransform):
  """Write user event information.
    The `PTransform` returns a PCollectionTuple with PCollections of
    successfully and failed written UserEvents.
    Example usage::
      pipeline
      | WriteUserEvent(
          project='example-gcp-project',
          catalog_name='my-catalog',
          event_store='my_event_store')
    """
  def __init__(
      self,
      project: str = None,
      retry: Retry = None,
      timeout: float = 120,
      metadata: Sequence[Tuple[str, str]] = (),
      catalog_name: str = "default_catalog",
      event_store: str = "default_event_store"):
    """Initializes a :class:`WriteUserEvent` transform.
        Args:
            project (str): Optional. GCP project name in which the catalog
              data will be imported.
            retry: Optional. Designation of what
              errors, if any, should be retried.
            timeout (float): Optional. The amount of time, in seconds, to wait
              for the request to complete.
            metadata: Optional. Strings which
              should be sent along with the request as metadata.
            catalog_name (str): Optional. Name of the catalog.
              Default: 'default_catalog'
            event_store (str): Optional. Name of the event store.
              Default: 'default_event_store'
        """
    self.project = project
    self.retry = retry
    self.timeout = timeout
    self.metadata = metadata
    self.catalog_name = catalog_name
    self.event_store = event_store
[docs]  def expand(self, pcoll):
    if self.project is None:
      self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
    if self.project is None:
      raise ValueError(
          'GCP project name needs to be specified in "project" pipeline option')
    return pcoll | ParDo(
        _WriteUserEventFn(
            self.project,
            self.retry,
            self.timeout,
            self.metadata,
            self.catalog_name,
            self.event_store))  
class _WriteUserEventFn(DoFn):
  FAILED_USER_EVENTS = "failed_user_events"
  def __init__(
      self,
      project=None,
      retry=None,
      timeout=120,
      metadata=None,
      catalog_name=None,
      event_store=None):
    self._client = None
    self.retry = retry
    self.timeout = timeout
    self.metadata = metadata
    self.parent = f"projects/{project}/locations/global/catalogs/"\
                  f"{catalog_name}/eventStores/{event_store}"
    self.counter = Metrics.counter(self.__class__, "api_calls")
  def setup(self):
    if self._client is None:
      self._client = get_recommendation_user_event_client()
  def process(self, element):
    user_event = recommendationengine.UserEvent(element)
    request = recommendationengine.WriteUserEventRequest(
        parent=self.parent, user_event=user_event)
    try:
      created_user_event = self._client.write_user_event(request)
      self.counter.inc()
      yield recommendationengine.UserEvent.to_dict(created_user_event)
    except Exception:
      yield pvalue.TaggedOutput(
          self.FAILED_USER_EVENTS,
          recommendationengine.UserEvent.to_dict(user_event))
[docs]class ImportUserEvents(PTransform):
  """Imports userevents in bulk.
    The `PTransform` returns a PCollectionTuple with PCollections of
    successfully and failed imported UserEvents.
    Example usage::
      pipeline
      | ImportUserEvents(
          project='example-gcp-project',
          catalog_name='my-catalog',
          event_store='my_event_store')
    """
  def __init__(
      self,
      max_batch_size: int = 5000,
      project: str = None,
      retry: Retry = None,
      timeout: float = 120,
      metadata: Sequence[Tuple[str, str]] = (),
      catalog_name: str = "default_catalog",
      event_store: str = "default_event_store"):
    """Initializes a :class:`WriteUserEvent` transform.
        Args:
            batch_size (int): Required. Maximum number of catalogitems
              per request.
            project (str): Optional. GCP project name in which the catalog
              data will be imported.
            retry: Optional. Designation of what
              errors, if any, should be retried.
            timeout (float): Optional. The amount of time, in seconds, to wait
              for the request to complete.
            metadata: Optional. Strings which
              should be sent along with the request as metadata.
            catalog_name (str): Optional. Name of the catalog.
              Default: 'default_catalog'
            event_store (str): Optional. Name of the event store.
              Default: 'default_event_store'
        """
    self.max_batch_size = max_batch_size
    self.project = project
    self.retry = retry
    self.timeout = timeout
    self.metadata = metadata
    self.catalog_name = catalog_name
    self.event_store = event_store
[docs]  def expand(self, pcoll):
    if self.project is None:
      self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
    if self.project is None:
      raise ValueError(
          'GCP project name needs to be specified in "project" pipeline option')
    return (
        pcoll | GroupIntoBatches.WithShardedKey(self.max_batch_size) | ParDo(
            _ImportUserEventsFn(
                self.project,
                self.retry,
                self.timeout,
                self.metadata,
                self.catalog_name,
                self.event_store)))  
class _ImportUserEventsFn(DoFn):
  FAILED_USER_EVENTS = "failed_user_events"
  def __init__(
      self,
      project=None,
      retry=None,
      timeout=120,
      metadata=None,
      catalog_name=None,
      event_store=None):
    self._client = None
    self.retry = retry
    self.timeout = timeout
    self.metadata = metadata
    self.parent = f"projects/{project}/locations/global/catalogs/"\
                  f"{catalog_name}/eventStores/{event_store}"
    self.counter = Metrics.counter(self.__class__, "api_calls")
  def setup(self):
    if self._client is None:
      self.client = get_recommendation_user_event_client()
  def process(self, element):
    user_events = [recommendationengine.UserEvent(e) for e in element[1]]
    user_event_inline_source = recommendationengine.UserEventInlineSource(
        {"user_events": user_events})
    input_config = recommendationengine.InputConfig(
        user_event_inline_source=user_event_inline_source)
    request = recommendationengine.ImportUserEventsRequest(
        parent=self.parent, input_config=input_config)
    try:
      operation = self._client.write_user_event(request)
      self.counter.inc(len(user_events))
      yield recommendationengine.PredictResponse.to_dict(operation.result())
    except Exception:
      yield pvalue.TaggedOutput(self.FAILED_USER_EVENTS, user_events)
[docs]class PredictUserEvent(PTransform):
  """Make a recommendation prediction.
    The `PTransform` returns a PCollection
    Example usage::
      pipeline
      | PredictUserEvent(
          project='example-gcp-project',
          catalog_name='my-catalog',
          event_store='my_event_store',
          placement_id='recently_viewed_default')
    """
  def __init__(
      self,
      project: str = None,
      retry: Retry = None,
      timeout: float = 120,
      metadata: Sequence[Tuple[str, str]] = (),
      catalog_name: str = "default_catalog",
      event_store: str = "default_event_store",
      placement_id: str = None):
    """Initializes a :class:`PredictUserEvent` transform.
        Args:
            project (str): Optional. GCP project name in which the catalog
              data will be imported.
            retry: Optional. Designation of what
              errors, if any, should be retried.
            timeout (float): Optional. The amount of time, in seconds, to wait
              for the request to complete.
            metadata: Optional. Strings which
              should be sent along with the request as metadata.
            catalog_name (str): Optional. Name of the catalog.
              Default: 'default_catalog'
            event_store (str): Optional. Name of the event store.
              Default: 'default_event_store'
            placement_id (str): Required. ID of the recommendation engine
              placement. This id is used to identify the set of models that
              will be used to make the prediction.
        """
    self.project = project
    self.retry = retry
    self.timeout = timeout
    self.metadata = metadata
    self.placement_id = placement_id
    self.catalog_name = catalog_name
    self.event_store = event_store
    if placement_id is None:
      raise ValueError('placement_id must be specified')
    else:
      self.placement_id = placement_id
[docs]  def expand(self, pcoll):
    if self.project is None:
      self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
    if self.project is None:
      raise ValueError(
          'GCP project name needs to be specified in "project" pipeline option')
    return pcoll | ParDo(
        _PredictUserEventFn(
            self.project,
            self.retry,
            self.timeout,
            self.metadata,
            self.catalog_name,
            self.event_store,
            self.placement_id))  
class _PredictUserEventFn(DoFn):
  FAILED_PREDICTIONS = "failed_predictions"
  def __init__(
      self,
      project=None,
      retry=None,
      timeout=120,
      metadata=None,
      catalog_name=None,
      event_store=None,
      placement_id=None):
    self._client = None
    self.retry = retry
    self.timeout = timeout
    self.metadata = metadata
    self.name = f"projects/{project}/locations/global/catalogs/"\
                f"{catalog_name}/eventStores/{event_store}/placements/"\
                f"{placement_id}"
    self.counter = Metrics.counter(self.__class__, "api_calls")
  def setup(self):
    if self._client is None:
      self._client = get_recommendation_prediction_client()
  def process(self, element):
    user_event = recommendationengine.UserEvent(element)
    request = recommendationengine.PredictRequest(
        name=self.name, user_event=user_event)
    try:
      prediction = self._client.predict(request)
      self.counter.inc()
      yield [
          recommendationengine.PredictResponse.to_dict(p)
          for p in prediction.pages
      ]
    except Exception:
      yield pvalue.TaggedOutput(self.FAILED_PREDICTIONS, user_event)