Source code for apache_beam.ml.rag.ingestion.mysql

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
from abc import ABC
from abc import abstractmethod
from typing import Callable
from typing import List
from typing import NamedTuple
from typing import Optional

import apache_beam as beam
from apache_beam.coders import registry
from apache_beam.coders.row_coder import RowCoder
from apache_beam.io.jdbc import WriteToJdbc
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpec
from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpecsBuilder
from apache_beam.ml.rag.ingestion.mysql_common import ConflictResolution
from apache_beam.ml.rag.types import Chunk

_LOGGER = logging.getLogger(__name__)


class _ConflictResolutionStrategy(ABC):
  """Abstract base class for conflict resolution strategies."""
  @abstractmethod
  def get_conflict_clause(self, all_columns: List[str]) -> str:
    """Generate the MySQL conflict clause."""
    pass


class _NoConflictStrategy(_ConflictResolutionStrategy):
  """Strategy for when no conflict resolution is needed."""
  def get_conflict_clause(self, all_columns: List[str]) -> str:
    return ""


class _UpdateStrategy(_ConflictResolutionStrategy):
  """Strategy for UPDATE action on conflict."""
  def __init__(self, update_fields: Optional[List[str]] = None):
    self.update_fields = update_fields

  def get_conflict_clause(self, all_columns: List[str]) -> str:
    # Use provided fields or default to all columns
    fields_to_update = self.update_fields or all_columns
    assert len(fields_to_update) > 0

    updates = [f"{field} = VALUES({field})" for field in fields_to_update]
    return f"ON DUPLICATE KEY UPDATE {', '.join(updates)}"


class _IgnoreStrategy(_ConflictResolutionStrategy):
  """Strategy for IGNORE action on conflict."""
  def __init__(self, primary_key_field: str):
    self.primary_key_field = primary_key_field

  def get_conflict_clause(self, all_columns: List[str]) -> str:
    return f"ON DUPLICATE KEY UPDATE {self.primary_key_field}"\
       f" = {self.primary_key_field}"


def _create_conflict_strategy(
    conflict_resolution: Optional[ConflictResolution]
) -> _ConflictResolutionStrategy:
  if conflict_resolution is None:
    return _NoConflictStrategy()
  if conflict_resolution.action == "UPDATE":
    return _UpdateStrategy(conflict_resolution.update_fields)
  if conflict_resolution.action == "IGNORE":
    assert conflict_resolution.primary_key_field is not None
    return _IgnoreStrategy(conflict_resolution.primary_key_field)
  raise ValueError(f"Unknown conflict resolution {conflict_resolution.action}")


class _MySQLQueryBuilder:
  def __init__(
      self,
      table_name: str,
      *,
      column_specs: List[ColumnSpec],
      conflict_resolution: Optional[ConflictResolution] = None):
    """Builds SQL queries for writing Chunks with Embeddings to MySQL.
    """
    self.table_name = table_name

    self.column_specs = column_specs
    self.conflict_resolution_strategy = _create_conflict_strategy(
        conflict_resolution)

    names = [col.column_name for col in self.column_specs]
    duplicates = set(name for name in names if names.count(name) > 1)
    if duplicates:
      raise ValueError(f"Duplicate column names found: {duplicates}")

    fields = [(col.column_name, col.python_type) for col in self.column_specs]
    type_name = f"VectorRecord_{table_name}"
    self.record_type = NamedTuple(type_name, fields)  # type: ignore

    registry.register_coder(self.record_type, RowCoder)

  def build_insert(self) -> str:
    fields = [col.column_name for col in self.column_specs]
    placeholders = [col.placeholder for col in self.column_specs]

    # Build base query
    query = f"""
        INSERT INTO {self.table_name}
        ({', '.join(fields)})
        VALUES ({', '.join(placeholders)})
    """
    conflict_clause = self.conflict_resolution_strategy.get_conflict_clause(
        all_columns=fields)
    query += f" {conflict_clause}"

    _LOGGER.info("MySQL Query with placeholders %s", query)
    return query

  def create_converter(self) -> Callable[[Chunk], NamedTuple]:
    """Creates a function to convert Chunks to records."""
    def convert(chunk: Chunk) -> self.record_type:  # type: ignore
      return self.record_type(
          **{col.column_name: col.value_fn(chunk)
             for col in self.column_specs})  # type: ignore

    return convert


[docs] class MySQLVectorWriterConfig(VectorDatabaseWriteConfig): def __init__( self, connection_config: ConnectionConfig, table_name: str, *, # pylint: disable=dangerous-default-value write_config: WriteConfig = WriteConfig(), column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build( ), conflict_resolution: Optional[ConflictResolution] = None): """Configuration for writing vectors to MySQL using jdbc. Supports flexible schema configuration through column specifications and conflict resolution strategies with MySQL-specific syntax. Args: connection_config: :class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`. table_name: Target table name. write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control batch sizes, authosharding, etc. column_specs: Use :class:`~.mysql_common.ColumnSpecsBuilder` to configure how embeddings and metadata are written to the database schema. If None, uses default Chunk schema with MySQL vector functions. conflict_resolution: Optional :class:`~.mysql_common.ConflictResolution` strategy for handling insert conflicts. ON DUPLICATE KEY UPDATE. None by default, meaning errors are thrown when attempting to insert duplicates. Examples: Simple case with default schema: >>> config = MySQLVectorWriterConfig( ... connection_config=ConnectionConfig(...), ... table_name='embeddings' ... ) Custom schema with metadata fields and MySQL functions: >>> specs = (ColumnSpecsBuilder() ... .with_id_spec(column_name="my_id_column") ... .with_embedding_spec( ... column_name="embedding_vec", ... placeholder="string_to_vector(?)" ... ) ... .add_metadata_field(field="source", column_name="src") ... .add_metadata_field( ... "timestamp", ... column_name="created_at", ... placeholder="STR_TO_DATE(?, '%Y-%m-%d %H:%i:%s')" ... ) ... .build()) Minimal schema (only ID + embedding written): >>> column_specs = (ColumnSpecsBuilder() ... .with_id_spec() ... .with_embedding_spec() ... .build()) >>> config = MySQLVectorWriterConfig( ... connection_config=ConnectionConfig(...), ... table_name='embeddings', ... column_specs=specs, ... conflict_resolution=ConflictResolution( ... on_conflict_fields=["id"], ... action="UPDATE", ... update_fields=["embedding", "content"] ... ) ... ) Using MySQL JSON functions: >>> specs = (ColumnSpecsBuilder() ... .with_id_spec() ... .with_embedding_spec() ... .with_metadata_spec( ... column_name="metadata_json", ... placeholder="CAST(? AS JSON)" ... ) ... .build()) """ self.connection_config = connection_config self.write_config = write_config # NamedTuple is created and registered here during pipeline construction self.query_builder = _MySQLQueryBuilder( table_name, column_specs=column_specs, conflict_resolution=conflict_resolution)
[docs] def create_write_transform(self) -> beam.PTransform: return _WriteToMySQLVectorDatabase(self)
class _WriteToMySQLVectorDatabase(beam.PTransform): """Implementation of MySQL vector database write.""" def __init__(self, config: MySQLVectorWriterConfig): self.config = config self.query_builder = config.query_builder self.connection_config = config.connection_config self.write_config = config.write_config def expand(self, pcoll: beam.PCollection[Chunk]): return ( pcoll | "Convert to Records" >> beam.Map(self.query_builder.create_converter()) | "Write to MySQL" >> WriteToJdbc( table_name=self.query_builder.table_name, driver_class_name="com.mysql.cj.jdbc.Driver", jdbc_url=self.connection_config.jdbc_url, username=self.connection_config.username, password=self.connection_config.password, statement=self.query_builder.build_insert(), connection_properties=self.connection_config.connection_properties, connection_init_sqls=self.connection_config.connection_init_sqls, autosharding=self.write_config.autosharding, max_connections=self.write_config.max_connections, write_batch_size=self.write_config.write_batch_size, **self.connection_config.additional_jdbc_args))