#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hashlib
import json
from typing import List
import apache_beam as beam
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
VECTOR_SIZE = 768
[docs]
def row_to_chunk(row) -> Chunk:
  # Parse embedding string back to float list
  embedding_list = [float(x) for x in row.embedding.strip('[]').split(',')]
  return Chunk(
      id=row.id,
      content=Content(text=row.content if hasattr(row, 'content') else None),
      embedding=Embedding(dense_embedding=embedding_list),
      metadata=json.loads(row.metadata) if hasattr(row, 'metadata') else {}) 
[docs]
class ChunkTestUtils:
  """Helper functions for generating test Chunks."""
[docs]
  @staticmethod
  def from_seed(seed: int, content_prefix: str, seed_multiplier: int) -> Chunk:
    """Creates a deterministic Chunk from a seed value."""
    return Chunk(
        id=f"id_{seed}",
        content=Content(text=f"{content_prefix}{seed}"),
        embedding=Embedding(
            dense_embedding=[
                float(seed + i * seed_multiplier) / 100
                for i in range(VECTOR_SIZE)
            ]),
        metadata={"seed": str(seed)}) 
[docs]
  @staticmethod
  def get_expected_values(
      range_start: int,
      range_end: int,
      content_prefix: str = "Testval",
      seed_multiplier: int = 1) -> List[Chunk]:
    """Returns a range of test Chunks."""
    return [
        ChunkTestUtils.from_seed(i, content_prefix, seed_multiplier)
        for i in range(range_start, range_end)
    ] 
 
[docs]
class HashingFn(beam.CombineFn):
  """Hashing function for verification."""
[docs]
  def create_accumulator(self):
    return [] 
[docs]
  def merge_accumulators(self, accumulators):
    merged = []
    for acc in accumulators:
      merged.extend(acc)
    return merged 
 
[docs]
def generate_expected_hash(num_records: int) -> str:
  chunks = ChunkTestUtils.get_expected_values(0, num_records)
  values = sorted(
      chunk.content.text if chunk.content.text else "" for chunk in chunks)
  return hashlib.md5(''.join(values).encode()).hexdigest() 
[docs]
def key_on_id(chunk):
  return (int(chunk.id.split('_')[1]), chunk)