#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Google Cloud Spanner IO
Experimental; no backwards-compatibility guarantees.
This is an experimental module for reading and writing data from Google Cloud
Spanner. Visit: https://cloud.google.com/spanner for more details.
Reading Data from Cloud Spanner.
To read from Cloud Spanner apply ReadFromSpanner transformation. It will
return a PCollection, where each element represents an individual row returned
from the read operation. Both Query and Read APIs are supported.
ReadFromSpanner relies on the ReadOperation objects which is exposed by the
SpannerIO API. ReadOperation holds the immutable data which is responsible to
execute batch and naive reads on Cloud Spanner. This is done for more
convenient programming.
ReadFromSpanner reads from Cloud Spanner by providing either an 'sql' param
in the constructor or 'table' name with 'columns' as list. For example:::
records = (pipeline
| ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
sql='Select * from users'))
records = (pipeline
| ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
table='users', columns=['id', 'name', 'email']))
You can also perform multiple reads by providing a list of ReadOperations
to the ReadFromSpanner transform constructor. ReadOperation exposes two static
methods. Use 'query' to perform sql based reads, 'table' to perform read from
table name. For example:::
read_operations = [
ReadOperation.table(table='customers', columns=['name',
'email']),
ReadOperation.table(table='vendors', columns=['name',
'email']),
]
all_users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
read_operations=read_operations)
...OR...
read_operations = [
ReadOperation.query(sql='Select name, email from
customers'),
ReadOperation.query(
sql='Select * from users where id <= @user_id',
params={'user_id': 100},
params_type={'user_id': param_types.INT64}
),
]
# `params_types` are instance of `google.cloud.spanner.param_types`
all_users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
read_operations=read_operations)
For more information, please review the docs on class ReadOperation.
User can also able to provide the ReadOperation in form of PCollection via
pipeline. For example:::
users = (pipeline
| beam.Create([ReadOperation...])
| ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME))
User may also create cloud spanner transaction from the transform called
`create_transaction` which is available in the SpannerIO API.
The transform is guaranteed to be executed on a consistent snapshot of data,
utilizing the power of read only transactions. Staleness of data can be
controlled by providing the `read_timestamp` or `exact_staleness` param values
in the constructor.
This transform requires root of the pipeline (PBegin) and returns PTransform
which is passed later to the `ReadFromSpanner` constructor. `ReadFromSpanner`
pass this transaction PTransform as a singleton side input to the
`_NaiveSpannerReadDoFn` containing 'session_id' and 'transaction_id'.
For example:::
transaction = (pipeline | create_transaction(TEST_PROJECT_ID,
TEST_INSTANCE_ID,
DB_NAME))
users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
sql='Select * from users', transaction=transaction)
tweets = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
sql='Select * from tweets', transaction=transaction)
For further details of this transform, please review the docs on the
:meth:`create_transaction` method available in the SpannerIO API.
ReadFromSpanner takes this transform in the constructor and pass this to the
read pipeline as the singleton side input.
Writing Data to Cloud Spanner.
The WriteToSpanner transform writes to Cloud Spanner by executing a
collection a input rows (WriteMutation). The mutations are grouped into
batches for efficiency.
WriteToSpanner transform relies on the WriteMutation objects which is exposed
by the SpannerIO API. WriteMutation have five static methods (insert, update,
insert_or_update, replace, delete). These methods returns the instance of the
_Mutator object which contains the mutation type and the Spanner Mutation
object. For more details, review the docs of the class SpannerIO.WriteMutation.
For example:::
mutations = [
WriteMutation.insert(table='user', columns=('name', 'email'),
values=[('sara', 'sara@dev.com')])
]
_ = (p
| beam.Create(mutations)
| WriteToSpanner(
project_id=SPANNER_PROJECT_ID,
instance_id=SPANNER_INSTANCE_ID,
database_id=SPANNER_DATABASE_NAME)
)
You can also create WriteMutation via calling its constructor. For example:::
mutations = [
WriteMutation(insert='users', columns=('name', 'email'),
values=[('sara", 'sara@example.com')])
]
For more information, review the docs available on WriteMutation class.
WriteToSpanner transform also takes three batching parameters (max_number_rows,
max_number_cells and max_batch_size_bytes). By default, max_number_rows is set
to 50 rows, max_number_cells is set to 500 cells and max_batch_size_bytes is
set to 1MB (1048576 bytes). These parameter used to reduce the number of
transactions sent to spanner by grouping the mutation into batches. Setting
these param values either to smaller value or zero to disable batching.
Unlike the Java connector, this connector does not create batches of
transactions sorted by table and primary key.
WriteToSpanner transforms starts with the grouping into batches. The first step
in this process is to make the mutation groups of the WriteMutation
objects and then filtering them into batchable and unbatchable mutation
groups. There are three batching parameters (max_number_cells, max_number_rows
& max_batch_size_bytes). We calculated th mutation byte size from the method
available in the `google.cloud.spanner_v1.proto.mutation_pb2.Mutation.ByteSize`.
if the mutation rows, cells or byte size are larger than value of the any
batching parameters param, it will be tagged as "unbatchable" mutation. After
this all the batchable mutation are merged into a single mutation group whos
size is not larger than the "max_batch_size_bytes", after this process, all the
mutation groups together to process. If the Mutation references a table or
column does not exits, it will cause a exception and fails the entire pipeline.
"""
import typing
from collections import deque
from collections import namedtuple
from apache_beam import Create
from apache_beam import DoFn
from apache_beam import Flatten
from apache_beam import ParDo
from apache_beam import Reshuffle
from apache_beam.internal.metrics.metric import ServiceCallMetric
from apache_beam.io.gcp import resource_identifiers
from apache_beam.metrics import Metrics
from apache_beam.metrics import monitoring_infos
from apache_beam.pvalue import AsSingleton
from apache_beam.pvalue import PBegin
from apache_beam.pvalue import TaggedOutput
from apache_beam.transforms import PTransform
from apache_beam.transforms import ptransform_fn
from apache_beam.transforms import window
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.typehints import with_input_types
from apache_beam.typehints import with_output_types
# Protect against environments where spanner library is not available.
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
# pylint: disable=unused-import
try:
from google.cloud.spanner import Client
from google.cloud.spanner import KeySet
from google.cloud.spanner_v1 import batch
from google.cloud.spanner_v1.database import BatchSnapshot
from google.api_core.exceptions import ClientError, GoogleAPICallError
from apitools.base.py.exceptions import HttpError
except ImportError:
Client = None
KeySet = None
BatchSnapshot = None
try:
from google.cloud.spanner_v1 import Mutation
except ImportError:
try:
# Remove this and the try clause when we upgrade to google-cloud-spanner
# 3.x.x.
from google.cloud.spanner_v1.proto.mutation_pb2 import Mutation
except ImportError:
# Ignoring for environments where the Spanner library is not available.
pass
__all__ = [
'create_transaction',
'ReadFromSpanner',
'ReadOperation',
'WriteToSpanner',
'WriteMutation',
'MutationGroup'
]
class _SPANNER_TRANSACTION(namedtuple("SPANNER_TRANSACTION", ["transaction"])):
"""
Holds the spanner transaction details.
"""
__slots__ = ()
[docs]
class ReadOperation(namedtuple(
"ReadOperation", ["is_sql", "is_table", "read_operation", "kwargs"])):
"""
Encapsulates a spanner read operation.
"""
__slots__ = ()
[docs]
@classmethod
def query(cls, sql, params=None, param_types=None):
"""
A convenient method to construct ReadOperation from sql query.
Args:
sql: SQL query statement
params: (optional) values for parameter replacement. Keys must match the
names used in sql
param_types: (optional) maps explicit types for one or more param values;
required if parameters are passed.
"""
if params:
assert param_types is not None
return cls(
is_sql=True,
is_table=False,
read_operation="process_query_batch",
kwargs={
'sql': sql, 'params': params, 'param_types': param_types
})
[docs]
@classmethod
def table(cls, table, columns, index="", keyset=None):
"""
A convenient method to construct ReadOperation from table.
Args:
table: name of the table from which to fetch data.
columns: names of columns to be retrieved.
index: (optional) name of index to use, rather than the table's primary
key.
keyset: (optional) `KeySet` keys / ranges identifying rows to be
retrieved.
"""
keyset = keyset or KeySet(all_=True)
if not isinstance(keyset, KeySet):
raise ValueError(
"keyset must be an instance of class "
"google.cloud.spanner.KeySet")
return cls(
is_sql=False,
is_table=True,
read_operation="process_read_batch",
kwargs={
'table': table,
'columns': columns,
'index': index,
'keyset': keyset
})
class _BeamSpannerConfiguration(namedtuple("_BeamSpannerConfiguration",
["project",
"instance",
"database",
"table",
"query_name",
"credentials",
"pool",
"snapshot_read_timestamp",
"snapshot_exact_staleness"])):
"""
A namedtuple holds the immutable data of the connection string to the cloud
spanner.
"""
@property
def snapshot_options(self):
snapshot_options = {}
if self.snapshot_exact_staleness:
snapshot_options['exact_staleness'] = self.snapshot_exact_staleness
if self.snapshot_read_timestamp:
snapshot_options['read_timestamp'] = self.snapshot_read_timestamp
return snapshot_options
@with_input_types(ReadOperation, _SPANNER_TRANSACTION)
@with_output_types(typing.List[typing.Any])
class _NaiveSpannerReadDoFn(DoFn):
def __init__(self, spanner_configuration):
"""
A naive version of Spanner read which uses the transaction API of the
cloud spanner.
https://googleapis.dev/python/spanner/latest/transaction-api.html
In Naive reads, this transform performs single reads, where as the
Batch reads use the spanner partitioning query to create batches.
Args:
spanner_configuration: (_BeamSpannerConfiguration) Connection details to
connect with cloud spanner.
"""
self._spanner_configuration = spanner_configuration
self._snapshot = None
self._session = None
self.base_labels = {
monitoring_infos.SERVICE_LABEL: 'Spanner',
monitoring_infos.METHOD_LABEL: 'Read',
monitoring_infos.SPANNER_PROJECT_ID: (
self._spanner_configuration.project),
monitoring_infos.SPANNER_DATABASE_ID: (
self._spanner_configuration.database),
}
def _table_metric(self, table_id, status):
database_id = self._spanner_configuration.database
project_id = self._spanner_configuration.project
resource = resource_identifiers.SpannerTable(
project_id, database_id, table_id)
labels = {
**self.base_labels,
monitoring_infos.RESOURCE_LABEL: resource,
monitoring_infos.SPANNER_TABLE_ID: table_id
}
service_call_metric = ServiceCallMetric(
request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
base_labels=labels)
service_call_metric.call(str(status))
def _query_metric(self, query_name, status):
project_id = self._spanner_configuration.project
resource = resource_identifiers.SpannerSqlQuery(project_id, query_name)
labels = {
**self.base_labels,
monitoring_infos.RESOURCE_LABEL: resource,
monitoring_infos.SPANNER_QUERY_NAME: query_name
}
service_call_metric = ServiceCallMetric(
request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
base_labels=labels)
service_call_metric.call(str(status))
def _get_session(self):
if self._session is None:
session = self._session = self._database.session()
session.create()
return self._session
def _close_session(self):
if self._session is not None:
self._session.delete()
def setup(self):
# setting up client to connect with cloud spanner
spanner_client = Client(self._spanner_configuration.project)
instance = spanner_client.instance(self._spanner_configuration.instance)
self._database = instance.database(
self._spanner_configuration.database,
pool=self._spanner_configuration.pool)
def process(self, element, spanner_transaction):
# `spanner_transaction` should be the instance of the _SPANNER_TRANSACTION
# object.
if not isinstance(spanner_transaction, _SPANNER_TRANSACTION):
raise ValueError(
"Invalid transaction object: %s. It should be instance "
"of SPANNER_TRANSACTION object created by "
"spannerio.create_transaction transform." % type(spanner_transaction))
transaction_info = spanner_transaction.transaction
# We used batch snapshot to reuse the same transaction passed through the
# side input
self._snapshot = BatchSnapshot.from_dict(self._database, transaction_info)
# getting the transaction from the snapshot's session to run read operation.
# with self._snapshot.session().transaction() as transaction:
with self._get_session().transaction() as transaction:
table_id = self._spanner_configuration.table
query_name = self._spanner_configuration.query_name or ''
if element.is_sql is True:
transaction_read = transaction.execute_sql
metric_action = self._query_metric
metric_id = query_name
elif element.is_table is True:
transaction_read = transaction.read
metric_action = self._table_metric
metric_id = table_id
else:
raise ValueError(
"ReadOperation is improperly configure: %s" % str(element))
try:
for row in transaction_read(**element.kwargs):
yield row
metric_action(metric_id, 'ok')
except (ClientError, GoogleAPICallError) as e:
metric_action(metric_id, e.code.value)
raise
except HttpError as e:
metric_action(metric_id, e)
raise
@with_input_types(ReadOperation)
@with_output_types(typing.Dict[typing.Any, typing.Any])
class _CreateReadPartitions(DoFn):
"""
A DoFn to create partitions. Uses the Partitioning API (PartitionRead /
PartitionQuery) request to start a partitioned query operation. Returns a
list of batch information needed to perform the actual queries.
If the element is the instance of :class:`ReadOperation` is to perform sql
query, `PartitionQuery` API is used the create partitions and returns mappings
of information used perform actual partitioned reads via
:meth:`process_query_batch`.
If the element is the instance of :class:`ReadOperation` is to perform read
from table, `PartitionRead` API is used the create partitions and returns
mappings of information used perform actual partitioned reads via
:meth:`process_read_batch`.
"""
def __init__(self, spanner_configuration):
self._spanner_configuration = spanner_configuration
def setup(self):
spanner_client = Client(
project=self._spanner_configuration.project,
credentials=self._spanner_configuration.credentials)
instance = spanner_client.instance(self._spanner_configuration.instance)
self._database = instance.database(
self._spanner_configuration.database,
pool=self._spanner_configuration.pool)
self._snapshot = self._database.batch_snapshot(
**self._spanner_configuration.snapshot_options)
self._snapshot_dict = self._snapshot.to_dict()
def process(self, element):
if element.is_sql is True:
partitioning_action = self._snapshot.generate_query_batches
elif element.is_table is True:
partitioning_action = self._snapshot.generate_read_batches
else:
raise ValueError(
"ReadOperation is improperly configure: %s" % str(element))
for p in partitioning_action(**element.kwargs):
yield {
"is_sql": element.is_sql,
"is_table": element.is_table,
"read_operation": element.read_operation,
"partitions": p,
"transaction_info": self._snapshot_dict
}
@with_input_types(int)
@with_output_types(_SPANNER_TRANSACTION)
class _CreateTransactionFn(DoFn):
"""
A DoFn to create the transaction of cloud spanner.
It connects to the database and and returns the transaction_id and session_id
by using the batch_snapshot.to_dict() method available in the google cloud
spanner sdk.
https://googleapis.dev/python/spanner/latest/database-api.html?highlight=
batch_snapshot#google.cloud.spanner_v1.database.BatchSnapshot.to_dict
"""
def __init__(
self,
project_id,
instance_id,
database_id,
credentials,
pool,
read_timestamp,
exact_staleness):
self._project_id = project_id
self._instance_id = instance_id
self._database_id = database_id
self._credentials = credentials
self._pool = pool
self._snapshot_options = {}
if read_timestamp:
self._snapshot_options['read_timestamp'] = read_timestamp
if exact_staleness:
self._snapshot_options['exact_staleness'] = exact_staleness
self._snapshot = None
def setup(self):
self._spanner_client = Client(
project=self._project_id, credentials=self._credentials)
self._instance = self._spanner_client.instance(self._instance_id)
self._database = self._instance.database(self._database_id, pool=self._pool)
def process(self, element, *args, **kwargs):
self._snapshot = self._database.batch_snapshot(**self._snapshot_options)
return [_SPANNER_TRANSACTION(self._snapshot.to_dict())]
[docs]
@ptransform_fn
def create_transaction(
pbegin,
project_id,
instance_id,
database_id,
credentials=None,
pool=None,
read_timestamp=None,
exact_staleness=None):
"""
A PTransform method to create a batch transaction.
Args:
pbegin: Root of the pipeline
project_id: Cloud spanner project id. Be sure to use the Project ID,
not the Project Number.
instance_id: Cloud spanner instance id.
database_id: Cloud spanner database id.
credentials: (optional) The authorization credentials to attach to requests.
These credentials identify this application to the service.
If none are specified, the client will attempt to ascertain
the credentials from the environment.
pool: (optional) session pool to be used by database. If not passed,
Spanner Cloud SDK uses the BurstyPool by default.
`google.cloud.spanner.BurstyPool`. Ref:
https://googleapis.dev/python/spanner/latest/database-api.html?#google.
cloud.spanner_v1.database.Database
read_timestamp: (optional) An instance of the `datetime.datetime` object to
execute all reads at the given timestamp.
exact_staleness: (optional) An instance of the `datetime.timedelta`
object. These timestamp bounds execute reads at a user-specified
timestamp.
"""
assert isinstance(pbegin, PBegin)
return (
pbegin | Create([1]) | ParDo(
_CreateTransactionFn(
project_id,
instance_id,
database_id,
credentials,
pool,
read_timestamp,
exact_staleness)))
@with_input_types(typing.Dict[typing.Any, typing.Any])
@with_output_types(typing.List[typing.Any])
class _ReadFromPartitionFn(DoFn):
"""
A DoFn to perform reads from the partition.
"""
def __init__(self, spanner_configuration):
self._spanner_configuration = spanner_configuration
self.base_labels = {
monitoring_infos.SERVICE_LABEL: 'Spanner',
monitoring_infos.METHOD_LABEL: 'Read',
monitoring_infos.SPANNER_PROJECT_ID: (
self._spanner_configuration.project),
monitoring_infos.SPANNER_DATABASE_ID: (
self._spanner_configuration.database),
}
self.service_metric = None
def _table_metric(self, table_id):
database_id = self._spanner_configuration.database
project_id = self._spanner_configuration.project
resource = resource_identifiers.SpannerTable(
project_id, database_id, table_id)
labels = {
**self.base_labels,
monitoring_infos.RESOURCE_LABEL: resource,
monitoring_infos.SPANNER_TABLE_ID: table_id
}
service_call_metric = ServiceCallMetric(
request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
base_labels=labels)
return service_call_metric
def _query_metric(self, query_name):
project_id = self._spanner_configuration.project
resource = resource_identifiers.SpannerSqlQuery(project_id, query_name)
labels = {
**self.base_labels,
monitoring_infos.RESOURCE_LABEL: resource,
monitoring_infos.SPANNER_QUERY_NAME: query_name
}
service_call_metric = ServiceCallMetric(
request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
base_labels=labels)
return service_call_metric
def setup(self):
spanner_client = Client(self._spanner_configuration.project)
instance = spanner_client.instance(self._spanner_configuration.instance)
self._database = instance.database(
self._spanner_configuration.database,
pool=self._spanner_configuration.pool)
self._snapshot = self._database.batch_snapshot(
**self._spanner_configuration.snapshot_options)
def process(self, element):
self._snapshot = BatchSnapshot.from_dict(
self._database, element['transaction_info'])
table_id = self._spanner_configuration.table
query_name = self._spanner_configuration.query_name or ''
if element['is_sql'] is True:
read_action = self._snapshot.process_query_batch
self.service_metric = self._query_metric(query_name)
elif element['is_table'] is True:
read_action = self._snapshot.process_read_batch
self.service_metric = self._table_metric(table_id)
else:
raise ValueError(
"ReadOperation is improperly configure: %s" % str(element))
try:
for row in read_action(element['partitions']):
yield row
self.service_metric.call('ok')
except (ClientError, GoogleAPICallError) as e:
self.service_metric(str(e.code.value))
raise
except HttpError as e:
self.service_metric(str(e))
raise
def teardown(self):
if self._snapshot:
self._snapshot.close()
[docs]
class ReadFromSpanner(PTransform):
"""
A PTransform to perform reads from cloud spanner.
ReadFromSpanner uses BatchAPI to perform all read operations.
"""
def __init__(self, project_id, instance_id, database_id, pool=None,
read_timestamp=None, exact_staleness=None, credentials=None,
sql=None, params=None, param_types=None, # with_query
table=None, query_name=None, columns=None, index="",
keyset=None, # with_table
read_operations=None, # for read all
transaction=None
):
"""
A PTransform that uses Spanner Batch API to perform reads.
Args:
project_id: Cloud spanner project id. Be sure to use the Project ID,
not the Project Number.
instance_id: Cloud spanner instance id.
database_id: Cloud spanner database id.
pool: (optional) session pool to be used by database. If not passed,
Spanner Cloud SDK uses the BurstyPool by default.
`google.cloud.spanner.BurstyPool`. Ref:
https://googleapis.dev/python/spanner/latest/database-api.html?#google.
cloud.spanner_v1.database.Database
read_timestamp: (optional) An instance of the `datetime.datetime` object
to execute all reads at the given timestamp. By default, set to `None`.
exact_staleness: (optional) An instance of the `datetime.timedelta`
object. These timestamp bounds execute reads at a user-specified
timestamp. By default, set to `None`.
credentials: (optional) The authorization credentials to attach to
requests. These credentials identify this application to the service.
If none are specified, the client will attempt to ascertain
the credentials from the environment. By default, set to `None`.
sql: (optional) SQL query statement.
params: (optional) Values for parameter replacement. Keys must match the
names used in sql. By default, set to `None`.
param_types: (optional) maps explicit types for one or more param values;
required if params are passed. By default, set to `None`.
table: (optional) Name of the table from which to fetch data. By
default, set to `None`.
columns: (optional) List of names of columns to be retrieved; required if
the table is passed. By default, set to `None`.
index: (optional) name of index to use, rather than the table's primary
key. By default, set to `None`.
keyset: (optional) keys / ranges identifying rows to be retrieved. By
default, set to `None`.
read_operations: (optional) List of the objects of :class:`ReadOperation`
to perform read all. By default, set to `None`.
transaction: (optional) PTransform of the :meth:`create_transaction` to
perform naive read on cloud spanner. By default, set to `None`.
"""
self._configuration = _BeamSpannerConfiguration(
project=project_id,
instance=instance_id,
database=database_id,
table=table,
query_name=query_name,
credentials=credentials,
pool=pool,
snapshot_read_timestamp=read_timestamp,
snapshot_exact_staleness=exact_staleness)
self._read_operations = read_operations
self._transaction = transaction
if self._read_operations is None:
if table is not None:
if columns is None:
raise ValueError("Columns are required with the table name.")
self._read_operations = [
ReadOperation.table(
table=table, columns=columns, index=index, keyset=keyset)
]
elif sql is not None:
self._read_operations = [
ReadOperation.query(
sql=sql, params=params, param_types=param_types)
]
[docs]
def expand(self, pbegin):
if self._read_operations is not None and isinstance(pbegin, PBegin):
pcoll = pbegin.pipeline | Create(self._read_operations)
elif not isinstance(pbegin, PBegin):
if self._read_operations is not None:
raise ValueError(
"Read operation in the constructor only works with "
"the root of the pipeline.")
pcoll = pbegin
else:
raise ValueError(
"Spanner required read operation, sql or table "
"with columns.")
if self._transaction is None:
# reading as batch read using the spanner partitioning query to create
# batches.
p = (
pcoll
| 'Generate Partitions' >> ParDo(
_CreateReadPartitions(spanner_configuration=self._configuration))
| 'Reshuffle' >> Reshuffle()
| 'Read From Partitions' >> ParDo(
_ReadFromPartitionFn(spanner_configuration=self._configuration)))
else:
# reading as naive read, in which we don't make batches and execute the
# queries as a single read.
p = (
pcoll
| 'Reshuffle' >> Reshuffle().with_input_types(ReadOperation)
| 'Perform Read' >> ParDo(
_NaiveSpannerReadDoFn(spanner_configuration=self._configuration),
AsSingleton(self._transaction)))
return p
[docs]
def display_data(self):
res = {}
sql = []
table = []
if self._read_operations is not None:
for ro in self._read_operations:
if ro.is_sql is True:
sql.append(ro.kwargs)
elif ro.is_table is True:
table.append(ro.kwargs)
if sql:
res['sql'] = DisplayDataItem(str(sql), label='Sql')
if table:
res['table'] = DisplayDataItem(str(table), label='Table')
if self._transaction:
res['transaction'] = DisplayDataItem(
str(self._transaction), label='transaction')
return res
[docs]
class WriteToSpanner(PTransform):
def __init__(
self,
project_id,
instance_id,
database_id,
pool=None,
credentials=None,
max_batch_size_bytes=1048576,
max_number_rows=50,
max_number_cells=500):
"""
A PTransform to write onto Google Cloud Spanner.
Args:
project_id: Cloud spanner project id. Be sure to use the Project ID,
not the Project Number.
instance_id: Cloud spanner instance id.
database_id: Cloud spanner database id.
max_batch_size_bytes: (optional) Split the mutations into batches to
reduce the number of transaction sent to Spanner. By default it is
set to 1 MB (1048576 Bytes).
max_number_rows: (optional) Split the mutations into batches to
reduce the number of transaction sent to Spanner. By default it is
set to 50 rows per batch.
max_number_cells: (optional) Split the mutations into batches to
reduce the number of transaction sent to Spanner. By default it is
set to 500 cells per batch.
"""
self._configuration = _BeamSpannerConfiguration(
project=project_id,
instance=instance_id,
database=database_id,
table=None,
query_name=None,
credentials=credentials,
pool=pool,
snapshot_read_timestamp=None,
snapshot_exact_staleness=None)
self._max_batch_size_bytes = max_batch_size_bytes
self._max_number_rows = max_number_rows
self._max_number_cells = max_number_cells
self._database_id = database_id
self._project_id = project_id
self._instance_id = instance_id
self._pool = pool
[docs]
def display_data(self):
res = {
'project_id': DisplayDataItem(self._project_id, label='Project Id'),
'instance_id': DisplayDataItem(self._instance_id, label='Instance Id'),
'pool': DisplayDataItem(str(self._pool), label='Pool'),
'database': DisplayDataItem(self._database_id, label='Database'),
'batch_size': DisplayDataItem(
self._max_batch_size_bytes, label="Batch Size"),
'max_number_rows': DisplayDataItem(
self._max_number_rows, label="Max Rows"),
'max_number_cells': DisplayDataItem(
self._max_number_cells, label="Max Cells"),
}
return res
[docs]
def expand(self, pcoll):
return (
pcoll
| "make batches" >> _WriteGroup(
max_batch_size_bytes=self._max_batch_size_bytes,
max_number_rows=self._max_number_rows,
max_number_cells=self._max_number_cells)
|
'Writing to spanner' >> ParDo(_WriteToSpannerDoFn(self._configuration)))
class _Mutator(namedtuple('_Mutator',
["mutation", "operation", "kwargs", "rows", "cells"])
):
__slots__ = ()
@property
def byte_size(self):
if hasattr(self.mutation, '_pb'):
# google-cloud-spanner 3.x
return self.mutation._pb.ByteSize()
else:
# google-cloud-spanner 1.x
return self.mutation.ByteSize()
[docs]
class MutationGroup(deque):
"""
A Bundle of Spanner Mutations (_Mutator).
"""
@property
def info(self):
cells = 0
rows = 0
bytes = 0
for m in self.__iter__():
bytes += m.byte_size
rows += m.rows
cells += m.cells
return {"rows": rows, "cells": cells, "byte_size": bytes}
[docs]
def primary(self):
return next(self.__iter__())
[docs]
class WriteMutation(object):
_OPERATION_DELETE = "delete"
_OPERATION_INSERT = "insert"
_OPERATION_INSERT_OR_UPDATE = "insert_or_update"
_OPERATION_REPLACE = "replace"
_OPERATION_UPDATE = "update"
def __init__(
self,
insert=None,
update=None,
insert_or_update=None,
replace=None,
delete=None,
columns=None,
values=None,
keyset=None):
"""
A convenient class to create Spanner Mutations for Write. User can provide
the operation via constructor or via static methods.
Note: If a user passing the operation via construction, make sure that it
will only accept one operation at a time. For example, if a user passing
a table name in the `insert` parameter, and he also passes the `update`
parameter value, this will cause an error.
Args:
insert: (Optional) Name of the table in which rows will be inserted.
update: (Optional) Name of the table in which existing rows will be
updated.
insert_or_update: (Optional) Table name in which rows will be written.
Like insert, except that if the row already exists, then its column
values are overwritten with the ones provided. Any column values not
explicitly written are preserved.
replace: (Optional) Table name in which rows will be replaced. Like
insert, except that if the row already exists, it is deleted, and the
column values provided are inserted instead. Unlike `insert_or_update`,
this means any values not explicitly written become `NULL`.
delete: (Optional) Table name from which rows will be deleted. Succeeds
whether or not the named rows were present.
columns: The names of the columns in table to be written. The list of
columns must contain enough columns to allow Cloud Spanner to derive
values for all primary key columns in the row(s) to be modified.
values: The values to be written. `values` can contain more than one
list of values. If it does, then multiple rows are written, one for
each entry in `values`. Each list in `values` must have exactly as
many entries as there are entries in columns above. Sending multiple
lists is equivalent to sending multiple Mutations, each containing one
`values` entry and repeating table and columns.
keyset: (Optional) The primary keys of the rows within table to delete.
Delete is idempotent. The transaction will succeed even if some or
all rows do not exist.
"""
self._columns = columns
self._values = values
self._keyset = keyset
self._insert = insert
self._update = update
self._insert_or_update = insert_or_update
self._replace = replace
self._delete = delete
if sum([1 for x in [self._insert,
self._update,
self._insert_or_update,
self._replace,
self._delete] if x is not None]) != 1:
raise ValueError(
"No or more than one write mutation operation "
"provided: <%s: %s>" % (self.__class__.__name__, str(self.__dict__)))
def __call__(self, *args, **kwargs):
if self._insert is not None:
return WriteMutation.insert(
table=self._insert, columns=self._columns, values=self._values)
elif self._update is not None:
return WriteMutation.update(
table=self._update, columns=self._columns, values=self._values)
elif self._insert_or_update is not None:
return WriteMutation.insert_or_update(
table=self._insert_or_update,
columns=self._columns,
values=self._values)
elif self._replace is not None:
return WriteMutation.replace(
table=self._replace, columns=self._columns, values=self._values)
elif self._delete is not None:
return WriteMutation.delete(table=self._delete, keyset=self._keyset)
[docs]
@staticmethod
def insert(table, columns, values):
"""Insert one or more new table rows.
Args:
table: Name of the table to be modified.
columns: Name of the table columns to be modified.
values: Values to be modified.
"""
rows = len(values)
cells = len(columns) * len(values)
return _Mutator(
mutation=Mutation(insert=batch._make_write_pb(table, columns, values)),
operation=WriteMutation._OPERATION_INSERT,
rows=rows,
cells=cells,
kwargs={
"table": table, "columns": columns, "values": values
})
[docs]
@staticmethod
def update(table, columns, values):
"""Update one or more existing table rows.
Args:
table: Name of the table to be modified.
columns: Name of the table columns to be modified.
values: Values to be modified.
"""
rows = len(values)
cells = len(columns) * len(values)
return _Mutator(
mutation=Mutation(update=batch._make_write_pb(table, columns, values)),
operation=WriteMutation._OPERATION_UPDATE,
rows=rows,
cells=cells,
kwargs={
"table": table, "columns": columns, "values": values
})
[docs]
@staticmethod
def insert_or_update(table, columns, values):
"""Insert/update one or more table rows.
Args:
table: Name of the table to be modified.
columns: Name of the table columns to be modified.
values: Values to be modified.
"""
rows = len(values)
cells = len(columns) * len(values)
return _Mutator(
mutation=Mutation(
insert_or_update=batch._make_write_pb(table, columns, values)),
operation=WriteMutation._OPERATION_INSERT_OR_UPDATE,
rows=rows,
cells=cells,
kwargs={
"table": table, "columns": columns, "values": values
})
[docs]
@staticmethod
def replace(table, columns, values):
"""Replace one or more table rows.
Args:
table: Name of the table to be modified.
columns: Name of the table columns to be modified.
values: Values to be modified.
"""
rows = len(values)
cells = len(columns) * len(values)
return _Mutator(
mutation=Mutation(replace=batch._make_write_pb(table, columns, values)),
operation=WriteMutation._OPERATION_REPLACE,
rows=rows,
cells=cells,
kwargs={
"table": table, "columns": columns, "values": values
})
[docs]
@staticmethod
def delete(table, keyset):
"""Delete one or more table rows.
Args:
table: Name of the table to be modified.
keyset: Keys/ranges identifying rows to delete.
"""
delete = Mutation.Delete(table=table, key_set=keyset._to_pb())
return _Mutator(
mutation=Mutation(delete=delete),
rows=0,
cells=0,
operation=WriteMutation._OPERATION_DELETE,
kwargs={
"table": table, "keyset": keyset
})
@with_input_types(typing.Union[MutationGroup, TaggedOutput])
@with_output_types(MutationGroup)
class _BatchFn(DoFn):
"""
Batches mutations together.
"""
def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells):
self._max_batch_size_bytes = max_batch_size_bytes
self._max_number_rows = max_number_rows
self._max_number_cells = max_number_cells
def start_bundle(self):
self._batch = MutationGroup()
self._size_in_bytes = 0
self._rows = 0
self._cells = 0
def _reset_count(self):
self._batch = MutationGroup()
self._size_in_bytes = 0
self._rows = 0
self._cells = 0
def process(self, element):
mg_info = element.info
if mg_info['byte_size'] + self._size_in_bytes > self._max_batch_size_bytes \
or mg_info['cells'] + self._cells > self._max_number_cells \
or mg_info['rows'] + self._rows > self._max_number_rows:
# Batch is full, output the batch and resetting the count.
if self._batch:
yield self._batch
self._reset_count()
self._batch.extend(element)
# total byte size of the mutation group.
self._size_in_bytes += mg_info['byte_size']
# total rows in the mutation group.
self._rows += mg_info['rows']
# total cells in the mutation group.
self._cells += mg_info['cells']
def finish_bundle(self):
if self._batch is not None:
yield window.GlobalWindows.windowed_value(self._batch)
self._batch = None
@with_input_types(MutationGroup)
@with_output_types(MutationGroup)
class _BatchableFilterFn(DoFn):
"""
Filters MutationGroups larger than the batch size to the output tagged with
OUTPUT_TAG_UNBATCHABLE.
"""
OUTPUT_TAG_UNBATCHABLE = 'unbatchable'
def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells):
self._max_batch_size_bytes = max_batch_size_bytes
self._max_number_rows = max_number_rows
self._max_number_cells = max_number_cells
self._batchable = None
self._unbatchable = None
def process(self, element):
if element.primary().operation == WriteMutation._OPERATION_DELETE:
# As delete mutations are not batchable.
yield TaggedOutput(_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, element)
else:
mg_info = element.info
if mg_info['byte_size'] > self._max_batch_size_bytes \
or mg_info['cells'] > self._max_number_cells \
or mg_info['rows'] > self._max_number_rows:
yield TaggedOutput(_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, element)
else:
yield element
class _WriteToSpannerDoFn(DoFn):
def __init__(self, spanner_configuration):
self._spanner_configuration = spanner_configuration
self._db_instance = None
self.batches = Metrics.counter(self.__class__, 'SpannerBatches')
self.base_labels = {
monitoring_infos.SERVICE_LABEL: 'Spanner',
monitoring_infos.METHOD_LABEL: 'Write',
monitoring_infos.SPANNER_PROJECT_ID: spanner_configuration.project,
monitoring_infos.SPANNER_DATABASE_ID: spanner_configuration.database,
}
# table_id to metrics
self.service_metrics = {}
def _register_table_metric(self, table_id):
if table_id in self.service_metrics:
return
database_id = self._spanner_configuration.database
project_id = self._spanner_configuration.project
resource = resource_identifiers.SpannerTable(
project_id, database_id, table_id)
labels = {
**self.base_labels,
monitoring_infos.RESOURCE_LABEL: resource,
monitoring_infos.SPANNER_TABLE_ID: table_id
}
service_call_metric = ServiceCallMetric(
request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
base_labels=labels)
self.service_metrics[table_id] = service_call_metric
def setup(self):
spanner_client = Client(self._spanner_configuration.project)
instance = spanner_client.instance(self._spanner_configuration.instance)
self._db_instance = instance.database(
self._spanner_configuration.database,
pool=self._spanner_configuration.pool)
def start_bundle(self):
self.service_metrics = {}
def process(self, element):
self.batches.inc()
try:
with self._db_instance.batch() as b:
for m in element:
table_id = m.kwargs['table']
self._register_table_metric(table_id)
if m.operation == WriteMutation._OPERATION_DELETE:
batch_func = b.delete
elif m.operation == WriteMutation._OPERATION_REPLACE:
batch_func = b.replace
elif m.operation == WriteMutation._OPERATION_INSERT_OR_UPDATE:
batch_func = b.insert_or_update
elif m.operation == WriteMutation._OPERATION_INSERT:
batch_func = b.insert
elif m.operation == WriteMutation._OPERATION_UPDATE:
batch_func = b.update
else:
raise ValueError("Unknown operation action: %s" % m.operation)
batch_func(**m.kwargs)
except (ClientError, GoogleAPICallError) as e:
for service_metric in self.service_metrics.values():
service_metric.call(str(e.code.value))
raise
except HttpError as e:
for service_metric in self.service_metrics.values():
service_metric.call(str(e))
raise
else:
for service_metric in self.service_metrics.values():
service_metric.call('ok')
@with_input_types(typing.Union[MutationGroup, _Mutator])
@with_output_types(MutationGroup)
class _MakeMutationGroupsFn(DoFn):
"""
Make Mutation group object if the element is the instance of _Mutator.
"""
def process(self, element):
if isinstance(element, MutationGroup):
yield element
elif isinstance(element, _Mutator):
yield MutationGroup([element])
else:
raise ValueError(
"Invalid object type: %s. Object must be an instance of "
"MutationGroup or WriteMutations" % str(element))
class _WriteGroup(PTransform):
def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells):
self._max_batch_size_bytes = max_batch_size_bytes
self._max_number_rows = max_number_rows
self._max_number_cells = max_number_cells
def expand(self, pcoll):
filter_batchable_mutations = (
pcoll
| 'Making mutation groups' >> ParDo(_MakeMutationGroupsFn())
| 'Filtering Batchable Mutations' >> ParDo(
_BatchableFilterFn(
max_batch_size_bytes=self._max_batch_size_bytes,
max_number_rows=self._max_number_rows,
max_number_cells=self._max_number_cells)).with_outputs(
_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, main='batchable')
)
batching_batchables = (
filter_batchable_mutations['batchable']
| ParDo(
_BatchFn(
max_batch_size_bytes=self._max_batch_size_bytes,
max_number_rows=self._max_number_rows,
max_number_cells=self._max_number_cells)))
return ((
batching_batchables,
filter_batchable_mutations[_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE])
| 'Merging batchable and unbatchable' >> Flatten())