Source code for apache_beam.ml.rag.ingestion.spanner

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Cloud Spanner vector store writer for RAG pipelines.

This module provides a writer for storing embeddings and associated metadata
in Google Cloud Spanner. It supports flexible schema configuration with the
ability to flatten metadata fields into dedicated columns.

Example usage:

    Default schema (id, embedding, content, metadata):
    >>> config = SpannerVectorWriterConfig(
    ...     project_id="my-project",
    ...     instance_id="my-instance",
    ...     database_id="my-db",
    ...     table_name="embeddings"
    ... )

    Flattened metadata fields:
    >>> specs = (
    ...     SpannerColumnSpecsBuilder()
    ...     .with_id_spec()
    ...     .with_embedding_spec()
    ...     .with_content_spec()
    ...     .add_metadata_field("source", str)
    ...     .add_metadata_field("page_number", int, default=0)
    ...     .with_metadata_spec()
    ...     .build()
    ... )
    >>> config = SpannerVectorWriterConfig(
    ...     project_id="my-project",
    ...     instance_id="my-instance",
    ...     database_id="my-db",
    ...     table_name="embeddings",
    ...     column_specs=specs
    ... )

Spanner schema example:

    CREATE TABLE embeddings (
        id STRING(1024) NOT NULL,
        embedding ARRAY<FLOAT32>(vector_length=>768),
        content STRING(MAX),
        source STRING(MAX),
        page_number INT64,
        metadata JSON
    ) PRIMARY KEY (id)
"""

import functools
import json
from dataclasses import dataclass
from typing import Any
from typing import Callable
from typing import List
from typing import Literal
from typing import NamedTuple
from typing import Optional
from typing import Type

import apache_beam as beam
from apache_beam.coders import registry
from apache_beam.coders.row_coder import RowCoder
from apache_beam.io.gcp import spanner
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.types import Chunk


[docs] @dataclass class SpannerColumnSpec: """Column specification for Spanner vector writes. Defines how to extract and format values from Chunks for insertion into Spanner table columns. Each spec maps to one column in the target table. Attributes: column_name: Name of the Spanner table column python_type: Python type for the NamedTuple field (required for RowCoder) value_fn: Function to extract value from a Chunk Examples: String column: >>> SpannerColumnSpec( ... column_name="id", ... python_type=str, ... value_fn=lambda chunk: chunk.id ... ) Array column with conversion: >>> SpannerColumnSpec( ... column_name="embedding", ... python_type=List[float], ... value_fn=lambda chunk: chunk.embedding.dense_embedding ... ) """ column_name: str python_type: Type value_fn: Callable[[Chunk], Any]
def _extract_and_convert(extract_fn, convert_fn, chunk): if convert_fn: return convert_fn(extract_fn(chunk)) return extract_fn(chunk)
[docs] class SpannerColumnSpecsBuilder: """Builder for creating Spanner column specifications. Provides a fluent API for defining table schemas and how to populate them from Chunk objects. Supports standard Chunk fields (id, embedding, content, metadata) and flattening metadata fields into dedicated columns. Example: >>> specs = ( ... SpannerColumnSpecsBuilder() ... .with_id_spec() ... .with_embedding_spec() ... .with_content_spec() ... .add_metadata_field("source", str) ... .with_metadata_spec() ... .build() ... ) """ def __init__(self): self._specs: List[SpannerColumnSpec] = []
[docs] @staticmethod def with_defaults() -> 'SpannerColumnSpecsBuilder': """Create builder with default schema. Default schema includes: - id (STRING): Chunk ID - embedding (ARRAY<FLOAT32>): Dense embedding vector - content (STRING): Chunk content text - metadata (JSON): Full metadata as JSON Returns: Builder with default column specifications """ return ( SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(). with_content_spec().with_metadata_spec())
[docs] def with_id_spec( self, column_name: str = "id", python_type: Type = str, convert_fn: Optional[Callable[[str], Any]] = None ) -> 'SpannerColumnSpecsBuilder': """Add ID column specification. Args: column_name: Column name (default: "id") python_type: Python type (default: str) convert_fn: Optional converter (e.g., to cast to int) Returns: Self for method chaining Examples: Default string ID: >>> builder.with_id_spec() Integer ID with conversion: >>> builder.with_id_spec( ... python_type=int, ... convert_fn=lambda id: int(id.split('_')[1]) ... ) """ self._specs.append( SpannerColumnSpec( column_name=column_name, python_type=python_type, value_fn=functools.partial( _extract_and_convert, lambda chunk: chunk.id, convert_fn))) return self
[docs] def with_embedding_spec( self, column_name: str = "embedding", convert_fn: Optional[Callable[[List[float]], List[float]]] = None ) -> 'SpannerColumnSpecsBuilder': """Add embedding array column (ARRAY<FLOAT32> or ARRAY<FLOAT64>). Args: column_name: Column name (default: "embedding") convert_fn: Optional converter (e.g., normalize, quantize) Returns: Self for method chaining Examples: Default embedding: >>> builder.with_embedding_spec() Normalized embedding: >>> def normalize(vec): ... norm = (sum(x**2 for x in vec) ** 0.5) or 1.0 ... return [x/norm for x in vec] >>> builder.with_embedding_spec(convert_fn=normalize) Rounded precision: >>> builder.with_embedding_spec( ... convert_fn=lambda vec: [round(x, 4) for x in vec] ... ) """ def extract_fn(chunk: Chunk) -> List[float]: if chunk.embedding is None or chunk.embedding.dense_embedding is None: raise ValueError(f'Chunk must contain embedding: {chunk}') return chunk.embedding.dense_embedding self._specs.append( SpannerColumnSpec( column_name=column_name, python_type=List[float], value_fn=functools.partial( _extract_and_convert, extract_fn, convert_fn))) return self
[docs] def with_content_spec( self, column_name: str = "content", python_type: Type = str, convert_fn: Optional[Callable[[str], Any]] = None ) -> 'SpannerColumnSpecsBuilder': """Add content column. Args: column_name: Column name (default: "content") python_type: Python type (default: str) convert_fn: Optional converter Returns: Self for method chaining Examples: Default text content: >>> builder.with_content_spec() Content length as integer: >>> builder.with_content_spec( ... column_name="content_length", ... python_type=int, ... convert_fn=lambda text: len(text.split()) ... ) Truncated content: >>> builder.with_content_spec( ... convert_fn=lambda text: text[:1000] ... ) """ def extract_fn(chunk: Chunk) -> str: if chunk.content.text is None: raise ValueError(f'Chunk must contain content: {chunk}') return chunk.content.text self._specs.append( SpannerColumnSpec( column_name=column_name, python_type=python_type, value_fn=functools.partial( _extract_and_convert, extract_fn, convert_fn))) return self
[docs] def with_metadata_spec( self, column_name: str = "metadata") -> 'SpannerColumnSpecsBuilder': """Add metadata JSON column. Stores the full metadata dictionary as a JSON string in Spanner. Args: column_name: Column name (default: "metadata") Returns: Self for method chaining Note: Metadata is automatically converted to JSON string using json.dumps() """ value_fn = lambda chunk: json.dumps(chunk.metadata) self._specs.append( SpannerColumnSpec( column_name=column_name, python_type=str, value_fn=value_fn)) return self
[docs] def add_metadata_field( self, field: str, python_type: Type, column_name: Optional[str] = None, convert_fn: Optional[Callable[[Any], Any]] = None, default: Any = None) -> 'SpannerColumnSpecsBuilder': """Flatten a metadata field into its own column. Extracts a specific field from chunk.metadata and stores it in a dedicated table column. Args: field: Key in chunk.metadata to extract python_type: Python type (must be explicitly specified) column_name: Column name (default: same as field) convert_fn: Optional converter for type casting/transformation default: Default value if field is missing from metadata Returns: Self for method chaining Examples: String field: >>> builder.add_metadata_field("source", str) Integer with default: >>> builder.add_metadata_field( ... "page_number", ... int, ... default=0 ... ) Float with conversion: >>> builder.add_metadata_field( ... "confidence", ... float, ... convert_fn=lambda x: round(float(x), 2), ... default=0.0 ... ) List of strings: >>> builder.add_metadata_field( ... "tags", ... List[str], ... default=[] ... ) Timestamp with conversion: >>> builder.add_metadata_field( ... "created_at", ... str, ... convert_fn=lambda ts: ts.isoformat() ... ) """ name = column_name or field def value_fn(chunk: Chunk) -> Any: return chunk.metadata.get(field, default) self._specs.append( SpannerColumnSpec( column_name=name, python_type=python_type, value_fn=functools.partial( _extract_and_convert, value_fn, convert_fn))) return self
[docs] def add_column( self, column_name: str, python_type: Type, value_fn: Callable[[Chunk], Any]) -> 'SpannerColumnSpecsBuilder': """Add a custom column with full control. Args: column_name: Column name python_type: Python type (required) value_fn: Value extraction function Returns: Self for method chaining Examples: Boolean flag: >>> builder.add_column( ... column_name="has_code", ... python_type=bool, ... value_fn=lambda chunk: "```" in chunk.content.text ... ) Computed value: >>> builder.add_column( ... column_name="word_count", ... python_type=int, ... value_fn=lambda chunk: len(chunk.content.text.split()) ... ) """ self._specs.append( SpannerColumnSpec( column_name=column_name, python_type=python_type, value_fn=value_fn)) return self
[docs] def build(self) -> List[SpannerColumnSpec]: """Build the final list of column specifications. Returns: List of SpannerColumnSpec objects """ return self._specs.copy()
class _SpannerSchemaBuilder: """Internal: Builds NamedTuple schema and registers RowCoder. Creates a NamedTuple type from column specifications and registers it with Beam's RowCoder for serialization. """ def __init__(self, table_name: str, column_specs: List[SpannerColumnSpec]): """Initialize schema builder. Args: table_name: Table name (used in NamedTuple type name) column_specs: List of column specifications Raises: ValueError: If duplicate column names are found """ self.table_name = table_name self.column_specs = column_specs # Validate no duplicates names = [col.column_name for col in column_specs] duplicates = set(name for name in names if names.count(name) > 1) if duplicates: raise ValueError(f"Duplicate column names: {duplicates}") # Create NamedTuple type fields = [(col.column_name, col.python_type) for col in column_specs] type_name = f"SpannerVectorRecord_{table_name}" self.record_type = NamedTuple(type_name, fields) # type: ignore # Register coder registry.register_coder(self.record_type, RowCoder) def create_converter(self) -> Callable[[Chunk], NamedTuple]: """Create converter function from Chunk to NamedTuple record. Returns: Function that converts a Chunk to a NamedTuple record """ def convert(chunk: Chunk) -> self.record_type: # type: ignore values = { col.column_name: col.value_fn(chunk) for col in self.column_specs } return self.record_type(**values) # type: ignore return convert
[docs] class SpannerVectorWriterConfig(VectorDatabaseWriteConfig): """Configuration for writing vectors to Cloud Spanner. Supports flexible schema configuration through column specifications and provides control over Spanner-specific write parameters. Examples: Default schema: >>> config = SpannerVectorWriterConfig( ... project_id="my-project", ... instance_id="my-instance", ... database_id="my-db", ... table_name="embeddings" ... ) Custom schema with flattened metadata: >>> specs = ( ... SpannerColumnSpecsBuilder() ... .with_id_spec() ... .with_embedding_spec() ... .with_content_spec() ... .add_metadata_field("source", str) ... .add_metadata_field("page_number", int, default=0) ... .with_metadata_spec() ... .build() ... ) >>> config = SpannerVectorWriterConfig( ... project_id="my-project", ... instance_id="my-instance", ... database_id="my-db", ... table_name="embeddings", ... column_specs=specs ... ) With emulator: >>> config = SpannerVectorWriterConfig( ... project_id="test-project", ... instance_id="test-instance", ... database_id="test-db", ... table_name="embeddings", ... emulator_host="http://localhost:9010" ... ) """ def __init__( self, project_id: str, instance_id: str, database_id: str, table_name: str, *, # Schema configuration column_specs: Optional[List[SpannerColumnSpec]] = None, # Write operation type write_mode: Literal["INSERT", "UPDATE", "REPLACE", "INSERT_OR_UPDATE"] = "INSERT_OR_UPDATE", # Batching configuration max_batch_size_bytes: Optional[int] = None, max_number_mutations: Optional[int] = None, max_number_rows: Optional[int] = None, grouping_factor: Optional[int] = None, # Networking host: Optional[str] = None, emulator_host: Optional[str] = None, expansion_service: Optional[str] = None, # Retry/deadline configuration commit_deadline: Optional[int] = None, max_cumulative_backoff: Optional[int] = None, # Error handling failure_mode: Optional[ spanner.FailureMode] = spanner.FailureMode.REPORT_FAILURES, high_priority: bool = False, # Additional Spanner arguments **spanner_kwargs): """Initialize Spanner vector writer configuration. Args: project_id: GCP project ID instance_id: Spanner instance ID database_id: Spanner database ID table_name: Target table name column_specs: Schema configuration using SpannerColumnSpecsBuilder. If None, uses default schema (id, embedding, content, metadata) write_mode: Spanner write operation type: - INSERT: Fail if row exists - UPDATE: Fail if row doesn't exist - REPLACE: Delete then insert - INSERT_OR_UPDATE: Insert or update if exists (default) max_batch_size_bytes: Maximum bytes per mutation batch (default: 1MB) max_number_mutations: Maximum cell mutations per batch (default: 5000) max_number_rows: Maximum rows per batch (default: 500) grouping_factor: Multiple of max mutation for sorting (default: 1000) host: Spanner host URL (usually not needed) emulator_host: Spanner emulator host (e.g., "http://localhost:9010") expansion_service: Java expansion service address (host:port) commit_deadline: Commit API deadline in seconds (default: 15) max_cumulative_backoff: Max retry backoff seconds (default: 900) failure_mode: Error handling strategy: - FAIL_FAST: Throw exception for any failure - REPORT_FAILURES: Continue processing (default) high_priority: Use high priority for operations (default: False) **spanner_kwargs: Additional keyword arguments to pass to the underlying Spanner write transform. Use this to pass any Spanner-specific parameters not explicitly exposed by this config. """ self.project_id = project_id self.instance_id = instance_id self.database_id = database_id self.table_name = table_name self.write_mode = write_mode self.max_batch_size_bytes = max_batch_size_bytes self.max_number_mutations = max_number_mutations self.max_number_rows = max_number_rows self.grouping_factor = grouping_factor self.host = host self.emulator_host = emulator_host self.expansion_service = expansion_service self.commit_deadline = commit_deadline self.max_cumulative_backoff = max_cumulative_backoff self.failure_mode = failure_mode self.high_priority = high_priority self.spanner_kwargs = spanner_kwargs # Use defaults if not provided specs = column_specs or SpannerColumnSpecsBuilder.with_defaults().build() # Create schema builder (NamedTuple + RowCoder registration) self.schema_builder = _SpannerSchemaBuilder(table_name, specs)
[docs] def create_write_transform(self) -> beam.PTransform: """Create the Spanner write PTransform. Returns: PTransform for writing to Spanner """ return _WriteToSpannerVectorDatabase(self)
class _WriteToSpannerVectorDatabase(beam.PTransform): """Internal: PTransform for writing to Spanner vector database.""" def __init__(self, config: SpannerVectorWriterConfig): """Initialize write transform. Args: config: Spanner writer configuration """ self.config = config self.schema_builder = config.schema_builder def expand(self, pcoll: beam.PCollection[Chunk]): """Expand the transform. Args: pcoll: PCollection of Chunks to write """ # Select appropriate Spanner write transform based on write_mode write_transform_class = { "INSERT": spanner.SpannerInsert, "UPDATE": spanner.SpannerUpdate, "REPLACE": spanner.SpannerReplace, "INSERT_OR_UPDATE": spanner.SpannerInsertOrUpdate, }[self.config.write_mode] return ( pcoll | "Convert to Records" >> beam.Map( self.schema_builder.create_converter()).with_output_types( self.schema_builder.record_type) | "Write to Spanner" >> write_transform_class( project_id=self.config.project_id, instance_id=self.config.instance_id, database_id=self.config.database_id, table=self.config.table_name, max_batch_size_bytes=self.config.max_batch_size_bytes, max_number_mutations=self.config.max_number_mutations, max_number_rows=self.config.max_number_rows, grouping_factor=self.config.grouping_factor, host=self.config.host, emulator_host=self.config.emulator_host, commit_deadline=self.config.commit_deadline, max_cumulative_backoff=self.config.max_cumulative_backoff, failure_mode=self.config.failure_mode, expansion_service=self.config.expansion_service, high_priority=self.config.high_priority, **self.config.spanner_kwargs))