#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import field
from enum import Enum
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from google.protobuf.json_format import MessageToDict
from pymilvus import AnnSearchRequest
from pymilvus import Hit
from pymilvus import Hits
from pymilvus import MilvusClient
from pymilvus import SearchResult
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Embedding
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
[docs]
class SearchStrategy(Enum):
"""Search strategies for information retrieval.
Args:
HYBRID: Combines vector and keyword search approaches. Leverages both
semantic understanding and exact matching. Typically provides the most
comprehensive results. Useful for queries with both conceptual and
specific keyword components.
VECTOR: Vector similarity search only. Based on semantic similarity between
query and documents. Effective for conceptual searches and finding related
content. Less sensitive to exact terminology than keyword search.
KEYWORD: Keyword/text search only. Based on exact or fuzzy matching of
specific terms. Effective for precise queries where exact wording matters.
Less effective for conceptual or semantic searches.
"""
HYBRID = "hybrid"
VECTOR = "vector"
KEYWORD = "keyword"
[docs]
class KeywordSearchMetrics(Enum):
"""Metrics for keyword search.
Args:
BM25: Range [0 to ∞), Best Match 25 ranking algorithm for text relevance.
Combines term frequency, inverse document frequency, and document length.
Higher scores indicate greater relevance. Higher scores indicate greater
relevance. Takes into account diminishing returns of term frequency.
Balances between exact matching and semantic relevance.
"""
BM25 = "BM25"
[docs]
class VectorSearchMetrics(Enum):
"""Metrics for vector search.
Args:
COSINE: Range [-1 to 1], higher values indicate greater similarity. Value 1
means vectors point in identical direction. Value 0 means vectors are
perpendicular to each other (no relationship). Value -1 means vectors
point in exactly opposite directions.
EUCLIDEAN_DISTANCE (L2): Range [0 to ∞), lower values indicate greater
similarity. Value 0 means vectors are identical. Larger values mean more
dissimilarity between vectors.
INNER_PRODUCT (IP): Range varies based on vector magnitudes, higher values
indicate greater similarity. Value 0 means vectors are perpendicular to
each other. Positive values mean vectors share some directional component.
Negative values mean vectors point in opposing directions.
"""
COSINE = "COSINE"
EUCLIDEAN_DISTANCE = "L2"
INNER_PRODUCT = "IP"
[docs]
class MilvusBaseRanker:
"""Base class for ranking algorithms in Milvus hybrid search strategy."""
def __int__(self):
return
[docs]
def dict(self):
return {}
def __str__(self):
return self.dict().__str__()
[docs]
@dataclass
class MilvusConnectionParameters:
"""Parameters for establishing connections to Milvus servers.
Args:
uri: URI endpoint for connecting to Milvus server in the format
"http(s)://hostname:port".
user: Username for authentication. Required if authentication is enabled and
not using token authentication.
password: Password for authentication. Required if authentication is enabled
and not using token authentication.
db_id: Database ID to connect to. Specifies which Milvus database to use.
Defaults to 'default'.
token: Authentication token as an alternative to username/password.
timeout: Connection timeout in seconds. Uses client default if None.
kwargs: Optional keyword arguments for additional connection parameters.
Enables forward compatibility.
"""
uri: str
user: str = field(default_factory=str)
password: str = field(default_factory=str)
db_id: str = "default"
token: str = field(default_factory=str)
timeout: Optional[float] = None
kwargs: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if not self.uri:
raise ValueError("URI must be provided for Milvus connection")
[docs]
@dataclass
class BaseSearchParameters:
"""Base parameters for both vector and keyword search operations.
Args:
anns_field: Approximate nearest neighbor search field indicates field name
containing the embedding to search. Required for both vector and keyword
search.
limit: Maximum number of results to return per query. Must be positive.
Defaults to 3 search results.
filter: Boolean expression string for filtering search results.
Example: 'price <= 1000 AND category == "electronics"'.
search_params: Additional search parameters specific to the search type.
Example: {"metric_type": VectorSearchMetrics.EUCLIDEAN_DISTANCE}.
consistency_level: Consistency level for read operations.
Options: "Strong", "Session", "Bounded", "Eventually". Defaults to
"Bounded" if not specified when creating the collection.
"""
anns_field: str
limit: int = 3
filter: str = field(default_factory=str)
search_params: Dict[str, Any] = field(default_factory=dict)
consistency_level: Optional[str] = None
def __post_init__(self):
if not self.anns_field:
raise ValueError(
"Approximate Nearest Neighbor Search (ANNS) field must be provided")
if self.limit <= 0:
raise ValueError(f"Search limit must be positive, got {self.limit}")
[docs]
@dataclass
class VectorSearchParameters(BaseSearchParameters):
"""Parameters for vector similarity search operations.
Inherits all parameters from BaseSearchParameters with the same semantics.
The anns_field should contain dense vector embeddings for this search type.
Args:
kwargs: Optional keyword arguments for additional vector search parameters.
Enables forward compatibility.
Note:
For inherited parameters documentation, see BaseSearchParameters.
"""
kwargs: Dict[str, Any] = field(default_factory=dict)
[docs]
@dataclass
class KeywordSearchParameters(BaseSearchParameters):
"""Parameters for keyword/text search operations.
This class inherits all parameters from BaseSearchParameters with the same
semantics. The anns_field should contain sparse vector embeddings content for
this search type.
Args:
kwargs: Optional keyword arguments for additional keyword search parameters.
Enables forward compatibility.
Note:
For inherited parameters documentation, see BaseSearchParameters.
"""
kwargs: Dict[str, Any] = field(default_factory=dict)
[docs]
@dataclass
class HybridSearchParameters:
"""Parameters for hybrid (vector + keyword) search operations.
Args:
vector: Parameters for the vector search component.
keyword: Parameters for the keyword search component.
ranker: Ranker for combining vector and keyword search results.
Example: RRFRanker(k=100).
limit: Maximum number of results to return per query. Defaults to 3 search
results.
kwargs: Optional keyword arguments for additional hybrid search parameters.
Enables forward compatibility.
"""
vector: VectorSearchParameters
keyword: KeywordSearchParameters
ranker: MilvusBaseRanker
limit: int = 3
kwargs: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if not self.vector or not self.keyword:
raise ValueError(
"Vector and keyword search parameters must be provided for "
"hybrid search")
if not self.ranker:
raise ValueError("Ranker must be provided for hybrid search")
if self.limit <= 0:
raise ValueError(f"Search limit must be positive, got {self.limit}")
SearchStrategyType = Union[VectorSearchParameters,
KeywordSearchParameters,
HybridSearchParameters]
[docs]
@dataclass
class MilvusSearchParameters:
"""Parameters configuring Milvus search operations.
This class encapsulates all parameters needed to execute searches against
Milvus collections, supporting vector, keyword, and hybrid search strategies.
Args:
collection_name: Name of the collection to search in.
search_strategy: Type of search to perform (VECTOR, KEYWORD, or HYBRID).
partition_names: List of partition names to restrict the search to. If
empty, all partitions will be searched.
output_fields: List of field names to include in search results. If empty,
only primary fields including distances will be returned.
timeout: Search operation timeout in seconds. If not specified, the client's
default timeout is used.
round_decimal: Number of decimal places for distance/similarity scores.
Defaults to -1 means no rounding.
"""
collection_name: str
search_strategy: SearchStrategyType
partition_names: List[str] = field(default_factory=list)
output_fields: List[str] = field(default_factory=list)
timeout: Optional[float] = None
round_decimal: int = -1
def __post_init__(self):
if not self.collection_name:
raise ValueError("Collection name must be provided")
if not self.search_strategy:
raise ValueError("Search strategy must be provided")
[docs]
@dataclass
class MilvusCollectionLoadParameters:
"""Parameters that control how Milvus loads a collection into memory.
This class provides fine-grained control over collection loading, which is
particularly important in resource-constrained environments. Proper
configuration can significantly reduce memory usage and improve query
performance by loading only necessary data.
Args:
refresh: If True, forces a reload of the collection even if already loaded.
Ensures the most up-to-date data is in memory.
resource_groups: List of resource groups to load the collection into. Can be
used for load balancing across multiple query nodes.
load_fields: Specify which fields to load into memory. Loading only
necessary fields reduces memory usage. If empty, all fields loaded.
skip_load_dynamic_field: If True, dynamic/growing fields will not be loaded
into memory. Saves memory when dynamic fields aren't needed.
kwargs: Optional keyword arguments for additional collection load
parameters. Enables forward compatibility.
"""
refresh: bool = field(default_factory=bool)
resource_groups: List[str] = field(default_factory=list)
load_fields: List[str] = field(default_factory=list)
skip_load_dynamic_field: bool = field(default_factory=bool)
kwargs: Dict[str, Any] = field(default_factory=dict)
[docs]
@dataclass
class MilvusSearchResult:
"""Search result from Milvus per chunk.
Args:
id: List of entity IDs returned from the search. Can be either string or
integer IDs.
distance: List of distances/similarity scores for each returned entity.
fields: List of dictionaries containing additional field values for each
entity. Each dictionary corresponds to one returned entity.
"""
id: List[Union[str, int]] = field(default_factory=list)
distance: List[float] = field(default_factory=list)
fields: List[Dict[str, Any]] = field(default_factory=list)
InputT, OutputT = Union[Chunk, List[Chunk]], List[Tuple[Chunk, Dict[str, Any]]]
[docs]
class MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
"""Enrichment handler for Milvus vector database searches.
This handler is designed to work with the
:class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler` transform.
It enables enriching data through vector similarity, keyword, or hybrid
searches against Milvus collections.
The handler supports different search strategies:
* Vector search - For finding similar embeddings based on vector similarity
* Keyword search - For text-based retrieval using BM25 or other text metrics
* Hybrid search - For combining vector and keyword search results
This handler queries the Milvus database per element by default. To enable
batching for improved performance, set the `min_batch_size` and
`max_batch_size` parameters. These control the batching behavior in the
:class:`apache_beam.transforms.utils.BatchElements` transform.
For memory-intensive operations, the handler allows fine-grained control over
collection loading through the `collection_load_parameters`.
"""
def __init__(
self,
connection_parameters: MilvusConnectionParameters,
search_parameters: MilvusSearchParameters,
*,
collection_load_parameters: Optional[MilvusCollectionLoadParameters],
min_batch_size: int = 1,
max_batch_size: int = 1000,
**kwargs):
"""
Example Usage:
connection_paramters = MilvusConnectionParameters(
uri="http://localhost:19530")
search_parameters = MilvusSearchParameters(
collection_name="my_collection",
search_strategy=VectorSearchParameters(anns_field="embedding"))
collection_load_parameters = MilvusCollectionLoadParameters(
load_fields=["embedding", "metadata"]),
milvus_handler = MilvusSearchEnrichmentHandler(
connection_paramters,
search_parameters,
collection_load_parameters=collection_load_parameters,
min_batch_size=10,
max_batch_size=100)
Args:
connection_parameters (MilvusConnectionParameters): Configuration for
connecting to the Milvus server, including URI, credentials, and
connection options.
search_parameters (MilvusSearchParameters): Configuration for search
operations, including collection name, search strategy, and output
fields.
collection_load_parameters (Optional[MilvusCollectionLoadParameters]):
Parameters controlling how collections are loaded into memory, which can
significantly impact resource usage and performance.
min_batch_size (int): Minimum number of elements to batch together when
querying Milvus. Default is 1 (no batching when max_batch_size is 1).
max_batch_size (int): Maximum number of elements to batch together.Default
is 1000. Higher values may improve throughput but increase memory usage.
**kwargs: Additional keyword arguments for Milvus Enrichment Handler.
Note:
* For large collections, consider setting appropriate values in
collection_load_parameters to reduce memory usage.
* The search_strategy in search_parameters determines the type of search
(vector, keyword, or hybrid) and associated parameters.
* Batching can significantly improve performance but requires more memory.
"""
self._connection_parameters = connection_parameters
self._search_parameters = search_parameters
self._collection_load_parameters = collection_load_parameters
if not self._collection_load_parameters:
self._collection_load_parameters = MilvusCollectionLoadParameters()
self._batching_kwargs = {
'min_batch_size': min_batch_size, 'max_batch_size': max_batch_size
}
self.kwargs = kwargs
self.join_fn = join_fn
self.use_custom_types = True
def __enter__(self):
connection_params = unpack_dataclass_with_kwargs(
self._connection_parameters)
collection_load_params = unpack_dataclass_with_kwargs(
self._collection_load_parameters)
self._client = MilvusClient(**connection_params)
self._client.load_collection(
collection_name=self.collection_name,
partition_names=self.partition_names,
**collection_load_params)
def __call__(self, request: Union[Chunk, List[Chunk]], *args,
**kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]:
reqs = request if isinstance(request, list) else [request]
search_result = self._search_documents(reqs)
return self._get_call_response(reqs, search_result)
def _search_documents(self, chunks: List[Chunk]):
if isinstance(self.search_strategy, HybridSearchParameters):
data = self._get_hybrid_search_data(chunks)
return self._client.hybrid_search(
collection_name=self.collection_name,
partition_names=self.partition_names,
output_fields=self.output_fields,
timeout=self.timeout,
round_decimal=self.round_decimal,
reqs=data,
ranker=self.search_strategy.ranker,
limit=self.search_strategy.limit,
**self.search_strategy.kwargs)
elif isinstance(self.search_strategy, VectorSearchParameters):
data = list(map(self._get_vector_search_data, chunks))
vector_search_params = unpack_dataclass_with_kwargs(self.search_strategy)
return self._client.search(
collection_name=self.collection_name,
partition_names=self.partition_names,
output_fields=self.output_fields,
timeout=self.timeout,
round_decimal=self.round_decimal,
data=data,
**vector_search_params)
elif isinstance(self.search_strategy, KeywordSearchParameters):
data = list(map(self._get_keyword_search_data, chunks))
keyword_search_params = unpack_dataclass_with_kwargs(self.search_strategy)
return self._client.search(
collection_name=self.collection_name,
partition_names=self.partition_names,
output_fields=self.output_fields,
timeout=self.timeout,
round_decimal=self.round_decimal,
data=data,
**keyword_search_params)
else:
raise ValueError(
f"Not supported search strategy yet: {self.search_strategy}")
def _get_hybrid_search_data(self, chunks: List[Chunk]):
vector_search_data = list(map(self._get_vector_search_data, chunks))
keyword_search_data = list(map(self._get_keyword_search_data, chunks))
vector_search_req = AnnSearchRequest(
data=vector_search_data,
anns_field=self.search_strategy.vector.anns_field,
param=self.search_strategy.vector.search_params,
limit=self.search_strategy.vector.limit,
expr=self.search_strategy.vector.filter)
keyword_search_req = AnnSearchRequest(
data=keyword_search_data,
anns_field=self.search_strategy.keyword.anns_field,
param=self.search_strategy.keyword.search_params,
limit=self.search_strategy.keyword.limit,
expr=self.search_strategy.keyword.filter)
reqs = [vector_search_req, keyword_search_req]
return reqs
def _get_vector_search_data(self, chunk: Chunk):
if not chunk.dense_embedding:
raise ValueError(
f"Chunk {chunk.id} missing dense embedding required for vector search"
)
return chunk.dense_embedding
def _get_keyword_search_data(self, chunk: Chunk):
if not chunk.content.text and not chunk.sparse_embedding:
raise ValueError(
f"Chunk {chunk.id} missing both text content and sparse embedding "
"required for keyword search")
sparse_embedding = self.convert_sparse_embedding_to_milvus_format(
chunk.sparse_embedding)
return chunk.content.text or sparse_embedding
def _get_call_response(
self, chunks: List[Chunk], search_result: SearchResult[Hits]):
response = []
for i in range(len(chunks)):
chunk = chunks[i]
hits: Hits = search_result[i]
result = MilvusSearchResult()
for i in range(len(hits)):
hit: Hit = hits[i]
normalized_fields = self._normalize_milvus_fields(hit.fields)
result.id.append(hit.id)
result.distance.append(hit.distance)
result.fields.append(normalized_fields)
response.append((chunk, result.__dict__))
return response
def _normalize_milvus_fields(self, fields: Dict[str, Any]):
normalized_fields = {}
for field, value in fields.items():
value = self._normalize_milvus_value(value)
normalized_fields[field] = value
return normalized_fields
def _normalize_milvus_value(self, value: Any):
# Convert Milvus-specific types to Python native types.
neither_str_nor_dict_nor_bytes = not isinstance(value, (str, dict, bytes))
if isinstance(value, Sequence) and neither_str_nor_dict_nor_bytes:
return list(value)
elif hasattr(value, 'DESCRIPTOR'):
# Handle protobuf messages.
return MessageToDict(value)
else:
# Keep other types as they are.
return value
@property
def collection_name(self):
"""Getter method for collection_name property"""
return self._search_parameters.collection_name
@property
def search_strategy(self):
"""Getter method for search_strategy property"""
return self._search_parameters.search_strategy
@property
def partition_names(self):
"""Getter method for partition_names property"""
return self._search_parameters.partition_names
@property
def output_fields(self):
"""Getter method for output_fields property"""
return self._search_parameters.output_fields
@property
def timeout(self):
"""Getter method for search timeout property"""
return self._search_parameters.timeout
@property
def round_decimal(self):
"""Getter method for search round_decimal property"""
return self._search_parameters.round_decimal
def __exit__(self, exc_type, exc_val, exc_tb):
self._client.release_collection(self.collection_name)
self._client.close()
self._client = None
[docs]
def batch_elements_kwargs(self) -> Dict[str, int]:
"""Returns kwargs for beam.BatchElements."""
return self._batching_kwargs
[docs]
def join_fn(left: Embedding, right: Dict[str, Any]) -> Embedding:
left.metadata['enrichment_data'] = right
return left
[docs]
def unpack_dataclass_with_kwargs(dataclass_instance):
# Create a copy of the dataclass's __dict__.
params_dict: dict = dataclass_instance.__dict__.copy()
# Extract the nested kwargs dictionary.
nested_kwargs = params_dict.pop('kwargs', {})
# Merge the dictionaries, with nested_kwargs taking precedence
# in case of duplicate keys.
return {**params_dict, **nested_kwargs}