#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
from abc import abstractmethod
from typing import Callable
from typing import List
from typing import NamedTuple
from typing import Optional
import apache_beam as beam
from apache_beam.coders import registry
from apache_beam.coders.row_coder import RowCoder
from apache_beam.io.jdbc import WriteToJdbc
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpec
from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpecsBuilder
from apache_beam.ml.rag.ingestion.mysql_common import ConflictResolution
from apache_beam.ml.rag.types import Chunk
_LOGGER = logging.getLogger(__name__)
class _ConflictResolutionStrategy(ABC):
"""Abstract base class for conflict resolution strategies."""
@abstractmethod
def get_conflict_clause(self, all_columns: List[str]) -> str:
"""Generate the MySQL conflict clause."""
pass
class _NoConflictStrategy(_ConflictResolutionStrategy):
"""Strategy for when no conflict resolution is needed."""
def get_conflict_clause(self, all_columns: List[str]) -> str:
return ""
class _UpdateStrategy(_ConflictResolutionStrategy):
"""Strategy for UPDATE action on conflict."""
def __init__(self, update_fields: Optional[List[str]] = None):
self.update_fields = update_fields
def get_conflict_clause(self, all_columns: List[str]) -> str:
# Use provided fields or default to all columns
fields_to_update = self.update_fields or all_columns
assert len(fields_to_update) > 0
updates = [f"{field} = VALUES({field})" for field in fields_to_update]
return f"ON DUPLICATE KEY UPDATE {', '.join(updates)}"
class _IgnoreStrategy(_ConflictResolutionStrategy):
"""Strategy for IGNORE action on conflict."""
def __init__(self, primary_key_field: str):
self.primary_key_field = primary_key_field
def get_conflict_clause(self, all_columns: List[str]) -> str:
return f"ON DUPLICATE KEY UPDATE {self.primary_key_field}"\
f" = {self.primary_key_field}"
def _create_conflict_strategy(
conflict_resolution: Optional[ConflictResolution]
) -> _ConflictResolutionStrategy:
if conflict_resolution is None:
return _NoConflictStrategy()
if conflict_resolution.action == "UPDATE":
return _UpdateStrategy(conflict_resolution.update_fields)
if conflict_resolution.action == "IGNORE":
assert conflict_resolution.primary_key_field is not None
return _IgnoreStrategy(conflict_resolution.primary_key_field)
raise ValueError(f"Unknown conflict resolution {conflict_resolution.action}")
class _MySQLQueryBuilder:
def __init__(
self,
table_name: str,
*,
column_specs: List[ColumnSpec],
conflict_resolution: Optional[ConflictResolution] = None):
"""Builds SQL queries for writing Chunks with Embeddings to MySQL.
"""
self.table_name = table_name
self.column_specs = column_specs
self.conflict_resolution_strategy = _create_conflict_strategy(
conflict_resolution)
names = [col.column_name for col in self.column_specs]
duplicates = set(name for name in names if names.count(name) > 1)
if duplicates:
raise ValueError(f"Duplicate column names found: {duplicates}")
fields = [(col.column_name, col.python_type) for col in self.column_specs]
type_name = f"VectorRecord_{table_name}"
self.record_type = NamedTuple(type_name, fields) # type: ignore
registry.register_coder(self.record_type, RowCoder)
def build_insert(self) -> str:
fields = [col.column_name for col in self.column_specs]
placeholders = [col.placeholder for col in self.column_specs]
# Build base query
query = f"""
INSERT INTO {self.table_name}
({', '.join(fields)})
VALUES ({', '.join(placeholders)})
"""
conflict_clause = self.conflict_resolution_strategy.get_conflict_clause(
all_columns=fields)
query += f" {conflict_clause}"
_LOGGER.info("MySQL Query with placeholders %s", query)
return query
def create_converter(self) -> Callable[[Chunk], NamedTuple]:
"""Creates a function to convert Chunks to records."""
def convert(chunk: Chunk) -> self.record_type: # type: ignore
return self.record_type(
**{col.column_name: col.value_fn(chunk)
for col in self.column_specs}) # type: ignore
return convert
[docs]
class MySQLVectorWriterConfig(VectorDatabaseWriteConfig):
def __init__(
self,
connection_config: ConnectionConfig,
table_name: str,
*,
# pylint: disable=dangerous-default-value
write_config: WriteConfig = WriteConfig(),
column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build(
),
conflict_resolution: Optional[ConflictResolution] = None):
"""Configuration for writing vectors to MySQL using jdbc.
Supports flexible schema configuration through column specifications and
conflict resolution strategies with MySQL-specific syntax.
Args:
connection_config:
:class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`.
table_name: Target table name.
write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control
batch sizes, authosharding, etc.
column_specs:
Use :class:`~.mysql_common.ColumnSpecsBuilder` to configure how
embeddings and metadata are written to the database
schema. If None, uses default Chunk schema with MySQL vector
functions.
conflict_resolution: Optional
:class:`~.mysql_common.ConflictResolution`
strategy for handling insert conflicts. ON DUPLICATE KEY UPDATE.
None by default, meaning errors are thrown when attempting to insert
duplicates.
Examples:
Simple case with default schema:
>>> config = MySQLVectorWriterConfig(
... connection_config=ConnectionConfig(...),
... table_name='embeddings'
... )
Custom schema with metadata fields and MySQL functions:
>>> specs = (ColumnSpecsBuilder()
... .with_id_spec(column_name="my_id_column")
... .with_embedding_spec(
... column_name="embedding_vec",
... placeholder="string_to_vector(?)"
... )
... .add_metadata_field(field="source", column_name="src")
... .add_metadata_field(
... "timestamp",
... column_name="created_at",
... placeholder="STR_TO_DATE(?, '%Y-%m-%d %H:%i:%s')"
... )
... .build())
Minimal schema (only ID + embedding written):
>>> column_specs = (ColumnSpecsBuilder()
... .with_id_spec()
... .with_embedding_spec()
... .build())
>>> config = MySQLVectorWriterConfig(
... connection_config=ConnectionConfig(...),
... table_name='embeddings',
... column_specs=specs,
... conflict_resolution=ConflictResolution(
... on_conflict_fields=["id"],
... action="UPDATE",
... update_fields=["embedding", "content"]
... )
... )
Using MySQL JSON functions:
>>> specs = (ColumnSpecsBuilder()
... .with_id_spec()
... .with_embedding_spec()
... .with_metadata_spec(
... column_name="metadata_json",
... placeholder="CAST(? AS JSON)"
... )
... .build())
"""
self.connection_config = connection_config
self.write_config = write_config
# NamedTuple is created and registered here during pipeline construction
self.query_builder = _MySQLQueryBuilder(
table_name,
column_specs=column_specs,
conflict_resolution=conflict_resolution)
class _WriteToMySQLVectorDatabase(beam.PTransform):
"""Implementation of MySQL vector database write."""
def __init__(self, config: MySQLVectorWriterConfig):
self.config = config
self.query_builder = config.query_builder
self.connection_config = config.connection_config
self.write_config = config.write_config
def expand(self, pcoll: beam.PCollection[Chunk]):
return (
pcoll
|
"Convert to Records" >> beam.Map(self.query_builder.create_converter())
| "Write to MySQL" >> WriteToJdbc(
table_name=self.query_builder.table_name,
driver_class_name="com.mysql.cj.jdbc.Driver",
jdbc_url=self.connection_config.jdbc_url,
username=self.connection_config.username,
password=self.connection_config.password,
statement=self.query_builder.build_insert(),
connection_properties=self.connection_config.connection_properties,
connection_init_sqls=self.connection_config.connection_init_sqls,
autosharding=self.write_config.autosharding,
max_connections=self.write_config.max_connections,
write_batch_size=self.write_config.write_batch_size,
**self.connection_config.additional_jdbc_args))