Source code for apache_beam.ml.rag.ingestion.postgres

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Callable
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Union

import apache_beam as beam
from apache_beam.coders import registry
from apache_beam.coders.row_coder import RowCoder
from apache_beam.io.jdbc import WriteToJdbc
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder
from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution
from apache_beam.ml.rag.types import Chunk

_LOGGER = logging.getLogger(__name__)

MetadataSpec = Union[ColumnSpec, Dict[str, ColumnSpec]]


class _PostgresQueryBuilder:
  def __init__(
      self,
      table_name: str,
      *,
      column_specs: List[ColumnSpec],
      conflict_resolution: Optional[ConflictResolution] = None):
    """Builds SQL queries for writing Chunks with Embeddings to Postgres.
    """
    self.table_name = table_name

    self.column_specs = column_specs
    self.conflict_resolution = conflict_resolution

    # Validate no duplicate column names
    names = [col.column_name for col in self.column_specs]
    duplicates = set(name for name in names if names.count(name) > 1)
    if duplicates:
      raise ValueError(f"Duplicate column names found: {duplicates}")

    # Create NamedTuple type
    fields = [(col.column_name, col.python_type) for col in self.column_specs]
    type_name = f"VectorRecord_{table_name}"
    self.record_type = NamedTuple(type_name, fields)  # type: ignore

    # Register coder
    registry.register_coder(self.record_type, RowCoder)

    # Set default update fields to all non-conflict fields if update fields are
    # not specified
    if self.conflict_resolution:
      self.conflict_resolution.maybe_set_default_update_fields(
          [col.column_name for col in self.column_specs if col.column_name])

  def build_insert(self) -> str:
    """Build INSERT query with proper type casting."""
    # Get column names and placeholders
    fields = [col.column_name for col in self.column_specs]
    placeholders = [col.placeholder for col in self.column_specs]

    # Build base query
    query = f"""
        INSERT INTO {self.table_name}
        ({', '.join(fields)})
        VALUES ({', '.join(placeholders)})
    """

    # Add conflict handling if configured
    if self.conflict_resolution:
      query += f" {self.conflict_resolution.get_conflict_clause()}"

    _LOGGER.info("Query with placeholders %s", query)
    return query

  def create_converter(self) -> Callable[[Chunk], NamedTuple]:
    """Creates a function to convert Chunks to records."""
    def convert(chunk: Chunk) -> self.record_type:  # type: ignore
      return self.record_type(
          **{col.column_name: col.value_fn(chunk)
             for col in self.column_specs})  # type: ignore

    return convert


[docs] class PostgresVectorWriterConfig(VectorDatabaseWriteConfig): def __init__( self, connection_config: ConnectionConfig, table_name: str, *, # pylint: disable=dangerous-default-value write_config: WriteConfig = WriteConfig(), column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build( ), conflict_resolution: Optional[ConflictResolution] = ConflictResolution( on_conflict_fields=[], action='IGNORE')): """Configuration for writing vectors to Postgres using jdbc. Supports flexible schema configuration through column specifications and conflict resolution strategies. Args: connection_config: :class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`. table_name: Target table name. write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control batch sizes, authosharding, etc. column_specs: Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how embeddings and metadata are written a database schema. If None, uses default Chunk schema. conflict_resolution: Optional :class:`~.postgres_common.ConflictResolution` strategy for handling insert conflicts. ON CONFLICT DO NOTHING by default. Examples: Simple case with default schema: >>> config = PostgresVectorWriterConfig( ... connection_config=ConnectionConfig(...), ... table_name='embeddings' ... ) Custom schema with metadata fields: >>> specs = (ColumnSpecsBuilder() ... .with_id_spec(column_name="my_id_column") ... .with_embedding_spec(column_name="embedding_vec") ... .add_metadata_field(field="source", column_name="src") ... .add_metadata_field( ... "timestamp", ... column_name="created_at", ... sql_typecast="::timestamp" ... ) ... .build()) Minimal schema (only ID + embedding written) >>> column_specs = (ColumnSpecsBuilder() ... .with_id_spec() ... .with_embedding_spec() ... .build()) >>> config = PostgresVectorWriterConfig( ... connection_config=ConnectionConfig(...), ... table_name='embeddings', ... column_specs=specs ... ) """ self.connection_config = connection_config self.write_config = write_config # NamedTuple is created and registered here during pipeline construction self.query_builder = _PostgresQueryBuilder( table_name, column_specs=column_specs, conflict_resolution=conflict_resolution)
[docs] def create_write_transform(self) -> beam.PTransform: return _WriteToPostgresVectorDatabase(self)
class _WriteToPostgresVectorDatabase(beam.PTransform): """Implementation of Postgres vector database write. """ def __init__(self, config: PostgresVectorWriterConfig): self.config = config self.query_builder = config.query_builder self.connection_config = config.connection_config self.write_config = config.write_config def expand(self, pcoll: beam.PCollection[Chunk]): return ( pcoll | "Convert to Records" >> beam.Map(self.query_builder.create_converter()) | "Write to Postgres" >> WriteToJdbc( table_name=self.query_builder.table_name, driver_class_name="org.postgresql.Driver", jdbc_url=self.connection_config.jdbc_url, username=self.connection_config.username, password=self.connection_config.password, statement=self.query_builder.build_insert(), connection_properties=self.connection_config.connection_properties, connection_init_sqls=self.connection_config.connection_init_sqls, autosharding=self.write_config.autosharding, max_connections=self.write_config.max_connections, write_batch_size=self.write_config.write_batch_size, **self.connection_config.additional_jdbc_args))