Source code for apache_beam.ml.rag.ingestion.milvus_search

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional

from pymilvus import MilvusClient
from pymilvus.exceptions import MilvusException

import apache_beam as beam
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.utils import DEFAULT_WRITE_BATCH_SIZE
from apache_beam.ml.rag.utils import MilvusConnectionParameters
from apache_beam.ml.rag.utils import MilvusHelpers
from apache_beam.ml.rag.utils import retry_with_backoff
from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs
from apache_beam.transforms import DoFn

_LOGGER = logging.getLogger(__name__)


[docs] @dataclass class MilvusWriteConfig: """Configuration parameters for writing data to Milvus collections. This class defines the parameters needed to write data to a Milvus collection, including collection targeting, batching behavior, and operation timeouts. Args: collection_name: Name of the target Milvus collection to write data to. Must be a non-empty string. partition_name: Name of the specific partition within the collection to write to. If empty, writes to the default partition. timeout: Maximum time in seconds to wait for write operations to complete. If None, uses the client's default timeout. write_config: Configuration for write operations including batch size and other write-specific settings. kwargs: Additional keyword arguments for write operations. Enables forward compatibility with future Milvus client parameters. """ collection_name: str partition_name: str = "" timeout: Optional[float] = None write_config: WriteConfig = field(default_factory=WriteConfig) kwargs: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): if not self.collection_name: raise ValueError("Collection name must be provided") @property def write_batch_size(self): """Returns the batch size for write operations. Returns: The configured batch size, or DEFAULT_WRITE_BATCH_SIZE if not specified. """ return self.write_config.write_batch_size or DEFAULT_WRITE_BATCH_SIZE
[docs] @dataclass class MilvusVectorWriterConfig(VectorDatabaseWriteConfig): """Configuration for writing vector data to Milvus collections. This class extends VectorDatabaseWriteConfig to provide Milvus-specific configuration for ingesting vector embeddings and associated metadata. It defines how Apache Beam chunks are converted to Milvus records and handles the write operation parameters. The configuration includes connection parameters, write settings, and column specifications that determine how chunk data is mapped to Milvus fields. Args: connection_params: Configuration for connecting to the Milvus server, including URI, credentials, and connection options. write_config: Configuration for write operations including collection name, partition, batch size, and timeouts. column_specs: List of column specifications defining how chunk fields are mapped to Milvus collection fields. Defaults to standard RAG fields (id, embedding, sparse_embedding, content, metadata). Example: config = MilvusVectorWriterConfig( connection_params=MilvusConnectionParameters( uri="http://localhost:19530"), write_config=MilvusWriteConfig(collection_name="my_collection"), column_specs=MilvusVectorWriterConfig.default_column_specs()) """ connection_params: MilvusConnectionParameters write_config: MilvusWriteConfig column_specs: List[ColumnSpec] = field( default_factory=lambda: MilvusVectorWriterConfig.default_column_specs())
[docs] def create_converter(self) -> Callable[[Chunk], Dict[str, Any]]: """Creates a function to convert Apache Beam Chunks to Milvus records. Returns: A function that takes a Chunk and returns a dictionary representing a Milvus record with fields mapped according to column_specs. """ def convert(chunk: Chunk) -> Dict[str, Any]: result = {} for col in self.column_specs: result[col.column_name] = col.value_fn(chunk) return result return convert
[docs] def create_write_transform(self) -> beam.PTransform: """Creates the Apache Beam transform for writing to Milvus. Returns: A PTransform that can be applied to a PCollection of Chunks to write them to the configured Milvus collection. """ return _WriteToMilvusVectorDatabase(self)
[docs] @staticmethod def default_column_specs() -> List[ColumnSpec]: """Returns default column specifications for RAG use cases. Creates column mappings for standard RAG fields: id, dense embedding, sparse embedding, content text, and metadata. These specifications define how Chunk fields are converted to Milvus-compatible formats. Returns: List of ColumnSpec objects defining the default field mappings. """ column_specs = ColumnSpecsBuilder() return column_specs\ .with_id_spec()\ .with_embedding_spec(convert_fn=lambda values: list(values))\ .with_sparse_embedding_spec(conv_fn=MilvusHelpers.sparse_embedding)\ .with_content_spec()\ .with_metadata_spec(convert_fn=lambda values: dict(values))\ .build()
class _WriteToMilvusVectorDatabase(beam.PTransform): """Apache Beam PTransform for writing vector data to Milvus. This transform handles the conversion of Apache Beam Chunks to Milvus records and coordinates the write operations. It applies the configured converter function and uses a DoFn for batched writes to optimize performance. Args: config: MilvusVectorWriterConfig containing all necessary parameters for the write operation. """ def __init__(self, config: MilvusVectorWriterConfig): self.config = config def expand(self, pcoll: beam.PCollection[Chunk]): """Expands the PTransform to convert chunks and write to Milvus. Args: pcoll: PCollection of Chunk objects to write to Milvus. Returns: PCollection of dictionaries representing the records written to Milvus. """ return ( pcoll | "Convert to Records" >> beam.Map(self.config.create_converter()) | beam.ParDo( _WriteMilvusFn( self.config.connection_params, self.config.write_config))) class _WriteMilvusFn(DoFn): """DoFn that handles batched writes to Milvus. This DoFn accumulates records in batches and flushes them to Milvus when the batch size is reached or when the bundle finishes. This approach optimizes performance by reducing the number of individual write operations. Args: connection_params: Configuration for connecting to the Milvus server. write_config: Configuration for write operations including batch size and collection details. """ def __init__( self, connection_params: MilvusConnectionParameters, write_config: MilvusWriteConfig): self._connection_params = connection_params self._write_config = write_config self.batch = [] def process(self, element, *args, **kwargs): """Processes individual records, batching them for efficient writes. Args: element: A dictionary representing a Milvus record to write. *args: Additional positional arguments. **kwargs: Additional keyword arguments. Yields: The original element after adding it to the batch. """ _ = args, kwargs # Unused parameters self.batch.append(element) if len(self.batch) >= self._write_config.write_batch_size: self._flush() yield element def finish_bundle(self): """Called when a bundle finishes processing. Flushes any remaining records in the batch to ensure all data is written. """ self._flush() def _flush(self): """Flushes the current batch of records to Milvus. Creates a MilvusSink connection and writes all batched records, then clears the batch for the next set of records. """ if len(self.batch) == 0: return with _MilvusSink(self._connection_params, self._write_config) as sink: sink.write(self.batch) self.batch = [] def display_data(self): """Returns display data for monitoring and debugging. Returns: Dictionary containing database, collection, and batch size information for display in the Apache Beam monitoring UI. """ res = super().display_data() res["database"] = self._connection_params.db_name res["collection"] = self._write_config.collection_name res["batch_size"] = self._write_config.write_batch_size return res class _MilvusSink: """Low-level sink for writing data directly to Milvus. This class handles the direct interaction with the Milvus client for upsert operations. It manages the connection lifecycle and provides context manager support for proper resource cleanup. Args: connection_params: Configuration for connecting to the Milvus server. write_config: Configuration for write operations including collection and partition targeting. """ def __init__( self, connection_params: MilvusConnectionParameters, write_config: MilvusWriteConfig): self._connection_params = connection_params self._write_config = write_config self._client = None def write(self, documents): """Writes a batch of documents to the Milvus collection. Performs an upsert operation to insert new documents or update existing ones based on primary key. After the upsert, flushes the collection to ensure data persistence. Args: documents: List of dictionaries representing Milvus records to write. Each dictionary should contain fields matching the collection schema. """ self._client = MilvusClient( **unpack_dataclass_with_kwargs(self._connection_params)) resp = self._client.upsert( collection_name=self._write_config.collection_name, partition_name=self._write_config.partition_name, data=documents, timeout=self._write_config.timeout, **self._write_config.kwargs) _LOGGER.debug( "Upserted into Milvus: upsert_count=%d, cost=%d", resp.get("upsert_count", 0), resp.get("cost", 0)) def __enter__(self): """Enters the context manager and establishes Milvus connection. Returns: Self, enabling use in 'with' statements. """ if not self._client: connection_params = unpack_dataclass_with_kwargs(self._connection_params) # Extract retry parameters from connection_params. max_retries = connection_params.pop('max_retries', 3) retry_delay = connection_params.pop('retry_delay', 1.0) retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0) def create_client(): return MilvusClient(**connection_params) self._client = retry_with_backoff( create_client, max_retries=max_retries, retry_delay=retry_delay, retry_backoff_factor=retry_backoff_factor, operation_name="Milvus connection", exception_types=(MilvusException, )) return self def __exit__(self, exc_type, exc_val, exc_tb): """Exits the context manager and closes the Milvus connection. Args: exc_type: Exception type if an exception was raised. exc_val: Exception value if an exception was raised. exc_tb: Exception traceback if an exception was raised. """ _ = exc_type, exc_val, exc_tb # Unused parameters if self._client: self._client.close()