Source code for apache_beam.ml.rag.ingestion.cloudsql

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import asdict
from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import List
from typing import Optional

from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.postgres import ColumnSpecsBuilder
from apache_beam.ml.rag.ingestion.postgres import PostgresVectorWriterConfig
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution


[docs] @dataclass class LanguageConnectorConfig: """Configuration options for CloudSQL Java language connector. Set parameters to connect connection to a CloudSQL instance using Java language connector connector. For details see https://github.com/GoogleCloudPlatform/cloud-sql-jdbc-socket-factory/blob/main/docs/jdbc.md Attributes: username: Database username. password: Database password. Can be empty string when using IAM. database_name: Name of the database to connect to. instance_name: Instance connection name. Format: '<PROJECT>:<REGION>:<INSTANCE>' ip_type: Preferred order of IP types used to connect via a comma list of strings. enable_iam_auth: Whether to enable IAM authentication. Default is False target_principal: Optional service account to impersonate for connection. delegates: Optional list of service accounts for delegated impersonation. admin_service_endpoint: Optional custom API service endpoint. quota_project: Optional project ID for quota and billing. connection_properties: Optional JDBC connection properties dict. Example: {'ssl': 'true'} additional_properties: Additional properties to be added to the JDBC url. Example: {'someProperty': 'true'} """ username: str password: str database_name: str instance_name: str ip_types: Optional[List[str]] = None enable_iam_auth: bool = False target_principal: Optional[str] = None delegates: Optional[List[str]] = None quota_project: Optional[str] = None connection_properties: Optional[Dict[str, str]] = None additional_properties: Optional[Dict[str, Any]] = None def _base_jdbc_properties(self) -> Dict[str, Any]: properties = {"cloudSqlInstance": self.instance_name} if self.ip_types: properties["ipTypes"] = ",".join(self.ip_types) if self.enable_iam_auth: properties["enableIamAuth"] = "true" if self.target_principal: properties["cloudSqlTargetPrincipal"] = self.target_principal if self.delegates: properties["cloudSqlDelegates"] = ",".join(self.delegates) if self.quota_project: properties["cloudSqlAdminQuotaProject"] = self.quota_project if self.additional_properties: properties.update(self.additional_properties) return properties def _build_jdbc_url(self, socketFactory, database_type): url = f"jdbc:{database_type}:///{self.database_name}?" properties = self._base_jdbc_properties() properties['socketFactory'] = socketFactory property_string = "&".join(f"{k}={v}" for k, v in properties.items()) return url + property_string
[docs] def to_connection_config(self): return ConnectionConfig( jdbc_url=self.to_jdbc_url(), username=self.username, password=self.password, connection_properties=self.connection_properties, additional_jdbc_args=self.additional_jdbc_args())
[docs] def additional_jdbc_args(self) -> Dict[str, List[Any]]: return {}
@dataclass class _PostgresConnectorConfig(LanguageConnectorConfig): def to_jdbc_url(self) -> str: """Convert options to a properly formatted JDBC URL. Returns: JDBC URL string configured with all options. """ return self._build_jdbc_url( socketFactory="com.google.cloud.sql.postgres.SocketFactory", database_type="postgresql") def additional_jdbc_args(self) -> Dict[str, List[Any]]: return { 'classpath': [ "org.postgresql:postgresql:42.2.16", "com.google.cloud.sql:postgres-socket-factory:1.25.0" ] } @classmethod def from_base_config(cls, config: LanguageConnectorConfig): return cls(**asdict(config))
[docs] class CloudSQLPostgresVectorWriterConfig(PostgresVectorWriterConfig): def __init__( self, connection_config: LanguageConnectorConfig, table_name: str, *, # pylint: disable=dangerous-default-value write_config: WriteConfig = WriteConfig(), column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build( ), conflict_resolution: Optional[ConflictResolution] = ConflictResolution( on_conflict_fields=[], action='IGNORE')): """Configuration for writing vectors to ClouSQL Postgres. Supports flexible schema configuration through column specifications and conflict resolution strategies. Args: connection_config: :class:`LanguageConnectorConfig`. table_name: Target table name. write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control batch sizes, authosharding, etc. column_specs: Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how embeddings and metadata are written a database schema. If None, uses default Chunk schema. conflict_resolution: Optional :class:`~.postgres_common.ConflictResolution` strategy for handling insert conflicts. ON CONFLICT DO NOTHING by default. Examples: Basic usage with default schema: >>> config = PostgresVectorWriterConfig( ... connection_config=PostgresConnectionConfig(...), ... table_name='embeddings' ... ) Simple case with default schema: >>> config = PostgresVectorWriterConfig( ... connection_config=ConnectionConfig(...), ... table_name='embeddings' ... ) Custom schema with metadata fields: >>> specs = (ColumnSpecsBuilder() ... .with_id_spec(column_name="my_id_column") ... .with_embedding_spec(column_name="embedding_vec") ... .add_metadata_field(field="source", column_name="src") ... .add_metadata_field( ... "timestamp", ... column_name="created_at", ... sql_typecast="::timestamp" ... ) ... .build()) Minimal schema (only ID + embedding written) >>> column_specs = (ColumnSpecsBuilder() ... .with_id_spec() ... .with_embedding_spec() ... .build()) >>> config = CloudSQLPostgresVectorWriterConfig( ... connection_config=PostgresConnectionConfig(...), ... table_name='embeddings', ... column_specs=specs ... ) """ self.connector_config = _PostgresConnectorConfig.from_base_config( connection_config) super().__init__( connection_config=self.connector_config.to_connection_config(), write_config=write_config, table_name=table_name, column_specs=column_specs, conflict_resolution=conflict_resolution)