Source code for apache_beam.ml.rag.ingestion.alloydb

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import List
from typing import Optional

from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.postgres import ColumnSpecsBuilder
from apache_beam.ml.rag.ingestion.postgres import PostgresVectorWriterConfig
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution


[docs] @dataclass class AlloyDBLanguageConnectorConfig: """Configuration options for AlloyDB language connector. Contains all parameters needed to configure a connection using the AlloyDB Java connector via JDBC. For details see https://github.com/GoogleCloudPlatform/alloydb-java-connector/blob/main/docs/jdbc.md Attributes: username: Database username. password: Database password. Can be empty string when using IAM. database_name: Name of the database to connect to. instance_name: Fullly qualified instance. Format: 'projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances /<INSTANCE>' ip_type: IP type to use for connection. Either 'PRIVATE' (default), 'PUBLIC' 'PSC. enable_iam_auth: Whether to enable IAM authentication. Default is False target_principal: Optional service account to impersonate for connection. delegates: Optional comma-separated list of service accounts for delegated impersonation. admin_service_endpoint: Optional custom API service endpoint. quota_project: Optional project ID for quota and billing. connection_properties: Optional JDBC connection properties dict. Example: {'ssl': 'true'} additional_properties: Additional properties to be added to the JDBC url. Example: {'someProperty': 'true'} """ username: str password: str database_name: str instance_name: str ip_type: str = "PRIVATE" enable_iam_auth: bool = False target_principal: Optional[str] = None delegates: Optional[List[str]] = None admin_service_endpoint: Optional[str] = None quota_project: Optional[str] = None connection_properties: Optional[Dict[str, str]] = None additional_properties: Optional[Dict[str, Any]] = None
[docs] def to_jdbc_url(self) -> str: """Convert options to a properly formatted JDBC URL. Returns: JDBC URL string configured with all options. """ # Base URL with database name url = f"jdbc:postgresql:///{self.database_name}?" # Add required properties properties = { "socketFactory": "com.google.cloud.alloydb.SocketFactory", "alloydbInstanceName": self.instance_name, "alloydbIpType": self.ip_type } if self.enable_iam_auth: properties["alloydbEnableIAMAuth"] = "true" if self.target_principal: properties["alloydbTargetPrincipal"] = self.target_principal if self.delegates: properties["alloydbDelegates"] = ",".join(self.delegates) if self.admin_service_endpoint: properties["alloydbAdminServiceEndpoint"] = self.admin_service_endpoint if self.quota_project: properties["alloydbQuotaProject"] = self.quota_project if self.additional_properties: properties.update(self.additional_properties) property_string = "&".join(f"{k}={v}" for k, v in properties.items()) return url + property_string
[docs] def to_connection_config(self): return ConnectionConfig( jdbc_url=self.to_jdbc_url(), username=self.username, password=self.password, connection_properties=self.connection_properties, additional_jdbc_args=self.additional_jdbc_args())
[docs] def additional_jdbc_args(self) -> Dict[str, List[Any]]: return { 'classpath': [ "org.postgresql:postgresql:42.2.16", "com.google.cloud:alloydb-jdbc-connector:1.2.0" ] }
[docs] class AlloyDBVectorWriterConfig(PostgresVectorWriterConfig): def __init__( self, connection_config: AlloyDBLanguageConnectorConfig, table_name: str, *, # pylint: disable=dangerous-default-value write_config: WriteConfig = WriteConfig(), column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build( ), conflict_resolution: Optional[ConflictResolution] = ConflictResolution( on_conflict_fields=[], action='IGNORE')): """Configuration for writing vectors to AlloyDB. Supports flexible schema configuration through column specifications and conflict resolution strategies. Args: connection_config: AlloyDB connection configuration. table_name: Target table name. write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control batch sizes, authosharding, etc. column_specs: Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how embeddings and metadata are written a database schema. If None, uses default Chunk schema. conflict_resolution: Optional :class:`~.postgres_common.ConflictResolution` strategy for handling insert conflicts. ON CONFLICT DO NOTHING by default. Examples: Basic usage with default schema: >>> config = AlloyDBVectorWriterConfig( ... connection_config=AlloyDBConnectionConfig(...), ... table_name='embeddings' ... ) Simple case with default schema: >>> config = PostgresVectorWriterConfig( ... connection_config=ConnectionConfig(...), ... table_name='embeddings' ... ) Custom schema with metadata fields: >>> specs = (ColumnSpecsBuilder() ... .with_id_spec(column_name="my_id_column") ... .with_embedding_spec(column_name="embedding_vec") ... .add_metadata_field(field="source", column_name="src") ... .add_metadata_field( ... "timestamp", ... column_name="created_at", ... sql_typecast="::timestamp" ... ) ... .build()) Minimal schema (only ID + embedding written) >>> column_specs = (ColumnSpecsBuilder() ... .with_id_spec() ... .with_embedding_spec() ... .build()) >>> config = AlloyDBVectorWriterConfig( ... connection_config=AlloyDBConnectionConfig(...), ... table_name='embeddings', ... column_specs=specs ... ) """ super().__init__( connection_config=connection_config.to_connection_config(), write_config=write_config, table_name=table_name, column_specs=column_specs, conflict_resolution=conflict_resolution)