Source code for apache_beam.io.gcp.datastore.v1new.query_splitter

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Implements a Cloud Datastore query splitter.

For internal use only. No backwards compatibility guarantees.
"""
from __future__ import absolute_import
from __future__ import division

from builtins import range
from builtins import round

from apache_beam.io.gcp.datastore.v1new import types

__all__ = ['QuerySplitterError', 'SplitNotPossibleError', 'get_splits']

SCATTER_PROPERTY_NAME = '__scatter__'
KEY_PROPERTY_NAME = '__key__'
# The number of keys to sample for each split.
KEYS_PER_SPLIT = 32


[docs]class QuerySplitterError(Exception): """Top-level error type."""
[docs]class SplitNotPossibleError(QuerySplitterError): """Raised when some parameter of the query does not allow splitting."""
[docs]def get_splits(client, query, num_splits): """Returns a list of sharded queries for the given Cloud Datastore query. This will create up to the desired number of splits, however it may return less splits if the desired number of splits is unavailable. This will happen if the number of split points provided by the underlying Datastore is less than the desired number, which will occur if the number of results for the query is too small. This implementation of the QuerySplitter uses the __scatter__ property to gather random split points for a query. Note: This implementation is derived from the java query splitter in https://github.com/GoogleCloudPlatform/google-cloud-datastore/blob/master/java/datastore/src/main/java/com/google/datastore/v1/client/QuerySplitterImpl.java Args: client: the datastore client. query: the query to split. num_splits: the desired number of splits. Returns: A list of split queries, of a max length of `num_splits` Raises: QuerySplitterError if split could not be performed owing to query or split parameters. """ if num_splits <= 1: raise SplitNotPossibleError('num_splits must be > 1, got: %d' % num_splits) validate_split(query) splits = [] client_scatter_keys = _get_scatter_keys(client, query, num_splits) last_client_key = None for next_client_key in _get_split_key(client_scatter_keys, num_splits): splits.append(_create_split(last_client_key, next_client_key, query)) last_client_key = next_client_key splits.append(_create_split(last_client_key, None, query)) return splits
def validate_split(query): """ Verifies that the given query can be properly scattered. Note that equality and ancestor filters are allowed, however they may result in inefficient sharding. Raises: QuerySplitterError if split could not be performed owing to query parameters. """ if query.order: raise SplitNotPossibleError('Query cannot have any sort orders.') if query.limit is not None: raise SplitNotPossibleError('Query cannot have a limit set.') for filter in query.filters: if filter[1] in ['<', '<=', '>', '>=']: raise SplitNotPossibleError('Query cannot have any inequality filters.') def _create_scatter_query(query, num_splits): """Creates a scatter query from the given user query.""" # There is a split containing entities before and after each scatter entity: # ||---*------*------*------*------*------*------*---|| * = scatter entity # If we represent each split as a region before a scatter entity, there is an # extra region following the last scatter point. Thus, we do not need the # scatter entity for the last region. limit = (num_splits - 1) * KEYS_PER_SPLIT scatter_query = types.Query( kind=query.kind, project=query.project, namespace=query.namespace, order=[SCATTER_PROPERTY_NAME], projection=[KEY_PROPERTY_NAME], limit=limit) return scatter_query def client_key_sort_key(client_key): """Key function for sorting lists of ``google.cloud.datastore.key.Key``.""" return [client_key.project, client_key.namespace or ''] + [ str(element) for element in client_key.flat_path] def _get_scatter_keys(client, query, num_splits): """Gets a list of split keys given a desired number of splits. This list will contain multiple split keys for each split. Only a single split key will be chosen as the split point, however providing multiple keys allows for more uniform sharding. Args: client: the client to datastore containing the data. query: the user query. num_splits: the number of desired splits. Returns: A list of scatter keys returned by Datastore. """ scatter_point_query = _create_scatter_query(query, num_splits) client_query = scatter_point_query._to_client_query(client) client_key_splits = [ client_entity.key for client_entity in client_query.fetch(client=client, limit=scatter_point_query.limit)] client_key_splits.sort(key=client_key_sort_key) return client_key_splits def _get_split_key(client_keys, num_splits): """Given a list of keys and a number of splits find the keys to split on. Args: client_keys: the list of keys. num_splits: the number of splits. Returns: A list of keys to split on. """ # If the number of keys is less than the number of splits, we are limited # in the number of splits we can make. if not client_keys or (len(client_keys) < (num_splits - 1)): return client_keys # Calculate the number of keys per split. This should be KEYS_PER_SPLIT, # but may be less if there are not KEYS_PER_SPLIT * (numSplits - 1) scatter # entities. # # Consider the following dataset, where - represents an entity and # * represents an entity that is returned as a scatter entity: # ||---*-----*----*-----*-----*------*----*----|| # If we want 4 splits in this data, the optimal split would look like: # ||---*-----*----*-----*-----*------*----*----|| # | | | # The scatter keys in the last region are not useful to us, so we never # request them: # ||---*-----*----*-----*-----*------*---------|| # | | | # With 6 scatter keys we want to set scatter points at indexes: 1, 3, 5. # # We keep this as a float so that any "fractional" keys per split get # distributed throughout the splits and don't make the last split # significantly larger than the rest. num_keys_per_split = max(1.0, float(len(client_keys)) / (num_splits - 1)) split_client_keys = [] # Grab the last sample for each split, otherwise the first split will be too # small. for i in range(1, num_splits): split_index = int(round(i * num_keys_per_split) - 1) split_client_keys.append(client_keys[split_index]) return split_client_keys def _create_split(last_client_key, next_client_key, query): """Create a new {@link Query} given the query and range. Args: last_client_key: the previous key. If null then assumed to be the beginning. next_client_key: the next key. If null then assumed to be the end. query: query to base the split query on. Returns: A split query with fetches entities in the range [last_key, next_client_key) """ if not (last_client_key or next_client_key): return query split_query = query.clone() # Copy filters and possible convert the default empty tuple to empty list. filters = list(split_query.filters) if last_client_key: filters.append((KEY_PROPERTY_NAME, '>=', last_client_key)) if next_client_key: filters.append((KEY_PROPERTY_NAME, '<', next_client_key)) split_query.filters = filters return split_query