Source code for apache_beam.ml.rag.ingestion.base
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from abc import abstractmethod
from typing import Any
import apache_beam as beam
from apache_beam.ml.rag.types import Chunk
[docs]
class VectorDatabaseWriteConfig(ABC):
"""Abstract base class for vector database configurations in RAG pipelines.
VectorDatabaseWriteConfig defines the interface for configuring vector
database writes in RAG pipelines. Implementations should provide
database-specific configuration and create appropriate write transforms.
The configuration flow:
1. Subclass provides database-specific configuration (table names, etc)
2. create_write_transform() creates appropriate PTransform for writing
3. Transform handles converting Chunks to database-specific format
Example implementation:
>>> class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig):
... def __init__(self, table: str):
... self.embedding_column = embedding_column
...
... def create_write_transform(self):
... return beam.io.WriteToBigQuery(
... table=self.table
... )
"""
[docs]
@abstractmethod
def create_write_transform(self) -> beam.PTransform[Chunk, Any]:
"""Creates a PTransform that writes embeddings to the vector database.
Returns:
A PTransform that accepts PCollection[Chunk] and writes the chunks'
embeddings and metadata to the configured vector database.
The transform should handle:
- Converting Chunk format to database schema
- Setting up database connection/client
- Writing with appropriate batching/error handling
"""
raise NotImplementedError(type(self))
[docs]
class VectorDatabaseWriteTransform(beam.PTransform):
"""A PTransform for writing embedded chunks to vector databases.
This transform uses a VectorDatabaseWriteConfig to write chunks with
embeddings to vector database. It handles validating the config and applying
the database-specific write transform.
Example usage:
>>> config = BigQueryVectorConfig(
... table='project.dataset.embeddings',
... embedding_column='embedding'
... )
>>>
>>> with beam.Pipeline() as p:
... chunks = p | beam.Create([...]) # PCollection[Chunk]
... chunks | VectorDatabaseWriteTransform(config)
Args:
database_config: Configuration for the target vector database.
Must be a subclass of VectorDatabaseWriteConfig that implements
create_write_transform().
Raises:
TypeError: If database_config is not a VectorDatabaseWriteConfig instance.
"""
def __init__(self, database_config: VectorDatabaseWriteConfig):
"""Initialize transform with database config.
Args:
database_config: Configuration for target vector database.
"""
if not isinstance(database_config, VectorDatabaseWriteConfig):
raise TypeError(
f"database_config must be VectorDatabaseWriteConfig, "
f"got {type(database_config)}")
self.database_config = database_config
[docs]
def expand(self,
pcoll: beam.PCollection[Chunk]) -> beam.PTransform[Chunk, Any]:
"""Creates and applies the database-specific write transform.
Args:
pcoll: PCollection of Chunks with embeddings to write to the
vector database. Each Chunk must have:
- An embedding
- An ID
- Metadata used to filter results as specified by database config
Returns:
Result of writing to database (implementation specific).
"""
write_transform = self.database_config.create_write_transform()
return pcoll | write_transform