Source code for apache_beam.ml.rag.ingestion.test_utils

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import hashlib
import json
from typing import List
from typing import NamedTuple

import apache_beam as beam
from apache_beam.coders import registry
from apache_beam.coders.row_coder import RowCoder
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding

TestRow = NamedTuple(
    'TestRow',
    [('id', str), ('embedding', List[float]), ('content', str),
     ('metadata', str)])
registry.register_coder(TestRow, RowCoder)

VECTOR_SIZE = 768


[docs] def row_to_chunk(row) -> Chunk: # Parse embedding string back to float list embedding_list = [float(x) for x in row.embedding.strip('[]').split(',')] return Chunk( id=row.id, content=Content(text=row.content if hasattr(row, 'content') else None), embedding=Embedding(dense_embedding=embedding_list), metadata=json.loads(row.metadata) if hasattr(row, 'metadata') else {})
[docs] class ChunkTestUtils: """Helper functions for generating test Chunks."""
[docs] @staticmethod def from_seed(seed: int, content_prefix: str, seed_multiplier: int) -> Chunk: """Creates a deterministic Chunk from a seed value.""" return Chunk( id=f"id_{seed}", content=Content(text=f"{content_prefix}{seed}"), embedding=Embedding( dense_embedding=[ float(seed + i * seed_multiplier) / 100 for i in range(VECTOR_SIZE) ]), metadata={"seed": str(seed)})
[docs] @staticmethod def get_expected_values( range_start: int, range_end: int, content_prefix: str = "Testval", seed_multiplier: int = 1) -> List[Chunk]: """Returns a range of test Chunks.""" return [ ChunkTestUtils.from_seed(i, content_prefix, seed_multiplier) for i in range(range_start, range_end) ]
[docs] class HashingFn(beam.CombineFn): """Hashing function for verification."""
[docs] def create_accumulator(self): return []
[docs] def add_input(self, accumulator, input): accumulator.append(input.content.text if input.content.text else "") return accumulator
[docs] def merge_accumulators(self, accumulators): merged = [] for acc in accumulators: merged.extend(acc) return merged
[docs] def extract_output(self, accumulator): sorted_values = sorted(accumulator) return hashlib.md5(''.join(sorted_values).encode()).hexdigest()
[docs] def generate_expected_hash(num_records: int) -> str: chunks = ChunkTestUtils.get_expected_values(0, num_records) values = sorted( chunk.content.text if chunk.content.text else "" for chunk in chunks) return hashlib.md5(''.join(values).encode()).hexdigest()
[docs] def key_on_id(chunk): return (int(chunk.id.split('_')[1]), chunk)