#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
from abc import abstractmethod
from typing import Callable
from typing import List
from typing import NamedTuple
from typing import Optional
import apache_beam as beam
from apache_beam.coders import registry
from apache_beam.coders.row_coder import RowCoder
from apache_beam.io.jdbc import WriteToJdbc
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpec
from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpecsBuilder
from apache_beam.ml.rag.ingestion.mysql_common import ConflictResolution
from apache_beam.ml.rag.types import Chunk
_LOGGER = logging.getLogger(__name__)
class _ConflictResolutionStrategy(ABC):
  """Abstract base class for conflict resolution strategies."""
  @abstractmethod
  def get_conflict_clause(self, all_columns: List[str]) -> str:
    """Generate the MySQL conflict clause."""
    pass
class _NoConflictStrategy(_ConflictResolutionStrategy):
  """Strategy for when no conflict resolution is needed."""
  def get_conflict_clause(self, all_columns: List[str]) -> str:
    return ""
class _UpdateStrategy(_ConflictResolutionStrategy):
  """Strategy for UPDATE action on conflict."""
  def __init__(self, update_fields: Optional[List[str]] = None):
    self.update_fields = update_fields
  def get_conflict_clause(self, all_columns: List[str]) -> str:
    # Use provided fields or default to all columns
    fields_to_update = self.update_fields or all_columns
    assert len(fields_to_update) > 0
    updates = [f"{field} = VALUES({field})" for field in fields_to_update]
    return f"ON DUPLICATE KEY UPDATE {', '.join(updates)}"
class _IgnoreStrategy(_ConflictResolutionStrategy):
  """Strategy for IGNORE action on conflict."""
  def __init__(self, primary_key_field: str):
    self.primary_key_field = primary_key_field
  def get_conflict_clause(self, all_columns: List[str]) -> str:
    return f"ON DUPLICATE KEY UPDATE {self.primary_key_field}"\
       f" = {self.primary_key_field}"
def _create_conflict_strategy(
    conflict_resolution: Optional[ConflictResolution]
) -> _ConflictResolutionStrategy:
  if conflict_resolution is None:
    return _NoConflictStrategy()
  if conflict_resolution.action == "UPDATE":
    return _UpdateStrategy(conflict_resolution.update_fields)
  if conflict_resolution.action == "IGNORE":
    assert conflict_resolution.primary_key_field is not None
    return _IgnoreStrategy(conflict_resolution.primary_key_field)
  raise ValueError(f"Unknown conflict resolution {conflict_resolution.action}")
class _MySQLQueryBuilder:
  def __init__(
      self,
      table_name: str,
      *,
      column_specs: List[ColumnSpec],
      conflict_resolution: Optional[ConflictResolution] = None):
    """Builds SQL queries for writing Chunks with Embeddings to MySQL.
    """
    self.table_name = table_name
    self.column_specs = column_specs
    self.conflict_resolution_strategy = _create_conflict_strategy(
        conflict_resolution)
    names = [col.column_name for col in self.column_specs]
    duplicates = set(name for name in names if names.count(name) > 1)
    if duplicates:
      raise ValueError(f"Duplicate column names found: {duplicates}")
    fields = [(col.column_name, col.python_type) for col in self.column_specs]
    type_name = f"VectorRecord_{table_name}"
    self.record_type = NamedTuple(type_name, fields)  # type: ignore
    registry.register_coder(self.record_type, RowCoder)
  def build_insert(self) -> str:
    fields = [col.column_name for col in self.column_specs]
    placeholders = [col.placeholder for col in self.column_specs]
    # Build base query
    query = f"""
        INSERT INTO {self.table_name}
        ({', '.join(fields)})
        VALUES ({', '.join(placeholders)})
    """
    conflict_clause = self.conflict_resolution_strategy.get_conflict_clause(
        all_columns=fields)
    query += f" {conflict_clause}"
    _LOGGER.info("MySQL Query with placeholders %s", query)
    return query
  def create_converter(self) -> Callable[[Chunk], NamedTuple]:
    """Creates a function to convert Chunks to records."""
    def convert(chunk: Chunk) -> self.record_type:  # type: ignore
      return self.record_type(
          **{col.column_name: col.value_fn(chunk)
             for col in self.column_specs})  # type: ignore
    return convert
[docs]
class MySQLVectorWriterConfig(VectorDatabaseWriteConfig):
  def __init__(
      self,
      connection_config: ConnectionConfig,
      table_name: str,
      *,
      # pylint: disable=dangerous-default-value
      write_config: WriteConfig = WriteConfig(),
      column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build(
      ),
      conflict_resolution: Optional[ConflictResolution] = None):
    """Configuration for writing vectors to MySQL using jdbc.
    
    Supports flexible schema configuration through column specifications and
    conflict resolution strategies with MySQL-specific syntax.
    Args:
        connection_config:
          :class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`.
        table_name: Target table name.
        write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control
          batch sizes, authosharding, etc.
        column_specs:
            Use :class:`~.mysql_common.ColumnSpecsBuilder` to configure how
            embeddings and metadata are written to the database
            schema. If None, uses default Chunk schema with MySQL vector
            functions.
        conflict_resolution: Optional
            :class:`~.mysql_common.ConflictResolution`
            strategy for handling insert conflicts. ON DUPLICATE KEY UPDATE.
            None by default, meaning errors are thrown when attempting to insert
            duplicates.
    
    Examples:
        Simple case with default schema:
        >>> config = MySQLVectorWriterConfig(
        ...     connection_config=ConnectionConfig(...),
        ...     table_name='embeddings'
        ... )
        Custom schema with metadata fields and MySQL functions:
        >>> specs = (ColumnSpecsBuilder()
        ...         .with_id_spec(column_name="my_id_column")
        ...         .with_embedding_spec(
        ...             column_name="embedding_vec",
        ...             placeholder="string_to_vector(?)"
        ...         )
        ...         .add_metadata_field(field="source", column_name="src")
        ...         .add_metadata_field(
        ...             "timestamp",
        ...             column_name="created_at",
        ...             placeholder="STR_TO_DATE(?, '%Y-%m-%d %H:%i:%s')"
        ...         )
        ...         .build())
        Minimal schema (only ID + embedding written):
        >>> column_specs = (ColumnSpecsBuilder()
        ...     .with_id_spec()
        ...     .with_embedding_spec()
        ...     .build())
        >>> config = MySQLVectorWriterConfig(
        ...     connection_config=ConnectionConfig(...),
        ...     table_name='embeddings',
        ...     column_specs=specs,
        ...     conflict_resolution=ConflictResolution(
        ...         on_conflict_fields=["id"],
        ...         action="UPDATE",
        ...         update_fields=["embedding", "content"]
        ...     )
        ... )
        Using MySQL JSON functions:
        >>> specs = (ColumnSpecsBuilder()
        ...     .with_id_spec()
        ...     .with_embedding_spec()
        ...     .with_metadata_spec(
        ...         column_name="metadata_json",
        ...         placeholder="CAST(? AS JSON)"
        ...     )
        ...     .build())
    """
    self.connection_config = connection_config
    self.write_config = write_config
    # NamedTuple is created and registered here during pipeline construction
    self.query_builder = _MySQLQueryBuilder(
        table_name,
        column_specs=column_specs,
        conflict_resolution=conflict_resolution)
 
class _WriteToMySQLVectorDatabase(beam.PTransform):
  """Implementation of MySQL vector database write."""
  def __init__(self, config: MySQLVectorWriterConfig):
    self.config = config
    self.query_builder = config.query_builder
    self.connection_config = config.connection_config
    self.write_config = config.write_config
  def expand(self, pcoll: beam.PCollection[Chunk]):
    return (
        pcoll
        |
        "Convert to Records" >> beam.Map(self.query_builder.create_converter())
        | "Write to MySQL" >> WriteToJdbc(
            table_name=self.query_builder.table_name,
            driver_class_name="com.mysql.cj.jdbc.Driver",
            jdbc_url=self.connection_config.jdbc_url,
            username=self.connection_config.username,
            password=self.connection_config.password,
            statement=self.query_builder.build_insert(),
            connection_properties=self.connection_config.connection_properties,
            connection_init_sqls=self.connection_config.connection_init_sqls,
            autosharding=self.write_config.autosharding,
            max_connections=self.write_config.max_connections,
            write_batch_size=self.write_config.write_batch_size,
            **self.connection_config.additional_jdbc_args))