Source code for apache_beam.ml.transforms.embeddings.huggingface

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["SentenceTransformerEmbeddings", "InferenceAPIEmbeddings"]

import logging
import os
from collections.abc import Callable
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Any
from typing import Optional

import requests

import apache_beam as beam
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler

try:
  from sentence_transformers import SentenceTransformer
except ImportError:
  SentenceTransformer = None

_LOGGER = logging.getLogger(__name__)


# TODO: https://github.com/apache/beam/issues/29621
# Use HuggingFaceModelHandlerTensor once the import issue is fixed.
# Right now, the hugging face model handler import torch and tensorflow
# at the same time, which adds too much weigth to the container unnecessarily.
class _SentenceTransformerModelHandler(ModelHandler):
  """
  Note: Intended for internal use and guarantees no backwards compatibility.
  """
  def __init__(
      self,
      model_name: str,
      model_class: Callable,
      load_model_args: Optional[dict] = None,
      min_batch_size: Optional[int] = None,
      max_batch_size: Optional[int] = None,
      max_seq_length: Optional[int] = None,
      large_model: bool = False,
      **kwargs):
    self._max_seq_length = max_seq_length
    self.model_name = model_name
    self._model_class = model_class
    self._load_model_args = load_model_args
    self._min_batch_size = min_batch_size
    self._max_batch_size = max_batch_size
    self._large_model = large_model
    self._kwargs = kwargs

    if not SentenceTransformer:
      raise ImportError(
          "sentence-transformers is required to use "
          "SentenceTransformerEmbeddings."
          "Please install it with using `pip install sentence-transformers`.")

  def run_inference(
      self,
      batch: Sequence[str],
      model: SentenceTransformer,
      inference_args: Optional[dict[str, Any]] = None,
  ):
    inference_args = inference_args or {}
    return model.encode(batch, **inference_args)

  def load_model(self):
    model = self._model_class(self.model_name, **self._load_model_args)
    if self._max_seq_length:
      model.max_seq_length = self._max_seq_length
    return model

  def share_model_across_processes(self) -> bool:
    return self._large_model

  def batch_elements_kwargs(self) -> Mapping[str, Any]:
    batch_sizes = {}
    if self._min_batch_size:
      batch_sizes["min_batch_size"] = self._min_batch_size
    if self._max_batch_size:
      batch_sizes["max_batch_size"] = self._max_batch_size
    return batch_sizes

  def __repr__(self) -> str:
    # ModelHandler is internal to the user and is not exposed.
    # Hence we need to override the __repr__ method to expose
    # the name of the class.
    return 'SentenceTransformerEmbeddings'


[docs] class SentenceTransformerEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, columns: list[str], max_seq_length: Optional[int] = None, image_model: bool = False, **kwargs): """ Embedding config for sentence-transformers. This config can be used with MLTransform to embed text data. Models are loaded using the RunInference PTransform with the help of ModelHandler. Args: model_name: Name of the model to use. The model should be hosted on HuggingFace Hub or compatible with sentence_transformers. For image embedding models, see https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models # pylint: disable=line-too-long for a list of available sentence_transformers models. columns: List of columns to be embedded. max_seq_length: Max sequence length to use for the model if applicable. image_model: Whether the model is generating image embeddings. min_batch_size: The minimum batch size to be used for inference. max_batch_size: The maximum batch size to be used for inference. large_model: Whether to share the model across processes. """ super().__init__(columns=columns, **kwargs) self.model_name = model_name self.max_seq_length = max_seq_length self.image_model = image_model
[docs] def get_model_handler(self): return _SentenceTransformerModelHandler( model_class=SentenceTransformer, max_seq_length=self.max_seq_length, model_name=self.model_name, load_model_args=self.load_model_args, min_batch_size=self.min_batch_size, max_batch_size=self.max_batch_size, large_model=self.large_model)
[docs] def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: # wrap the model handler in an appropriate embedding handler to provide # some type checking. if self.image_model: return ( RunInference( model_handler=_ImageEmbeddingHandler(self), inference_args=self.inference_args, )) return ( RunInference( model_handler=_TextEmbeddingHandler(self), inference_args=self.inference_args, ))
class _InferenceAPIHandler(ModelHandler): def __init__(self, config: 'InferenceAPIEmbeddings'): super().__init__() self._config = config def load_model(self): session = requests.Session() # if the token is not provided during construction time, it might have # been provided with custom container, which we can get it during runtume. if not self._config.hf_token: hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise ValueError( 'HF_TOKEN environment variable not set. ' 'Please set the environment variable or pass the token as an ' 'argument.') session.headers.update({"Authorization": f"Bearer {hf_token}"}) return session session.headers.update(self._config.authorization_token) return session def run_inference( self, batch, session: requests.Session, inference_args=None): response = session.post( self._config.api_url, headers=self._config.authorization_token, json={ "inputs": batch, "options": inference_args }) return response.json()
[docs] class InferenceAPIEmbeddings(EmbeddingsManager): """ Feature extraction using HuggingFace's Inference API. Intended to be used for feature-extraction. For other tasks, please refer to https://huggingface.co/inference-api. Args: hf_token: HuggingFace token. columns: List of columns to be embedded. model_name: Model name used for feature extraction. api_url: API url for feature extraction. If specified, model_name will be ignored. If none, the default url for feature extraction will be used. """ def __init__( self, hf_token: Optional[str], columns: list[str], model_name: Optional[str] = None, # example: "sentence-transformers/all-MiniLM-l6-v2" # pylint: disable=line-too-long api_url: Optional[str] = None, **kwargs, ): super().__init__(columns=columns, **kwargs) self._authorization_token = {"Authorization": f"Bearer {hf_token}"} self._model_name = model_name self.hf_token = hf_token if not api_url: if not self._model_name: raise ValueError("Either api_url or model_name must be provided.") self._api_url = ( f"https://api-inference.huggingface.co/pipeline/feature-extraction/{self._model_name}" # pylint: disable=line-too-long ) else: self._api_url = api_url _LOGGER.info("HuggingFace API URL: %s")
[docs] def get_token(self): return os.environ.get('HF_TOKEN')
@property def api_url(self): return self._api_url @property def authorization_token(self): return self._authorization_token
[docs] def get_model_handler(self): return _InferenceAPIHandler(self)
[docs] def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: options = { # sometimes the model is not ready and returns an error response # instead of waiting. So we wait for the model to be ready. 'wait_for_model': True, 'use_cache': True, } self.inference_args.update(options) return ( RunInference( model_handler=_TextEmbeddingHandler(self), inference_args=self.inference_args))