# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
from typing import Dict
from typing import Optional

from google.api_core.exceptions import NotFound
from import bigtable
from import Client
from import CellsColumnLimitFilter
from import RowFilter

import apache_beam as beam
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel

__all__ = [

_LOGGER = logging.getLogger(__name__)

[docs]class BigTableEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): """A handler for :class:`apache_beam.transforms.enrichment.Enrichment` transform to interact with GCP BigTable. Args: project_id (str): GCP project-id of the BigTable cluster. instance_id (str): GCP instance-id of the BigTable cluster. table_id (str): GCP table-id of the BigTable. row_key (str): unique row-key field name from the input `beam.Row` object to use as `row_key` for BigTable querying. row_filter: a ``:class:```` to filter data read with ``read_row()``. Defaults to `CellsColumnLimitFilter(1)`. app_profile_id (str): App profile ID to use for BigTable. See for more details. encoding (str): encoding type to convert the string to bytes and vice-versa from BigTable. Default is `utf-8`. exception_level: a `enum.Enum` value from ``apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel`` to set the level when an empty row is returned from the BigTable query. Defaults to ``ExceptionLevel.WARN``. include_timestamp (bool): If enabled, the timestamp associated with the value is returned as `(value, timestamp)` for each `row_key`. Defaults to `False` - only the latest value without the timestamp is returned. """ def __init__( self, project_id: str, instance_id: str, table_id: str, row_key: str, row_filter: Optional[RowFilter] = CellsColumnLimitFilter(1), *, app_profile_id: str = None, # type: ignore[assignment] encoding: str = 'utf-8', exception_level: ExceptionLevel = ExceptionLevel.WARN, include_timestamp: bool = False, ): self._project_id = project_id self._instance_id = instance_id self._table_id = table_id self._row_key = row_key self._row_filter = row_filter self._app_profile_id = app_profile_id self._encoding = encoding self._exception_level = exception_level self._include_timestamp = include_timestamp def __enter__(self): """connect to the Google BigTable cluster.""" self.client = Client(project=self._project_id) self.instance = self.client.instance(self._instance_id) self._table = bigtable.table.Table( table_id=self._table_id, instance=self.instance, app_profile_id=self._app_profile_id) def __call__(self, request: beam.Row, *args, **kwargs): """ Reads a row from the GCP BigTable and returns a `Tuple` of request and response. Args: request: the input `beam.Row` to enrich. """ response_dict: Dict[str, Any] = {} row_key_str: str = "" try: request_dict = request._asdict() row_key_str = str(request_dict[self._row_key]) row_key = row_key_str.encode(self._encoding) row = self._table.read_row(row_key, filter_=self._row_filter) if row: for cf_id, cf_v in row.cells.items(): response_dict[cf_id] = {} for col_id, col_v in cf_v.items(): if self._include_timestamp: response_dict[cf_id][col_id.decode(self._encoding)] = [ (v.value.decode(self._encoding), v.timestamp) for v in col_v ] else: response_dict[cf_id][col_id.decode( self._encoding)] = col_v[0].value.decode(self._encoding) elif self._exception_level == ExceptionLevel.WARN: _LOGGER.warning( 'no matching row found for row_key: %s ' 'with row_filter: %s' % (row_key_str, self._row_filter)) elif self._exception_level == ExceptionLevel.RAISE: raise ValueError( 'no matching row found for row_key: %s ' 'with row_filter=%s' % (row_key_str, self._row_filter)) except KeyError: raise KeyError('row_key %s not found in input PCollection.' % row_key_str) except NotFound: raise NotFound( 'GCP BigTable cluster `%s:%s:%s` not found.' % (self._project_id, self._instance_id, self._table_id)) except Exception as e: raise e return request, beam.Row(**response_dict) def __exit__(self, exc_type, exc_val, exc_tb): """Clean the instantiated BigTable client.""" self.client = None self.instance = None self._table = None
[docs] def get_cache_key(self, request: beam.Row) -> str: """Returns a string formatted with row key since it is unique to a request made to `Bigtable`.""" return "%s: %s" % (self._row_key, request._asdict()[self._row_key])