Source code for apache_beam.io.gcp.datastore.v1.query_splitter

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Implements a Cloud Datastore query splitter."""

from apache_beam.io.gcp.datastore.v1 import helper

# Protect against environments where datastore library is not available.
# pylint: disable=wrong-import-order, wrong-import-position
try:
  from google.cloud.proto.datastore.v1 import datastore_pb2
  from google.cloud.proto.datastore.v1 import query_pb2
  from google.cloud.proto.datastore.v1.query_pb2 import PropertyFilter
  from google.cloud.proto.datastore.v1.query_pb2 import CompositeFilter
  from googledatastore import helper as datastore_helper
  UNSUPPORTED_OPERATORS = [PropertyFilter.LESS_THAN,
                           PropertyFilter.LESS_THAN_OR_EQUAL,
                           PropertyFilter.GREATER_THAN,
                           PropertyFilter.GREATER_THAN_OR_EQUAL]
except ImportError:
  UNSUPPORTED_OPERATORS = None
# pylint: enable=wrong-import-order, wrong-import-position


__all__ = [
    'get_splits',
]

SCATTER_PROPERTY_NAME = '__scatter__'
KEY_PROPERTY_NAME = '__key__'
# The number of keys to sample for each split.
KEYS_PER_SPLIT = 32


[docs]def get_splits(datastore, query, num_splits, partition=None): """Returns a list of sharded queries for the given Cloud Datastore query. This will create up to the desired number of splits, however it may return less splits if the desired number of splits is unavailable. This will happen if the number of split points provided by the underlying Datastore is less than the desired number, which will occur if the number of results for the query is too small. This implementation of the QuerySplitter uses the __scatter__ property to gather random split points for a query. Note: This implementation is derived from the java query splitter in https://github.com/GoogleCloudPlatform/google-cloud-datastore/blob/master/java/datastore/src/main/java/com/google/datastore/v1/client/QuerySplitterImpl.java Args: datastore: the datastore client. query: the query to split. num_splits: the desired number of splits. partition: the partition the query is running in. Returns: A list of split queries, of a max length of `num_splits` """ # Validate that the number of splits is not out of bounds. if num_splits < 1: raise ValueError('The number of splits must be greater than 0.') if num_splits == 1: return [query] _validate_query(query) splits = [] scatter_keys = _get_scatter_keys(datastore, query, num_splits, partition) last_key = None for next_key in _get_split_key(scatter_keys, num_splits): splits.append(_create_split(last_key, next_key, query)) last_key = next_key splits.append(_create_split(last_key, None, query)) return splits
def _validate_query(query): """ Verifies that the given query can be properly scattered.""" if len(query.kind) != 1: raise ValueError('Query must have exactly one kind.') if query.order: raise ValueError('Query cannot have any sort orders.') if query.HasField('limit'): raise ValueError('Query cannot have a limit set.') if query.offset > 0: raise ValueError('Query cannot have an offset set.') _validate_filter(query.filter) def _validate_filter(filter): """Validates that we only have allowable filters. Note that equality and ancestor filters are allowed, however they may result in inefficient sharding. """ if filter.HasField('composite_filter'): for sub_filter in filter.composite_filter.filters: _validate_filter(sub_filter) elif filter.HasField('property_filter'): if filter.property_filter.op in UNSUPPORTED_OPERATORS: raise ValueError('Query cannot have any inequality filters.') else: pass def _create_scatter_query(query, num_splits): """Creates a scatter query from the given user query.""" scatter_query = query_pb2.Query() for kind in query.kind: scatter_kind = scatter_query.kind.add() scatter_kind.CopyFrom(kind) # ascending order datastore_helper.add_property_orders(scatter_query, SCATTER_PROPERTY_NAME) # There is a split containing entities before and after each scatter entity: # ||---*------*------*------*------*------*------*---|| * = scatter entity # If we represent each split as a region before a scatter entity, there is an # extra region following the last scatter point. Thus, we do not need the # scatter entity for the last region. scatter_query.limit.value = (num_splits - 1) * KEYS_PER_SPLIT datastore_helper.add_projection(scatter_query, KEY_PROPERTY_NAME) return scatter_query def _get_scatter_keys(datastore, query, num_splits, partition): """Gets a list of split keys given a desired number of splits. This list will contain multiple split keys for each split. Only a single split key will be chosen as the split point, however providing multiple keys allows for more uniform sharding. Args: numSplits: the number of desired splits. query: the user query. partition: the partition to run the query in. datastore: the client to datastore containing the data. Returns: A list of scatter keys returned by Datastore. """ scatter_point_query = _create_scatter_query(query, num_splits) key_splits = [] while True: req = datastore_pb2.RunQueryRequest() if partition: req.partition_id.CopyFrom(partition) req.query.CopyFrom(scatter_point_query) resp = datastore.run_query(req) for entity_result in resp.batch.entity_results: key_splits.append(entity_result.entity.key) if resp.batch.more_results != query_pb2.QueryResultBatch.NOT_FINISHED: break scatter_point_query.start_cursor = resp.batch.end_cursor scatter_point_query.limit.value -= len(resp.batch.entity_results) key_splits.sort(helper.key_comparator) return key_splits def _get_split_key(keys, num_splits): """Given a list of keys and a number of splits find the keys to split on. Args: keys: the list of keys. num_splits: the number of splits. Returns: A list of keys to split on. """ # If the number of keys is less than the number of splits, we are limited # in the number of splits we can make. if not keys or (len(keys) < (num_splits - 1)): return keys # Calculate the number of keys per split. This should be KEYS_PER_SPLIT, # but may be less if there are not KEYS_PER_SPLIT * (numSplits - 1) scatter # entities. # # Consider the following dataset, where - represents an entity and # * represents an entity that is returned as a scatter entity: # ||---*-----*----*-----*-----*------*----*----|| # If we want 4 splits in this data, the optimal split would look like: # ||---*-----*----*-----*-----*------*----*----|| # | | | # The scatter keys in the last region are not useful to us, so we never # request them: # ||---*-----*----*-----*-----*------*---------|| # | | | # With 6 scatter keys we want to set scatter points at indexes: 1, 3, 5. # # We keep this as a float so that any "fractional" keys per split get # distributed throughout the splits and don't make the last split # significantly larger than the rest. num_keys_per_split = max(1.0, float(len(keys)) / (num_splits - 1)) split_keys = [] # Grab the last sample for each split, otherwise the first split will be too # small. for i in range(1, num_splits): split_index = int(round(i * num_keys_per_split) - 1) split_keys.append(keys[split_index]) return split_keys def _create_split(last_key, next_key, query): """Create a new {@link Query} given the query and range.. Args: last_key: the previous key. If null then assumed to be the beginning. next_key: the next key. If null then assumed to be the end. query: the desired query. Returns: A split query with fetches entities in the range [last_key, next_key) """ if not (last_key or next_key): return query split_query = query_pb2.Query() split_query.CopyFrom(query) composite_filter = split_query.filter.composite_filter composite_filter.op = CompositeFilter.AND if query.HasField('filter'): composite_filter.filters.add().CopyFrom(query.filter) if last_key: lower_bound = composite_filter.filters.add() lower_bound.property_filter.property.name = KEY_PROPERTY_NAME lower_bound.property_filter.op = PropertyFilter.GREATER_THAN_OR_EQUAL lower_bound.property_filter.value.key_value.CopyFrom(last_key) if next_key: upper_bound = composite_filter.filters.add() upper_bound.property_filter.property.name = KEY_PROPERTY_NAME upper_bound.property_filter.op = PropertyFilter.LESS_THAN upper_bound.property_filter.value.key_value.CopyFrom(next_key) return split_query