#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Callable
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Union
import apache_beam as beam
from apache_beam.coders import registry
from apache_beam.coders.row_coder import RowCoder
from apache_beam.io.jdbc import WriteToJdbc
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder
from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution
from apache_beam.ml.rag.types import Chunk
_LOGGER = logging.getLogger(__name__)
MetadataSpec = Union[ColumnSpec, Dict[str, ColumnSpec]]
class _PostgresQueryBuilder:
def __init__(
self,
table_name: str,
*,
column_specs: List[ColumnSpec],
conflict_resolution: Optional[ConflictResolution] = None):
"""Builds SQL queries for writing Chunks with Embeddings to Postgres.
"""
self.table_name = table_name
self.column_specs = column_specs
self.conflict_resolution = conflict_resolution
# Validate no duplicate column names
names = [col.column_name for col in self.column_specs]
duplicates = set(name for name in names if names.count(name) > 1)
if duplicates:
raise ValueError(f"Duplicate column names found: {duplicates}")
# Create NamedTuple type
fields = [(col.column_name, col.python_type) for col in self.column_specs]
type_name = f"VectorRecord_{table_name}"
self.record_type = NamedTuple(type_name, fields) # type: ignore
# Register coder
registry.register_coder(self.record_type, RowCoder)
# Set default update fields to all non-conflict fields if update fields are
# not specified
if self.conflict_resolution:
self.conflict_resolution.maybe_set_default_update_fields(
[col.column_name for col in self.column_specs if col.column_name])
def build_insert(self) -> str:
"""Build INSERT query with proper type casting."""
# Get column names and placeholders
fields = [col.column_name for col in self.column_specs]
placeholders = [col.placeholder for col in self.column_specs]
# Build base query
query = f"""
INSERT INTO {self.table_name}
({', '.join(fields)})
VALUES ({', '.join(placeholders)})
"""
# Add conflict handling if configured
if self.conflict_resolution:
query += f" {self.conflict_resolution.get_conflict_clause()}"
_LOGGER.info("Query with placeholders %s", query)
return query
def create_converter(self) -> Callable[[Chunk], NamedTuple]:
"""Creates a function to convert Chunks to records."""
def convert(chunk: Chunk) -> self.record_type: # type: ignore
return self.record_type(
**{col.column_name: col.value_fn(chunk)
for col in self.column_specs}) # type: ignore
return convert
[docs]
class PostgresVectorWriterConfig(VectorDatabaseWriteConfig):
def __init__(
self,
connection_config: ConnectionConfig,
table_name: str,
*,
# pylint: disable=dangerous-default-value
write_config: WriteConfig = WriteConfig(),
column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build(
),
conflict_resolution: Optional[ConflictResolution] = ConflictResolution(
on_conflict_fields=[], action='IGNORE')):
"""Configuration for writing vectors to Postgres using jdbc.
Supports flexible schema configuration through column specifications and
conflict resolution strategies.
Args:
connection_config:
:class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`.
table_name: Target table name.
write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control
batch sizes, authosharding, etc.
column_specs:
Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how
embeddings and metadata are written a database
schema. If None, uses default Chunk schema.
conflict_resolution: Optional
:class:`~.postgres_common.ConflictResolution`
strategy for handling insert conflicts. ON CONFLICT DO NOTHING by
default.
Examples:
Simple case with default schema:
>>> config = PostgresVectorWriterConfig(
... connection_config=ConnectionConfig(...),
... table_name='embeddings'
... )
Custom schema with metadata fields:
>>> specs = (ColumnSpecsBuilder()
... .with_id_spec(column_name="my_id_column")
... .with_embedding_spec(column_name="embedding_vec")
... .add_metadata_field(field="source", column_name="src")
... .add_metadata_field(
... "timestamp",
... column_name="created_at",
... sql_typecast="::timestamp"
... )
... .build())
Minimal schema (only ID + embedding written)
>>> column_specs = (ColumnSpecsBuilder()
... .with_id_spec()
... .with_embedding_spec()
... .build())
>>> config = PostgresVectorWriterConfig(
... connection_config=ConnectionConfig(...),
... table_name='embeddings',
... column_specs=specs
... )
"""
self.connection_config = connection_config
self.write_config = write_config
# NamedTuple is created and registered here during pipeline construction
self.query_builder = _PostgresQueryBuilder(
table_name,
column_specs=column_specs,
conflict_resolution=conflict_resolution)
[docs]
def create_write_transform(self) -> beam.PTransform:
return _WriteToPostgresVectorDatabase(self)
class _WriteToPostgresVectorDatabase(beam.PTransform):
"""Implementation of Postgres vector database write. """
def __init__(self, config: PostgresVectorWriterConfig):
self.config = config
self.query_builder = config.query_builder
self.connection_config = config.connection_config
self.write_config = config.write_config
def expand(self, pcoll: beam.PCollection[Chunk]):
return (
pcoll
|
"Convert to Records" >> beam.Map(self.query_builder.create_converter())
| "Write to Postgres" >> WriteToJdbc(
table_name=self.query_builder.table_name,
driver_class_name="org.postgresql.Driver",
jdbc_url=self.connection_config.jdbc_url,
username=self.connection_config.username,
password=self.connection_config.password,
statement=self.query_builder.build_insert(),
connection_properties=self.connection_config.connection_properties,
connection_init_sqls=self.connection_config.connection_init_sqls,
autosharding=self.write_config.autosharding,
max_connections=self.write_config.max_connections,
write_batch_size=self.write_config.write_batch_size,
**self.connection_config.additional_jdbc_args))