Source code for apache_beam.ml.rag.ingestion.base
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from abc import abstractmethod
from typing import Any
import apache_beam as beam
from apache_beam.ml.rag.types import Chunk
[docs]
class VectorDatabaseWriteConfig(ABC):
  """Abstract base class for vector database configurations in RAG pipelines.
  
  VectorDatabaseWriteConfig defines the interface for configuring vector
  database writes in RAG pipelines. Implementations should provide
  database-specific configuration and create appropriate write transforms.
  The configuration flow:
  1. Subclass provides database-specific configuration (table names, etc)
  2. create_write_transform() creates appropriate PTransform for writing
  3. Transform handles converting Chunks to database-specific format
  Example implementation:
    >>> class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig):
    ...     def __init__(self, table: str):
    ...         self.embedding_column = embedding_column
    ...         
    ...     def create_write_transform(self):
    ...         return beam.io.WriteToBigQuery(
    ...             table=self.table
    ...         )
  """
[docs]
  @abstractmethod
  def create_write_transform(self) -> beam.PTransform[Chunk, Any]:
    """Creates a PTransform that writes embeddings to the vector database.
    
    Returns:
        A PTransform that accepts PCollection[Chunk] and writes the chunks'
        embeddings and metadata to the configured vector database.
        The transform should handle:
        - Converting Chunk format to database schema
        - Setting up database connection/client
        - Writing with appropriate batching/error handling
    """
    raise NotImplementedError(type(self))
[docs]
class VectorDatabaseWriteTransform(beam.PTransform):
  """A PTransform for writing embedded chunks to vector databases.
  
  This transform uses a VectorDatabaseWriteConfig to write chunks with
  embeddings to vector database. It handles validating the config and applying
  the database-specific write transform.
  Example usage:
    >>> config = BigQueryVectorConfig(
    ...     table='project.dataset.embeddings',
    ...     embedding_column='embedding'
    ... )
    >>> 
    >>> with beam.Pipeline() as p:
    ...     chunks = p | beam.Create([...])  # PCollection[Chunk]
    ...     chunks | VectorDatabaseWriteTransform(config)
  Args:
      database_config: Configuration for the target vector database.
          Must be a subclass of VectorDatabaseWriteConfig that implements
          create_write_transform().
  
  Raises:
      TypeError: If database_config is not a VectorDatabaseWriteConfig instance.
  """
  def __init__(self, database_config: VectorDatabaseWriteConfig):
    """Initialize transform with database config.
        
        Args:
            database_config: Configuration for target vector database.
        """
    if not isinstance(database_config, VectorDatabaseWriteConfig):
      raise TypeError(
          f"database_config must be VectorDatabaseWriteConfig, "
          f"got {type(database_config)}")
    self.database_config = database_config
[docs]
  def expand(self,
             pcoll: beam.PCollection[Chunk]) -> beam.PTransform[Chunk, Any]:
    """Creates and applies the database-specific write transform.
    
    Args:
        pcoll: PCollection of Chunks with embeddings to write to the
            vector database. Each Chunk must have:
            - An embedding
            - An ID
            - Metadata used to filter results as specified by database config
            
    Returns:
        Result of writing to database (implementation specific).
    """
    write_transform = self.database_config.create_write_transform()
    return pcoll | write_transform