Source code for apache_beam.ml.rag.ingestion.postgres

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Callable
from typing import NamedTuple
from typing import Optional
from typing import Union

import apache_beam as beam
from apache_beam.coders import registry
from apache_beam.coders.row_coder import RowCoder
from apache_beam.io.jdbc import WriteToJdbc
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder
from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution
from apache_beam.ml.rag.types import EmbeddableItem

_LOGGER = logging.getLogger(__name__)

MetadataSpec = Union[ColumnSpec, dict[str, ColumnSpec]]


class _PostgresQueryBuilder:
  def __init__(
      self,
      table_name: str,
      *,
      column_specs: list[ColumnSpec],
      conflict_resolution: Optional[ConflictResolution] = None):
    """Builds SQL queries for writing EmbeddableItems to Postgres.
    """
    self.table_name = table_name

    self.column_specs = column_specs
    self.conflict_resolution = conflict_resolution

    # Validate no duplicate column names
    names = [col.column_name for col in self.column_specs]
    duplicates = set(name for name in names if names.count(name) > 1)
    if duplicates:
      raise ValueError(f"Duplicate column names found: {duplicates}")

    # Create NamedTuple type
    fields = [(col.column_name, col.python_type) for col in self.column_specs]
    type_name = f"VectorRecord_{table_name}"
    self.record_type = NamedTuple(type_name, fields)  # type: ignore

    # Register coder
    registry.register_coder(self.record_type, RowCoder)

    # Set default update fields to all non-conflict fields if update fields are
    # not specified
    if self.conflict_resolution:
      self.conflict_resolution.maybe_set_default_update_fields(
          [col.column_name for col in self.column_specs if col.column_name])

  def build_insert(self) -> str:
    """Build INSERT query with proper type casting."""
    # Get column names and placeholders
    fields = [col.column_name for col in self.column_specs]
    placeholders = [col.placeholder for col in self.column_specs]

    # Build base query
    query = f"""
        INSERT INTO {self.table_name}
        ({', '.join(fields)})
        VALUES ({', '.join(placeholders)})
    """

    # Add conflict handling if configured
    if self.conflict_resolution:
      query += f" {self.conflict_resolution.get_conflict_clause()}"

    _LOGGER.info("Query with placeholders %s", query)
    return query

  def create_converter(self) -> Callable[[EmbeddableItem], NamedTuple]:
    """Creates a function to convert EmbeddableItems to records."""
    def convert(chunk: EmbeddableItem) -> self.record_type:  # type: ignore
      return self.record_type(
          **{col.column_name: col.value_fn(chunk)
             for col in self.column_specs})  # type: ignore

    return convert


[docs] class PostgresVectorWriterConfig(VectorDatabaseWriteConfig): def __init__( self, connection_config: ConnectionConfig, table_name: str, *, # pylint: disable=dangerous-default-value write_config: WriteConfig = WriteConfig(), column_specs: list[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build( ), conflict_resolution: Optional[ConflictResolution] = ConflictResolution( on_conflict_fields=[], action='IGNORE')): """Configuration for writing vectors to Postgres using jdbc. Supports flexible schema configuration through column specifications and conflict resolution strategies. Args: connection_config: :class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`. table_name: Target table name. write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control batch sizes, authosharding, etc. column_specs: Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how embeddings and metadata are written a database schema. If None, uses default EmbeddableItem schema. conflict_resolution: Optional :class:`~.postgres_common.ConflictResolution` strategy for handling insert conflicts. ON CONFLICT DO NOTHING by default. Examples: Simple case with default schema: >>> config = PostgresVectorWriterConfig( ... connection_config=ConnectionConfig(...), ... table_name='embeddings' ... ) Custom schema with metadata fields: >>> specs = (ColumnSpecsBuilder() ... .with_id_spec(column_name="my_id_column") ... .with_embedding_spec(column_name="embedding_vec") ... .add_metadata_field(field="source", column_name="src") ... .add_metadata_field( ... "timestamp", ... column_name="created_at", ... sql_typecast="::timestamp" ... ) ... .build()) Minimal schema (only ID + embedding written) >>> column_specs = (ColumnSpecsBuilder() ... .with_id_spec() ... .with_embedding_spec() ... .build()) >>> config = PostgresVectorWriterConfig( ... connection_config=ConnectionConfig(...), ... table_name='embeddings', ... column_specs=specs ... ) """ self.connection_config = connection_config self.write_config = write_config # NamedTuple is created and registered here during pipeline construction self.query_builder = _PostgresQueryBuilder( table_name, column_specs=column_specs, conflict_resolution=conflict_resolution)
[docs] def create_write_transform(self) -> beam.PTransform: return _WriteToPostgresVectorDatabase(self)
class _WriteToPostgresVectorDatabase(beam.PTransform): """Implementation of Postgres vector database write. """ def __init__(self, config: PostgresVectorWriterConfig): self.config = config self.query_builder = config.query_builder self.connection_config = config.connection_config self.write_config = config.write_config def expand(self, pcoll: beam.PCollection[EmbeddableItem]): return ( pcoll | "Convert to Records" >> beam.Map(self.query_builder.create_converter()) | "Write to Postgres" >> WriteToJdbc( table_name=self.query_builder.table_name, driver_class_name="org.postgresql.Driver", jdbc_url=self.connection_config.jdbc_url, username=self.connection_config.username, password=self.connection_config.password, statement=self.query_builder.build_insert(), connection_properties=self.connection_config.connection_properties, connection_init_sqls=self.connection_config.connection_init_sqls, autosharding=self.write_config.autosharding, max_connections=self.write_config.max_connections, write_batch_size=self.write_config.write_batch_size, **self.connection_config.additional_jdbc_args))