#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Module for tracking a chain of beam_sql magics applied.
For internal use only; no backwards-compatibility guarantees.
"""
# pytype: skip-file
import importlib
import logging
from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import Optional
from typing import Set
from typing import Union
import apache_beam as beam
from apache_beam.internal import pickler
from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
from apache_beam.runners.interactive.utils import create_var_in_main
from apache_beam.runners.interactive.utils import pcoll_by_name
from apache_beam.runners.interactive.utils import progress_indicated
from apache_beam.transforms.sql import SqlTransform
from apache_beam.utils.interactive_utils import is_in_ipython
_LOGGER = logging.getLogger(__name__)
[docs]@dataclass
class SqlNode:
"""Each SqlNode represents a beam_sql magic applied.
Attributes:
output_name: the watched unique name of the beam_sql output. Can be used as
an identifier.
source: the inputs consumed by this node. Can be a pipeline or a set of
PCollections represented by their variable names watched. When it's a
pipeline, the node computes from raw values in the query, so the output
can be consumed by any SqlNode in any SqlChain.
query: the SQL query applied by this node.
schemas: the schemas (NamedTuple classes) used by this node.
evaluated: the pipelines this node has been evaluated for.
next: the next SqlNode applied chronologically.
execution_count: the execution count if in an IPython env.
"""
output_name: str
source: Union[beam.Pipeline, Set[str]]
query: str
schemas: Set[Any] = None
evaluated: Set[beam.Pipeline] = None
next: Optional['SqlNode'] = None
execution_count: int = 0
def __post_init__(self):
if not self.schemas:
self.schemas = set()
if not self.evaluated:
self.evaluated = set()
if is_in_ipython():
from IPython import get_ipython
self.execution_count = get_ipython().execution_count
def __hash__(self):
return hash(
(self.output_name, self.source, self.query, self.execution_count))
[docs] def to_pipeline(self, pipeline: Optional[beam.Pipeline]) -> beam.Pipeline:
"""Converts the chain into an executable pipeline."""
if pipeline not in self.evaluated:
# The whole chain should form a single pipeline.
source = self.source
if isinstance(self.source, beam.Pipeline):
if pipeline: # use the known pipeline
source = pipeline
else: # use the source pipeline
pipeline = self.source
else:
name_to_pcoll = pcoll_by_name()
if len(self.source) == 1:
source = name_to_pcoll.get(next(iter(self.source)))
else:
source = {s: name_to_pcoll.get(s) for s in self.source}
if isinstance(source, beam.Pipeline):
output = source | 'beam_sql_{}_{}'.format(
self.output_name, self.execution_count) >> SqlTransform(self.query)
else:
output = source | 'schema_loaded_beam_sql_{}_{}'.format(
self.output_name, self.execution_count
) >> SchemaLoadedSqlTransform(
self.output_name, self.query, self.schemas, self.execution_count)
_ = create_var_in_main(self.output_name, output)
self.evaluated.add(pipeline)
if self.next:
return self.next.to_pipeline(pipeline)
else:
return pipeline
[docs]@dataclass
class SqlChain:
"""A chain of SqlNodes.
Attributes:
nodes: all nodes by their output_names.
root: the first SqlNode applied chronologically.
current: the last node applied.
user_pipeline: the user defined pipeline this chain originates from. If
None, the whole chain just computes from raw values in queries.
Otherwise, at least some of the nodes in chain has queried against
PCollections.
"""
nodes: Dict[str, SqlNode] = None
root: Optional[SqlNode] = None
current: Optional[SqlNode] = None
user_pipeline: Optional[beam.Pipeline] = None
def __post_init__(self):
if not self.nodes:
self.nodes = {}
[docs] @progress_indicated
def to_pipeline(self) -> beam.Pipeline:
"""Converts the chain into a beam pipeline."""
pipeline_to_execute = self.root.to_pipeline(self.user_pipeline)
# The pipeline definitely contains external transform: SqlTransform.
pipeline_to_execute.contains_external_transforms = True
return pipeline_to_execute
[docs] def append(self, node: SqlNode) -> 'SqlChain':
"""Appends a node to the chain."""
if self.current:
self.current.next = node
else:
self.root = node
self.current = node
self.nodes[node.output_name] = node
return self
[docs] def get(self, output_name: str) -> Optional[SqlNode]:
"""Gets a node from the chain based on the given output_name."""
return self.nodes.get(output_name, None)