#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import List
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.transforms.base import EmbeddingTypeAdapter
[docs]
def create_rag_adapter() -> EmbeddingTypeAdapter[Chunk, Chunk]:
  """Creates adapter for converting between Chunk and Embedding types.
    
    The adapter:
    - Extracts text from Chunk.content.text for embedding
    - Creates Embedding objects from model output
    - Sets Embedding in Chunk.embedding
    
    Returns:
        EmbeddingTypeAdapter configured for RAG pipeline types
    """
  return EmbeddingTypeAdapter(
      input_fn=_extract_chunk_text, output_fn=_add_embedding_fn) 
def _extract_chunk_text(chunks: Sequence[Chunk]) -> List[str]:
  """Extract text from chunks for embedding."""
  chunk_texts = []
  for chunk in chunks:
    if not chunk.content.text:
      raise ValueError("Expected chunk text content.")
    chunk_texts.append(chunk.content.text)
  return chunk_texts
def _add_embedding_fn(
    chunks: Sequence[Chunk], embeddings: Sequence[List[float]]) -> List[Chunk]:
  """Create Embeddings from chunks and embedding vectors."""
  for chunk, embedding in zip(chunks, embeddings):
    chunk.embedding = Embedding(dense_embedding=embedding)
  return list(chunks)