Source code for apache_beam.ml.rag.embeddings.huggingface

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""RAG-specific embedding implementations using HuggingFace models."""

import io
from collections.abc import Sequence
from typing import Optional

import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.rag.embeddings.base import _add_embedding_fn
from apache_beam.ml.rag.embeddings.base import create_text_adapter
from apache_beam.ml.rag.types import EmbeddableItem
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import EmbeddingTypeAdapter
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from apache_beam.ml.transforms.embeddings.huggingface import _SentenceTransformerModelHandler

try:
  from sentence_transformers import SentenceTransformer
except ImportError:
  SentenceTransformer = None

try:
  from PIL import Image as PILImage
except ImportError:
  PILImage = None  # type: ignore[assignment]


[docs] class HuggingfaceTextEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs): """HuggingFace text embeddings for RAG pipelines. Args: model_name: Name of the sentence-transformers model to use. max_seq_length: Maximum sequence length for the model. **kwargs: Additional arguments passed to :class:`~apache_beam.ml.transforms.base.EmbeddingsManager`, including: - ``load_model_args``: dict passed to ``SentenceTransformer()`` constructor (e.g. ``device``, ``cache_folder``). - ``min_batch_size`` / ``max_batch_size``: Control batching for inference. - ``large_model``: If True, share the model across processes to reduce memory usage. - ``inference_args``: dict passed to ``model.encode()`` (e.g. ``normalize_embeddings``). """ if not SentenceTransformer: raise ImportError( "sentence-transformers is required to use " "HuggingfaceTextEmbeddings." "Please install it with using `pip install sentence-transformers`.") super().__init__(type_adapter=create_text_adapter(), **kwargs) self.model_name = model_name self.max_seq_length = max_seq_length self.model_class = SentenceTransformer
[docs] def get_model_handler(self): """Returns model handler configured with RAG adapter.""" return _SentenceTransformerModelHandler( model_class=self.model_class, max_seq_length=self.max_seq_length, model_name=self.model_name, load_model_args=self.load_model_args, min_batch_size=self.min_batch_size, max_batch_size=self.max_batch_size, large_model=self.large_model)
[docs] def get_ptransform_for_processing( self, **kwargs ) -> beam.PTransform[beam.PCollection[EmbeddableItem], beam.PCollection[EmbeddableItem]]: """Returns PTransform that uses the RAG adapter.""" return RunInference( model_handler=_TextEmbeddingHandler(self), inference_args=self.inference_args).with_output_types(EmbeddableItem)
def _extract_images(items: Sequence[EmbeddableItem]) -> list: """Extract images from items and convert to PIL.Image objects. Supports raw bytes, local file paths, and remote URIs (e.g. gs://, s3://) via Beam's FileSystems. """ images = [] for item in items: if not item.content.image: raise ValueError( "Expected image content in " f"{type(item).__name__} {item.id}, " "got None") img_data = item.content.image if isinstance(img_data, bytes): img = PILImage.open(io.BytesIO(img_data)) else: with FileSystems.open(img_data, 'rb') as f: img = PILImage.open(f) img.load() images.append(img.convert('RGB')) return images def _create_hf_image_adapter( ) -> EmbeddingTypeAdapter[EmbeddableItem, EmbeddableItem]: """Creates adapter for HuggingFace image embedding. Extracts content.image from EmbeddableItems and converts to PIL.Image objects. Supports both raw bytes and file paths. Returns: EmbeddingTypeAdapter for HuggingFace image embedding. """ return EmbeddingTypeAdapter( input_fn=_extract_images, output_fn=_add_embedding_fn)
[docs] class HuggingfaceImageEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs): """HuggingFace image embeddings for RAG pipelines. Generates embeddings for images using sentence-transformers models that support image input (e.g. clip-ViT-B-32). Args: model_name: Name of the sentence-transformers model. Must be an image-text model. See https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models max_seq_length: Maximum sequence length for the model if applicable. **kwargs: Additional arguments passed to :class:`~apache_beam.ml.transforms.base.EmbeddingsManager`, including: - ``load_model_args``: dict passed to ``SentenceTransformer()`` constructor (e.g. ``device``, ``cache_folder``, ``trust_remote_code``). - ``min_batch_size`` / ``max_batch_size``: Control batching for inference. - ``large_model``: If True, share the model across processes to reduce memory usage. - ``inference_args``: dict passed to ``model.encode()`` (e.g. ``normalize_embeddings``). """ if not SentenceTransformer: raise ImportError( "sentence-transformers is required to use " "HuggingfaceImageEmbeddings. " "Please install it with `pip install sentence-transformers`.") if not PILImage: raise ImportError( "Pillow is required to use HuggingfaceImageEmbeddings. " "Please install it with `pip install pillow`.") super().__init__(type_adapter=_create_hf_image_adapter(), **kwargs) self.model_name = model_name self.max_seq_length = max_seq_length self.model_class = SentenceTransformer
[docs] def get_model_handler(self): """Returns model handler configured with RAG adapter.""" return _SentenceTransformerModelHandler( model_class=self.model_class, max_seq_length=self.max_seq_length, model_name=self.model_name, load_model_args=self.load_model_args, min_batch_size=self.min_batch_size, max_batch_size=self.max_batch_size, large_model=self.large_model)
[docs] def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: """Returns PTransform for image embedding.""" return RunInference( model_handler=_ImageEmbeddingHandler(self), inference_args=self.inference_args).with_output_types(EmbeddableItem)