Source code for apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from typing import List

import proto
from google.api_core.exceptions import NotFound
from google.cloud import aiplatform

import apache_beam as beam
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel

__all__ = [
    'VertexAIFeatureStoreEnrichmentHandler',
    'VertexAIFeatureStoreLegacyEnrichmentHandler'
]

_LOGGER = logging.getLogger(__name__)


def _not_found_err_message(
    feature_store_name: str, feature_view_name: str, entity_id: str) -> str:
  """returns a string formatted with given parameters"""
  return (
      "make sure the Feature Store: %s with Feature View "
      "%s has entity_id: %s" %
      (feature_store_name, feature_view_name, entity_id))


[docs]class VertexAIFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): """Enrichment handler to interact with Vertex AI Feature Store. Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` transform when the Vertex AI Feature Store is set up for Bigtable Online serving. With the Bigtable Online serving approach, the client fetches all the available features for an entity-id. The entity-id is extracted from the `row_key` field in the input `beam.Row` object. To filter the features to enrich, use the `join_fn` param in :class:`apache_beam.transforms.enrichment.Enrichment`. **NOTE:** The default severity to report exceptions is logging a warning. For this handler, Vertex AI client returns the same exception `Requested entity was not found` even though the feature store doesn't exist. So make sure the feature store instance exists or set `exception_level` as `ExceptionLevel.RAISE`. """ def __init__( self, project: str, location: str, api_endpoint: str, feature_store_name: str, feature_view_name: str, row_key: str, *, exception_level: ExceptionLevel = ExceptionLevel.WARN, **kwargs, ): """Initializes an instance of `VertexAIFeatureStoreEnrichmentHandler`. Args: project (str): The GCP project-id for the Vertex AI Feature Store. location (str): The region for the Vertex AI Feature Store. api_endpoint (str): The API endpoint for the Vertex AI Feature Store. feature_store_name (str): The name of the Vertex AI Feature Store. feature_view_name (str): The name of the feature view within the Feature Store. row_key (str): The row key field name containing the unique id for the feature values. exception_level: a `enum.Enum` value from `apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel` to set the level when an empty row is returned from the BigTable query. Defaults to `ExceptionLevel.WARN`. kwargs: Optional keyword arguments to configure the `aiplatform.gapic.FeatureOnlineStoreServiceClient`. """ self.project = project self.location = location self.api_endpoint = api_endpoint self.feature_store_name = feature_store_name self.feature_view_name = feature_view_name self.row_key = row_key self.exception_level = exception_level self.kwargs = kwargs if kwargs else {} if 'client_options' in self.kwargs: if not self.kwargs['client_options']['api_endpoint']: self.kwargs['client_options']['api_endpoint'] = self.api_endpoint elif self.kwargs['client_options']['api_endpoint'] != self.api_endpoint: raise ValueError( 'Multiple values received for api_endpoint in ' 'api_endpoint and client_options parameters.') else: self.kwargs['client_options'] = {"api_endpoint": self.api_endpoint} # check if the feature store exists try: admin_client = aiplatform.gapic.FeatureOnlineStoreAdminServiceClient( **self.kwargs) except Exception: _LOGGER.warning( 'Due to insufficient admin permission, could not verify ' 'the existence of feature store. If the `exception_level` ' 'is set to WARN then make sure the feature store exists ' 'otherwise the data enrichment will not happen without ' 'throwing an error.') else: location_path = admin_client.common_location_path( project=self.project, location=self.location) feature_store_path = admin_client.feature_online_store_path( project=self.project, location=self.location, feature_online_store=self.feature_store_name) feature_store = admin_client.get_feature_online_store( name=feature_store_path) if not feature_store: raise NotFound( 'Vertex AI Feature Store %s does not exists in %s' % (self.feature_store_name, location_path)) def __enter__(self): """Connect with the Vertex AI Feature Store.""" self.client = aiplatform.gapic.FeatureOnlineStoreServiceClient( **self.kwargs) self.feature_view_path = self.client.feature_view_path( self.project, self.location, self.feature_store_name, self.feature_view_name) def __call__(self, request: beam.Row, *args, **kwargs): """Fetches feature value for an entity-id from Vertex AI Feature Store. Args: request: the input `beam.Row` to enrich. """ try: entity_id = request._asdict()[self.row_key] except KeyError: raise KeyError( "Enrichment requests to Vertex AI Feature Store should " "contain a field: %s in the input `beam.Row` to join " "the input with fetched response. This is used as the " "`FeatureViewDataKey` to fetch feature values " "corresponding to this key." % self.row_key) try: response = self.client.fetch_feature_values( request=aiplatform.gapic.FetchFeatureValuesRequest( data_key=aiplatform.gapic.FeatureViewDataKey(key=entity_id), feature_view=self.feature_view_path, data_format=aiplatform.gapic.FeatureViewDataFormat.PROTO_STRUCT, )) except NotFound: if self.exception_level == ExceptionLevel.WARN: _LOGGER.warning( _not_found_err_message( self.feature_store_name, self.feature_view_name, entity_id)) return request, beam.Row() elif self.exception_level == ExceptionLevel.RAISE: raise ValueError( _not_found_err_message( self.feature_store_name, self.feature_view_name, entity_id)) response_dict = dict(response.proto_struct) return request, beam.Row(**response_dict) def __exit__(self, exc_type, exc_val, exc_tb): """Clean the instantiated Vertex AI client.""" self.client = None
[docs] def get_cache_key(self, request: beam.Row) -> str: """Returns a string formatted with unique entity-id for the feature values. """ return 'entity_id: %s' % request._asdict()[self.row_key]
[docs]class VertexAIFeatureStoreLegacyEnrichmentHandler(EnrichmentSourceHandler): """Enrichment handler to interact with Vertex AI Feature Store (Legacy). Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` transform for the Vertex AI Feature Store (Legacy). By default, it fetches all the features values for an entity-id. The entity-id is extracted from the `row_key` field in the input `beam.Row` object.You can specify the features names using `feature_ids` to fetch specific features. """ def __init__( self, project: str, location: str, api_endpoint: str, feature_store_id: str, entity_type_id: str, feature_ids: List[str], row_key: str, *, exception_level: ExceptionLevel = ExceptionLevel.WARN, **kwargs, ): """Initializes an instance of `VertexAIFeatureStoreLegacyEnrichmentHandler`. Args: project (str): The GCP project for the Vertex AI Feature Store (Legacy). location (str): The region for the Vertex AI Feature Store (Legacy). api_endpoint (str): The API endpoint for the Vertex AI Feature Store (Legacy). feature_store_id (str): The id of the Vertex AI Feature Store (Legacy). entity_type_id (str): The entity type of the feature store. feature_ids (List[str]): A list of feature-ids to fetch from the Feature Store. row_key (str): The row key field name containing the entity id for the feature values. exception_level: a `enum.Enum` value from `apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel` to set the level when an empty row is returned from the BigTable query. Defaults to `ExceptionLevel.WARN`. kwargs: Optional keyword arguments to configure the `aiplatform.gapic.FeaturestoreOnlineServingServiceClient`. """ self.project = project self.location = location self.api_endpoint = api_endpoint self.feature_store_id = feature_store_id self.entity_type_id = entity_type_id self.feature_ids = feature_ids self.row_key = row_key self.exception_level = exception_level self.kwargs = kwargs if kwargs else {} if 'client_options' in self.kwargs: if not self.kwargs['client_options']['api_endpoint']: self.kwargs['client_options']['api_endpoint'] = self.api_endpoint elif self.kwargs['client_options']['api_endpoint'] != self.api_endpoint: raise ValueError( 'Multiple values received for api_endpoint in ' 'api_endpoint and client_options parameters.') else: self.kwargs['client_options'] = {"api_endpoint": self.api_endpoint} # checks if feature store exists try: _ = aiplatform.Featurestore( featurestore_name=self.feature_store_id, project=self.project, location=self.location, credentials=self.kwargs.get('credentials'), ) except NotFound: raise NotFound( 'Vertex AI Feature Store (Legacy) %s does not exist' % self.feature_store_id) def __enter__(self): """Connect with the Vertex AI Feature Store (Legacy).""" self.client = aiplatform.gapic.FeaturestoreOnlineServingServiceClient( **self.kwargs) self.entity_type_path = self.client.entity_type_path( self.project, self.location, self.feature_store_id, self.entity_type_id) def __call__(self, request: beam.Row, *args, **kwargs): """Fetches feature value for an entity-id from Vertex AI Feature Store (Legacy). Args: request: the input `beam.Row` to enrich. """ try: entity_id = request._asdict()[self.row_key] except KeyError: raise KeyError( "Enrichment requests to Vertex AI Feature Store should " "contain a field: %s in the input `beam.Row` to join " "the input with fetched response. This is used as the " "`FeatureViewDataKey` to fetch feature values " "corresponding to this key." % self.row_key) try: selector = aiplatform.gapic.FeatureSelector( id_matcher=aiplatform.gapic.IdMatcher(ids=self.feature_ids)) response = self.client.read_feature_values( request=aiplatform.gapic.ReadFeatureValuesRequest( entity_type=self.entity_type_path, entity_id=entity_id, feature_selector=selector)) except NotFound: raise ValueError( _not_found_err_message( self.feature_store_id, self.entity_type_id, entity_id)) response_dict = {} proto_to_dict = proto.Message.to_dict(response.entity_view) for key, msg in zip(response.header.feature_descriptors, proto_to_dict['data']): if msg and 'value' in msg: response_dict[key.id] = list(msg['value'].values())[0] # skip fetching the metadata elif self.exception_level == ExceptionLevel.RAISE: raise ValueError( _not_found_err_message( self.feature_store_id, self.entity_type_id, entity_id)) elif self.exception_level == ExceptionLevel.WARN: _LOGGER.warning( _not_found_err_message( self.feature_store_id, self.entity_type_id, entity_id)) return request, beam.Row(**response_dict) def __exit__(self, exc_type, exc_val, exc_tb): """Clean the instantiated Vertex AI client.""" self.client = None
[docs] def get_cache_key(self, request: beam.Row) -> str: """Returns a string formatted with unique entity-id for the feature values. """ return 'entity_id: %s' % request._asdict()[self.row_key]