#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hashlib
import json
from typing import List
from typing import NamedTuple
import apache_beam as beam
from apache_beam.coders import registry
from apache_beam.coders.row_coder import RowCoder
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
TestRow = NamedTuple(
'TestRow',
[('id', str), ('embedding', List[float]), ('content', str),
('metadata', str)])
registry.register_coder(TestRow, RowCoder)
VECTOR_SIZE = 768
[docs]
def row_to_chunk(row) -> Chunk:
# Parse embedding string back to float list
embedding_list = [float(x) for x in row.embedding.strip('[]').split(',')]
return Chunk(
id=row.id,
content=Content(text=row.content if hasattr(row, 'content') else None),
embedding=Embedding(dense_embedding=embedding_list),
metadata=json.loads(row.metadata) if hasattr(row, 'metadata') else {})
[docs]
class ChunkTestUtils:
"""Helper functions for generating test Chunks."""
[docs]
@staticmethod
def from_seed(seed: int, content_prefix: str, seed_multiplier: int) -> Chunk:
"""Creates a deterministic Chunk from a seed value."""
return Chunk(
id=f"id_{seed}",
content=Content(text=f"{content_prefix}{seed}"),
embedding=Embedding(
dense_embedding=[
float(seed + i * seed_multiplier) / 100
for i in range(VECTOR_SIZE)
]),
metadata={"seed": str(seed)})
[docs]
@staticmethod
def get_expected_values(
range_start: int,
range_end: int,
content_prefix: str = "Testval",
seed_multiplier: int = 1) -> List[Chunk]:
"""Returns a range of test Chunks."""
return [
ChunkTestUtils.from_seed(i, content_prefix, seed_multiplier)
for i in range(range_start, range_end)
]
[docs]
class HashingFn(beam.CombineFn):
"""Hashing function for verification."""
[docs]
def create_accumulator(self):
return []
[docs]
def merge_accumulators(self, accumulators):
merged = []
for acc in accumulators:
merged.extend(acc)
return merged
[docs]
def generate_expected_hash(num_records: int) -> str:
chunks = ChunkTestUtils.get_expected_values(0, num_records)
values = sorted(
chunk.content.text if chunk.content.text else "" for chunk in chunks)
return hashlib.md5(''.join(values).encode()).hexdigest()
[docs]
def key_on_id(chunk):
return (int(chunk.id.split('_')[1]), chunk)