Source code for apache_beam.runners.interactive.sql.sql_chain

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for tracking a chain of beam_sql magics applied.

For internal use only; no backwards-compatibility guarantees.

# pytype: skip-file

import importlib
import logging
from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import Optional
from typing import Set
from typing import Union

import apache_beam as beam
from apache_beam.internal import pickler
from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
from apache_beam.runners.interactive.utils import create_var_in_main
from apache_beam.runners.interactive.utils import pcoll_by_name
from apache_beam.runners.interactive.utils import progress_indicated
from apache_beam.transforms.sql import SqlTransform
from apache_beam.utils.interactive_utils import is_in_ipython

_LOGGER = logging.getLogger(__name__)

[docs]@dataclass class SqlNode: """Each SqlNode represents a beam_sql magic applied. Attributes: output_name: the watched unique name of the beam_sql output. Can be used as an identifier. source: the inputs consumed by this node. Can be a pipeline or a set of PCollections represented by their variable names watched. When it's a pipeline, the node computes from raw values in the query, so the output can be consumed by any SqlNode in any SqlChain. query: the SQL query applied by this node. schemas: the schemas (NamedTuple classes) used by this node. evaluated: the pipelines this node has been evaluated for. next: the next SqlNode applied chronologically. execution_count: the execution count if in an IPython env. """ output_name: str source: Union[beam.Pipeline, Set[str]] query: str schemas: Set[Any] = None evaluated: Set[beam.Pipeline] = None next: Optional['SqlNode'] = None execution_count: int = 0 def __post_init__(self): if not self.schemas: self.schemas = set() if not self.evaluated: self.evaluated = set() if is_in_ipython(): from IPython import get_ipython self.execution_count = get_ipython().execution_count def __hash__(self): return hash( (self.output_name, self.source, self.query, self.execution_count))
[docs] def to_pipeline(self, pipeline: Optional[beam.Pipeline]) -> beam.Pipeline: """Converts the chain into an executable pipeline.""" if pipeline not in self.evaluated: # The whole chain should form a single pipeline. source = self.source if isinstance(self.source, beam.Pipeline): if pipeline: # use the known pipeline source = pipeline else: # use the source pipeline pipeline = self.source else: name_to_pcoll = pcoll_by_name() if len(self.source) == 1: source = name_to_pcoll.get(next(iter(self.source))) else: source = {s: name_to_pcoll.get(s) for s in self.source} if isinstance(source, beam.Pipeline): output = source | 'beam_sql_{}_{}'.format( self.output_name, self.execution_count) >> SqlTransform(self.query) else: output = source | 'schema_loaded_beam_sql_{}_{}'.format( self.output_name, self.execution_count ) >> SchemaLoadedSqlTransform( self.output_name, self.query, self.schemas, self.execution_count) _ = create_var_in_main(self.output_name, output) self.evaluated.add(pipeline) if return else: return pipeline
[docs]class SchemaLoadedSqlTransform(beam.PTransform): """PTransform that loads schema before executing SQL. When submitting a pipeline to remote runner for execution, schemas defined in the main module are not available without save_main_session. However, save_main_session might fail when there is anything unpicklable. This DoFn makes sure only the schemas needed are pickled locally and restored later on workers. """ def __init__(self, output_name, query, schemas, execution_count): self.output_name = output_name self.query = query self.schemas = schemas self.execution_count = execution_count # TODO(BEAM-8123): clean up this attribute or the whole wrapper PTransform. # Dill does not preserve everything. On the other hand, save_main_session # is not stable. Until cloudpickle replaces dill in Beam, we work around # it by explicitly pickling annotations and load schemas in remote main # sessions. self.schema_annotations = [s.__annotations__ for s in self.schemas] class _SqlTransformDoFn(beam.DoFn): """The DoFn yields all its input without any transform but a setup to configure the main session.""" def __init__(self, schemas, annotations): self.pickled_schemas = [pickler.dumps(s) for s in schemas] self.pickled_annotations = [pickler.dumps(a) for a in annotations] def setup(self): main_session = importlib.import_module('__main__') for pickled_schema, pickled_annotation in zip( self.pickled_schemas, self.pickled_annotations): schema = pickler.loads(pickled_schema) schema.__annotations__ = pickler.loads(pickled_annotation) if not hasattr(main_session, schema.__name__) or not hasattr( getattr(main_session, schema.__name__), '__annotations__'): # Restore the schema in the main session on the [remote] worker. setattr(main_session, schema.__name__, schema) register_coder_for_schema(schema) def process(self, e): yield e
[docs] def expand(self, source): """Applies the SQL transform. If a PCollection uses a schema defined in the main session, use the additional DoFn to restore it on the worker.""" if isinstance(source, dict): schema_loaded = { tag: pcoll | 'load_schemas_{}_tag_{}_{}'.format( self.output_name, tag, self.execution_count) >> beam.ParDo( self._SqlTransformDoFn(self.schemas, self.schema_annotations)) if pcoll.element_type in self.schemas else pcoll for tag, pcoll in source.items() } elif isinstance(source, beam.pvalue.PCollection): schema_loaded = source | 'load_schemas_{}_{}'.format( self.output_name, self.execution_count) >> beam.ParDo( self._SqlTransformDoFn(self.schemas, self.schema_annotations) ) if source.element_type in self.schemas else source else: raise ValueError( '{} should be either a single PCollection or a dict of named ' 'PCollections.'.format(source)) return schema_loaded | 'beam_sql_{}_{}'.format( self.output_name, self.execution_count) >> SqlTransform(self.query)
[docs]@dataclass class SqlChain: """A chain of SqlNodes. Attributes: nodes: all nodes by their output_names. root: the first SqlNode applied chronologically. current: the last node applied. user_pipeline: the user defined pipeline this chain originates from. If None, the whole chain just computes from raw values in queries. Otherwise, at least some of the nodes in chain has queried against PCollections. """ nodes: Dict[str, SqlNode] = None root: Optional[SqlNode] = None current: Optional[SqlNode] = None user_pipeline: Optional[beam.Pipeline] = None def __post_init__(self): if not self.nodes: self.nodes = {}
[docs] @progress_indicated def to_pipeline(self) -> beam.Pipeline: """Converts the chain into a beam pipeline.""" pipeline_to_execute = self.root.to_pipeline(self.user_pipeline) # The pipeline definitely contains external transform: SqlTransform. pipeline_to_execute.contains_external_transforms = True return pipeline_to_execute
[docs] def append(self, node: SqlNode) -> 'SqlChain': """Appends a node to the chain.""" if self.current: = node else: self.root = node self.current = node self.nodes[node.output_name] = node return self
[docs] def get(self, output_name: str) -> Optional[SqlNode]: """Gets a node from the chain based on the given output_name.""" return self.nodes.get(output_name, None)