#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Implements a Cloud Datastore query splitter."""
from __future__ import absolute_import
from __future__ import division
from builtins import range
from builtins import round
from apache_beam.io.gcp.datastore.v1 import helper
# Protect against environments where datastore library is not available.
# pylint: disable=wrong-import-order, wrong-import-position
try:
from google.cloud.proto.datastore.v1 import datastore_pb2
from google.cloud.proto.datastore.v1 import query_pb2
from google.cloud.proto.datastore.v1.query_pb2 import PropertyFilter
from google.cloud.proto.datastore.v1.query_pb2 import CompositeFilter
from googledatastore import helper as datastore_helper
UNSUPPORTED_OPERATORS = [PropertyFilter.LESS_THAN,
PropertyFilter.LESS_THAN_OR_EQUAL,
PropertyFilter.GREATER_THAN,
PropertyFilter.GREATER_THAN_OR_EQUAL]
except ImportError:
UNSUPPORTED_OPERATORS = None
# pylint: enable=wrong-import-order, wrong-import-position
__all__ = [
'get_splits',
]
SCATTER_PROPERTY_NAME = '__scatter__'
KEY_PROPERTY_NAME = '__key__'
# The number of keys to sample for each split.
KEYS_PER_SPLIT = 32
[docs]def get_splits(datastore, query, num_splits, partition=None):
"""Returns a list of sharded queries for the given Cloud Datastore query.
This will create up to the desired number of splits, however it may return
less splits if the desired number of splits is unavailable. This will happen
if the number of split points provided by the underlying Datastore is less
than the desired number, which will occur if the number of results for the
query is too small.
This implementation of the QuerySplitter uses the __scatter__ property to
gather random split points for a query.
Note: This implementation is derived from the java query splitter in
https://github.com/GoogleCloudPlatform/google-cloud-datastore/blob/master/java/datastore/src/main/java/com/google/datastore/v1/client/QuerySplitterImpl.java
Args:
datastore: the datastore client.
query: the query to split.
num_splits: the desired number of splits.
partition: the partition the query is running in.
Returns:
A list of split queries, of a max length of `num_splits`
"""
# Validate that the number of splits is not out of bounds.
if num_splits < 1:
raise ValueError('The number of splits must be greater than 0.')
if num_splits == 1:
return [query]
_validate_query(query)
splits = []
scatter_keys = _get_scatter_keys(datastore, query, num_splits, partition)
last_key = None
for next_key in _get_split_key(scatter_keys, num_splits):
splits.append(_create_split(last_key, next_key, query))
last_key = next_key
splits.append(_create_split(last_key, None, query))
return splits
def _validate_query(query):
""" Verifies that the given query can be properly scattered."""
if len(query.kind) != 1:
raise ValueError('Query must have exactly one kind.')
if query.order:
raise ValueError('Query cannot have any sort orders.')
if query.HasField('limit'):
raise ValueError('Query cannot have a limit set.')
if query.offset > 0:
raise ValueError('Query cannot have an offset set.')
_validate_filter(query.filter)
def _validate_filter(filter):
"""Validates that we only have allowable filters.
Note that equality and ancestor filters are allowed, however they may result
in inefficient sharding.
"""
if filter.HasField('composite_filter'):
for sub_filter in filter.composite_filter.filters:
_validate_filter(sub_filter)
elif filter.HasField('property_filter'):
if filter.property_filter.op in UNSUPPORTED_OPERATORS:
raise ValueError('Query cannot have any inequality filters.')
else:
pass
def _create_scatter_query(query, num_splits):
"""Creates a scatter query from the given user query."""
scatter_query = query_pb2.Query()
for kind in query.kind:
scatter_kind = scatter_query.kind.add()
scatter_kind.CopyFrom(kind)
# ascending order
datastore_helper.add_property_orders(scatter_query, SCATTER_PROPERTY_NAME)
# There is a split containing entities before and after each scatter entity:
# ||---*------*------*------*------*------*------*---|| * = scatter entity
# If we represent each split as a region before a scatter entity, there is an
# extra region following the last scatter point. Thus, we do not need the
# scatter entity for the last region.
scatter_query.limit.value = (num_splits - 1) * KEYS_PER_SPLIT
datastore_helper.add_projection(scatter_query, KEY_PROPERTY_NAME)
return scatter_query
def _get_scatter_keys(datastore, query, num_splits, partition):
"""Gets a list of split keys given a desired number of splits.
This list will contain multiple split keys for each split. Only a single split
key will be chosen as the split point, however providing multiple keys allows
for more uniform sharding.
Args:
numSplits: the number of desired splits.
query: the user query.
partition: the partition to run the query in.
datastore: the client to datastore containing the data.
Returns:
A list of scatter keys returned by Datastore.
"""
scatter_point_query = _create_scatter_query(query, num_splits)
key_splits = []
while True:
req = datastore_pb2.RunQueryRequest()
if partition:
req.partition_id.CopyFrom(partition)
req.query.CopyFrom(scatter_point_query)
resp = datastore.run_query(req)
for entity_result in resp.batch.entity_results:
key_splits.append(entity_result.entity.key)
if resp.batch.more_results != query_pb2.QueryResultBatch.NOT_FINISHED:
break
scatter_point_query.start_cursor = resp.batch.end_cursor
scatter_point_query.limit.value -= len(resp.batch.entity_results)
key_splits.sort(helper.key_comparator)
return key_splits
def _get_split_key(keys, num_splits):
"""Given a list of keys and a number of splits find the keys to split on.
Args:
keys: the list of keys.
num_splits: the number of splits.
Returns:
A list of keys to split on.
"""
# If the number of keys is less than the number of splits, we are limited
# in the number of splits we can make.
if not keys or (len(keys) < (num_splits - 1)):
return keys
# Calculate the number of keys per split. This should be KEYS_PER_SPLIT,
# but may be less if there are not KEYS_PER_SPLIT * (numSplits - 1) scatter
# entities.
#
# Consider the following dataset, where - represents an entity and
# * represents an entity that is returned as a scatter entity:
# ||---*-----*----*-----*-----*------*----*----||
# If we want 4 splits in this data, the optimal split would look like:
# ||---*-----*----*-----*-----*------*----*----||
# | | |
# The scatter keys in the last region are not useful to us, so we never
# request them:
# ||---*-----*----*-----*-----*------*---------||
# | | |
# With 6 scatter keys we want to set scatter points at indexes: 1, 3, 5.
#
# We keep this as a float so that any "fractional" keys per split get
# distributed throughout the splits and don't make the last split
# significantly larger than the rest.
num_keys_per_split = max(1.0, float(len(keys)) / (num_splits - 1))
split_keys = []
# Grab the last sample for each split, otherwise the first split will be too
# small.
for i in range(1, num_splits):
split_index = int(round(i * num_keys_per_split) - 1)
split_keys.append(keys[split_index])
return split_keys
def _create_split(last_key, next_key, query):
"""Create a new {@link Query} given the query and range..
Args:
last_key: the previous key. If null then assumed to be the beginning.
next_key: the next key. If null then assumed to be the end.
query: the desired query.
Returns:
A split query with fetches entities in the range [last_key, next_key)
"""
if not (last_key or next_key):
return query
split_query = query_pb2.Query()
split_query.CopyFrom(query)
composite_filter = split_query.filter.composite_filter
composite_filter.op = CompositeFilter.AND
if query.HasField('filter'):
composite_filter.filters.add().CopyFrom(query.filter)
if last_key:
lower_bound = composite_filter.filters.add()
lower_bound.property_filter.property.name = KEY_PROPERTY_NAME
lower_bound.property_filter.op = PropertyFilter.GREATER_THAN_OR_EQUAL
lower_bound.property_filter.value.key_value.CopyFrom(last_key)
if next_key:
upper_bound = composite_filter.filters.add()
upper_bound.property_filter.property.name = KEY_PROPERTY_NAME
upper_bound.property_filter.op = PropertyFilter.LESS_THAN
upper_bound.property_filter.value.key_value.CopyFrom(next_key)
return split_query