#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This module implements IO classes to read and write data on MongoDB.
Read from MongoDB
-----------------
:class:`ReadFromMongoDB` is a ``PTransform`` that reads from a configured
MongoDB source and returns a ``PCollection`` of dict representing MongoDB
documents.
To configure MongoDB source, the URI to connect to MongoDB server, database
name, collection name needs to be provided.
Example usage::
pipeline | ReadFromMongoDB(uri='mongodb://localhost:27017',
db='testdb',
coll='input')
To read from MongoDB Atlas, use ``bucket_auto`` option to enable
``@bucketAuto`` MongoDB aggregation instead of ``splitVector``
command which is a high-privilege function that cannot be assigned
to any user in Atlas.
Example usage::
pipeline | ReadFromMongoDB(uri='mongodb+srv://user:pwd@cluster0.mongodb.net',
db='testdb',
coll='input',
bucket_auto=True)
Write to MongoDB:
-----------------
:class:`WriteToMongoDB` is a ``PTransform`` that writes MongoDB documents to
configured sink, and the write is conducted through a mongodb bulk_write of
``ReplaceOne`` operations. If the document's _id field already existed in the
MongoDB collection, it results in an overwrite, otherwise, a new document
will be inserted.
Example usage::
pipeline | WriteToMongoDB(uri='mongodb://localhost:27017',
db='testdb',
coll='output',
batch_size=10)
No backward compatibility guarantees. Everything in this module is experimental.
"""
# pytype: skip-file
import itertools
import json
import logging
import math
import struct
from typing import Union
import apache_beam as beam
from apache_beam.io import iobase
from apache_beam.io.range_trackers import LexicographicKeyRangeTracker
from apache_beam.io.range_trackers import OffsetRangeTracker
from apache_beam.io.range_trackers import OrderedPositionRangeTracker
from apache_beam.transforms import DoFn
from apache_beam.transforms import PTransform
from apache_beam.transforms import Reshuffle
from apache_beam.utils.annotations import experimental
_LOGGER = logging.getLogger(__name__)
try:
# Mongodb has its own bundled bson, which is not compatible with bson package.
# (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
# it fails because bson package is installed, MongoDB IO will not work but at
# least rest of the SDK will work.
from bson import json_util
from bson import objectid
from bson.objectid import ObjectId
# pymongo also internally depends on bson.
from pymongo import ASCENDING
from pymongo import DESCENDING
from pymongo import MongoClient
from pymongo import ReplaceOne
except ImportError:
objectid = None
json_util = None
ObjectId = None
ASCENDING = 1
DESCENDING = -1
MongoClient = None
ReplaceOne = None
_LOGGER.warning("Could not find a compatible bson package.")
__all__ = ["ReadFromMongoDB", "WriteToMongoDB"]
[docs]@experimental()
class ReadFromMongoDB(PTransform):
"""A ``PTransform`` to read MongoDB documents into a ``PCollection``."""
def __init__(
self,
uri="mongodb://localhost:27017",
db=None,
coll=None,
filter=None,
projection=None,
extra_client_params=None,
bucket_auto=False,
):
"""Initialize a :class:`ReadFromMongoDB`
Args:
uri (str): The MongoDB connection string following the URI format.
db (str): The MongoDB database name.
coll (str): The MongoDB collection name.
filter: A `bson.SON
<https://api.mongodb.com/python/current/api/bson/son.html>`_ object
specifying elements which must be present for a document to be included
in the result set.
projection: A list of field names that should be returned in the result
set or a dict specifying the fields to include or exclude.
extra_client_params(dict): Optional `MongoClient
<https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_
parameters.
bucket_auto (bool): If :data:`True`, use MongoDB `$bucketAuto` aggregation
to split collection into bundles instead of `splitVector` command,
which does not work with MongoDB Atlas.
If :data:`False` (the default), use `splitVector` command for bundling.
Returns:
:class:`~apache_beam.transforms.ptransform.PTransform`
"""
if extra_client_params is None:
extra_client_params = {}
if not isinstance(db, str):
raise ValueError("ReadFromMongDB db param must be specified as a string")
if not isinstance(coll, str):
raise ValueError(
"ReadFromMongDB coll param must be specified as a string")
self._mongo_source = _BoundedMongoSource(
uri=uri,
db=db,
coll=coll,
filter=filter,
projection=projection,
extra_client_params=extra_client_params,
bucket_auto=bucket_auto,
)
[docs] def expand(self, pcoll):
return pcoll | iobase.Read(self._mongo_source)
class _ObjectIdRangeTracker(OrderedPositionRangeTracker):
"""RangeTracker for tracking mongodb _id of bson ObjectId type."""
def position_to_fraction(
self,
pos: ObjectId,
start: ObjectId,
end: ObjectId,
):
"""Returns the fraction of keys in the range [start, end) that
are less than the given key.
"""
pos_number = _ObjectIdHelper.id_to_int(pos)
start_number = _ObjectIdHelper.id_to_int(start)
end_number = _ObjectIdHelper.id_to_int(end)
return (pos_number - start_number) / (end_number - start_number)
def fraction_to_position(
self,
fraction: float,
start: ObjectId,
end: ObjectId,
):
"""Converts a fraction between 0 and 1
to a position between start and end.
"""
start_number = _ObjectIdHelper.id_to_int(start)
end_number = _ObjectIdHelper.id_to_int(end)
total = end_number - start_number
pos = int(total * fraction + start_number)
# make sure split position is larger than start position and smaller than
# end position.
if pos <= start_number:
return _ObjectIdHelper.increment_id(start, 1)
if pos >= end_number:
return _ObjectIdHelper.increment_id(end, -1)
return _ObjectIdHelper.int_to_id(pos)
class _BoundedMongoSource(iobase.BoundedSource):
"""A MongoDB source that reads a finite amount of input records.
This class defines following operations which can be used to read
MongoDB source efficiently.
* Size estimation - method ``estimate_size()`` may return an accurate
estimation in bytes for the size of the source.
* Splitting into bundles of a given size - method ``split()`` can be used to
split the source into a set of sub-sources (bundles) based on a desired
bundle size.
* Getting a RangeTracker - method ``get_range_tracker()`` should return a
``RangeTracker`` object for a given position range for the position type
of the records returned by the source.
* Reading the data - method ``read()`` can be used to read data from the
source while respecting the boundaries defined by a given
``RangeTracker``.
A runner will perform reading the source in two steps.
(1) Method ``get_range_tracker()`` will be invoked with start and end
positions to obtain a ``RangeTracker`` for the range of positions the
runner intends to read. Source must define a default initial start and end
position range. These positions must be used if the start and/or end
positions passed to the method ``get_range_tracker()`` are ``None``
(2) Method read() will be invoked with the ``RangeTracker`` obtained in the
previous step.
**Mutability**
A ``_BoundedMongoSource`` object should not be mutated while
its methods (for example, ``read()``) are being invoked by a runner. Runner
implementations may invoke methods of ``_BoundedMongoSource`` objects through
multi-threaded and/or reentrant execution modes.
"""
def __init__(
self,
uri=None,
db=None,
coll=None,
filter=None,
projection=None,
extra_client_params=None,
bucket_auto=False,
):
if extra_client_params is None:
extra_client_params = {}
if filter is None:
filter = {}
self.uri = uri
self.db = db
self.coll = coll
self.filter = filter
self.projection = projection
self.spec = extra_client_params
self.bucket_auto = bucket_auto
def estimate_size(self):
with MongoClient(self.uri, **self.spec) as client:
return client[self.db].command("collstats", self.coll).get("size")
def _estimate_average_document_size(self):
with MongoClient(self.uri, **self.spec) as client:
return client[self.db].command("collstats", self.coll).get("avgObjSize")
def split(
self,
desired_bundle_size: int,
start_position: Union[int, str, bytes, ObjectId] = None,
stop_position: Union[int, str, bytes, ObjectId] = None,
):
"""Splits the source into a set of bundles.
Bundles should be approximately of size ``desired_bundle_size`` bytes.
Args:
desired_bundle_size: the desired size (in bytes) of the bundles returned.
start_position: if specified the given position must be used as the
starting position of the first bundle.
stop_position: if specified the given position must be used as the ending
position of the last bundle.
Returns:
an iterator of objects of type 'SourceBundle' that gives information about
the generated bundles.
"""
desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
# for desired bundle size, if desired chunk size smaller than 1mb, use
# MongoDB default split size of 1mb.
desired_bundle_size_in_mb = max(desired_bundle_size_in_mb, 1)
is_initial_split = start_position is None and stop_position is None
start_position, stop_position = self._replace_none_positions(
start_position, stop_position
)
if self.bucket_auto:
# Use $bucketAuto for bundling
split_keys = []
weights = []
for bucket in self._get_auto_buckets(
desired_bundle_size_in_mb,
start_position,
stop_position,
is_initial_split,
):
split_keys.append({"_id": bucket["_id"]["max"]})
weights.append(bucket["count"])
else:
# Use splitVector for bundling
split_keys = self._get_split_keys(
desired_bundle_size_in_mb, start_position, stop_position)
weights = itertools.cycle((desired_bundle_size_in_mb, ))
bundle_start = start_position
for split_key_id, weight in zip(split_keys, weights):
if bundle_start >= stop_position:
break
bundle_end = min(stop_position, split_key_id["_id"])
yield iobase.SourceBundle(
weight=weight,
source=self,
start_position=bundle_start,
stop_position=bundle_end,
)
bundle_start = bundle_end
# add range of last split_key to stop_position
if bundle_start < stop_position:
# bucket_auto mode can come here if not split due to single document
weight = 1 if self.bucket_auto else desired_bundle_size_in_mb
yield iobase.SourceBundle(
weight=weight,
source=self,
start_position=bundle_start,
stop_position=stop_position,
)
def get_range_tracker(
self,
start_position: Union[int, str, ObjectId] = None,
stop_position: Union[int, str, ObjectId] = None,
) -> Union[
_ObjectIdRangeTracker, OffsetRangeTracker, LexicographicKeyRangeTracker]:
"""Returns a RangeTracker for a given position range depending on type.
Args:
start_position: starting position of the range. If 'None' default start
position of the source must be used.
stop_position: ending position of the range. If 'None' default stop
position of the source must be used.
Returns:
a ``_ObjectIdRangeTracker``, ``OffsetRangeTracker``
or ``LexicographicKeyRangeTracker`` depending on the given position range.
"""
start_position, stop_position = self._replace_none_positions(
start_position, stop_position
)
if isinstance(start_position, ObjectId):
return _ObjectIdRangeTracker(start_position, stop_position)
if isinstance(start_position, int):
return OffsetRangeTracker(start_position, stop_position)
if isinstance(start_position, str):
return LexicographicKeyRangeTracker(start_position, stop_position)
raise NotImplementedError(
f"RangeTracker for {type(start_position)} not implemented!")
def read(self, range_tracker):
"""Returns an iterator that reads data from the source.
The returned set of data must respect the boundaries defined by the given
``RangeTracker`` object. For example:
* Returned set of data must be for the range
``[range_tracker.start_position, range_tracker.stop_position)``. Note
that a source may decide to return records that start after
``range_tracker.stop_position``. See documentation in class
``RangeTracker`` for more details. Also, note that framework might
invoke ``range_tracker.try_split()`` to perform dynamic split
operations. range_tracker.stop_position may be updated
dynamically due to successful dynamic split operations.
* Method ``range_tracker.try_split()`` must be invoked for every record
that starts at a split point.
* Method ``range_tracker.record_current_position()`` may be invoked for
records that do not start at split points.
Args:
range_tracker: a ``RangeTracker`` whose boundaries must be respected
when reading data from the source. A runner that reads this
source muss pass a ``RangeTracker`` object that is not
``None``.
Returns:
an iterator of data read by the source.
"""
with MongoClient(self.uri, **self.spec) as client:
all_filters = self._merge_id_filter(
range_tracker.start_position(), range_tracker.stop_position())
docs_cursor = (
client[self.db][self.coll].find(
filter=all_filters,
projection=self.projection).sort([("_id", ASCENDING)]))
for doc in docs_cursor:
if not range_tracker.try_claim(doc["_id"]):
return
yield doc
def display_data(self):
"""Returns the display data associated to a pipeline component."""
res = super().display_data()
res["database"] = self.db
res["collection"] = self.coll
res["filter"] = json.dumps(self.filter, default=json_util.default)
res["projection"] = str(self.projection)
res["bucket_auto"] = self.bucket_auto
return res
@staticmethod
def _range_is_not_splittable(
start_pos: Union[int, str, ObjectId],
end_pos: Union[int, str, ObjectId],
):
"""Return `True` if splitting range doesn't make sense
(single document is not splittable),
Return `False` otherwise.
"""
return ((
isinstance(start_pos, ObjectId) and
start_pos >= _ObjectIdHelper.increment_id(end_pos, -1)) or
(isinstance(start_pos, int) and start_pos >= end_pos - 1) or
(isinstance(start_pos, str) and start_pos >= end_pos))
def _get_split_keys(
self,
desired_chunk_size_in_mb: int,
start_pos: Union[int, str, ObjectId],
end_pos: Union[int, str, ObjectId],
):
"""Calls MongoDB `splitVector` command
to get document ids at split position.
"""
# single document not splittable
if self._range_is_not_splittable(start_pos, end_pos):
return []
with MongoClient(self.uri, **self.spec) as client:
name_space = "%s.%s" % (self.db, self.coll)
return client[self.db].command(
"splitVector",
name_space,
keyPattern={"_id": 1}, # Ascending index
min={"_id": start_pos},
max={"_id": end_pos},
maxChunkSize=desired_chunk_size_in_mb,
)["splitKeys"]
def _get_auto_buckets(
self,
desired_chunk_size_in_mb: int,
start_pos: Union[int, str, ObjectId],
end_pos: Union[int, str, ObjectId],
is_initial_split: bool,
) -> list:
"""Use MongoDB `$bucketAuto` aggregation to split collection into bundles
instead of `splitVector` command, which does not work with MongoDB Atlas.
"""
# single document not splittable
if self._range_is_not_splittable(start_pos, end_pos):
return []
if is_initial_split and not self.filter:
# total collection size in MB
size_in_mb = self.estimate_size() / float(1 << 20)
else:
# size of documents within start/end id range and possibly filtered
documents_count = self._count_id_range(start_pos, end_pos)
avg_document_size = self._estimate_average_document_size()
size_in_mb = documents_count * avg_document_size / float(1 << 20)
if size_in_mb == 0:
# no documents not splittable (maybe a result of filtering)
return []
bucket_count = math.ceil(size_in_mb / desired_chunk_size_in_mb)
with beam.io.mongodbio.MongoClient(self.uri, **self.spec) as client:
pipeline = [
{
# filter by positions and by the custom filter if any
"$match": self._merge_id_filter(start_pos, end_pos)
},
{
"$bucketAuto": {
"groupBy": "$_id", "buckets": bucket_count
}
},
]
buckets = list(
# Use `allowDiskUse` option to avoid aggregation limit of 100 Mb RAM
client[self.db][self.coll].aggregate(pipeline, allowDiskUse=True))
if buckets:
buckets[-1]["_id"]["max"] = end_pos
return buckets
def _merge_id_filter(
self,
start_position: Union[int, str, bytes, ObjectId],
stop_position: Union[int, str, bytes, ObjectId] = None,
) -> dict:
"""Merge the default filter (if any) with refined _id field range
of range_tracker.
$gte specifies start position (inclusive)
and $lt specifies the end position (exclusive),
see more at
https://docs.mongodb.com/manual/reference/operator/query/gte/ and
https://docs.mongodb.com/manual/reference/operator/query/lt/
"""
if stop_position is None:
id_filter = {"_id": {"$gte": start_position}}
else:
id_filter = {"_id": {"$gte": start_position, "$lt": stop_position}}
if self.filter:
all_filters = {
# see more at
# https://docs.mongodb.com/manual/reference/operator/query/and/
"$and": [self.filter.copy(), id_filter]
}
else:
all_filters = id_filter
return all_filters
def _get_head_document_id(self, sort_order):
with MongoClient(self.uri, **self.spec) as client:
cursor = (
client[self.db][self.coll].find(filter={}, projection=[]).sort([
("_id", sort_order)
]).limit(1))
try:
return cursor[0]["_id"]
except IndexError:
raise ValueError("Empty Mongodb collection")
def _replace_none_positions(self, start_position, stop_position):
if start_position is None:
start_position = self._get_head_document_id(ASCENDING)
if stop_position is None:
last_doc_id = self._get_head_document_id(DESCENDING)
# increment last doc id binary value by 1 to make sure the last document
# is not excluded
if isinstance(last_doc_id, ObjectId):
stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
elif isinstance(last_doc_id, int):
stop_position = last_doc_id + 1
elif isinstance(last_doc_id, str):
stop_position = last_doc_id + '\x00'
return start_position, stop_position
def _count_id_range(self, start_position, stop_position):
"""Number of documents between start_position (inclusive)
and stop_position (exclusive), respecting the custom filter if any.
"""
with MongoClient(self.uri, **self.spec) as client:
return client[self.db][self.coll].count_documents(
filter=self._merge_id_filter(start_position, stop_position))
class _ObjectIdHelper:
"""A Utility class to manipulate bson object ids."""
@classmethod
def id_to_int(cls, _id: Union[int, ObjectId]) -> int:
"""
Args:
_id: ObjectId required for each MongoDB document _id field.
Returns: Converted integer value of ObjectId's 12 bytes binary value.
"""
if isinstance(_id, int):
return _id
# converts object id binary to integer
# id object is bytes type with size of 12
ints = struct.unpack(">III", _id.binary)
return (ints[0] << 64) + (ints[1] << 32) + ints[2]
@classmethod
def int_to_id(cls, number):
"""
Args:
number(int): The integer value to be used to convert to ObjectId.
Returns: The ObjectId that has the 12 bytes binary converted from the
integer value.
"""
# converts integer value to object id. Int value should be less than
# (2 ^ 96) so it can be convert to 12 bytes required by object id.
if number < 0 or number >= (1 << 96):
raise ValueError("number value must be within [0, %s)" % (1 << 96))
ints = [
(number & 0xFFFFFFFF0000000000000000) >> 64,
(number & 0x00000000FFFFFFFF00000000) >> 32,
number & 0x0000000000000000FFFFFFFF,
]
number_bytes = struct.pack(">III", *ints)
return ObjectId(number_bytes)
@classmethod
def increment_id(
cls,
_id: ObjectId,
inc: int,
) -> ObjectId:
"""
Increment object_id binary value by inc value and return new object id.
Args:
_id: The `_id` to change.
inc(int): The incremental int value to be added to `_id`.
Returns:
`_id` incremented by `inc` value
"""
id_number = _ObjectIdHelper.id_to_int(_id)
new_number = id_number + inc
if new_number < 0 or new_number >= (1 << 96):
raise ValueError(
"invalid incremental, inc value must be within ["
"%s, %s)" % (0 - id_number, 1 << 96 - id_number))
return _ObjectIdHelper.int_to_id(new_number)
[docs]@experimental()
class WriteToMongoDB(PTransform):
"""WriteToMongoDB is a ``PTransform`` that writes a ``PCollection`` of
mongodb document to the configured MongoDB server.
In order to make the document writes idempotent so that the bundles are
retry-able without creating duplicates, the PTransform added 2 transformations
before final write stage:
a ``GenerateId`` transform and a ``Reshuffle`` transform.::
-----------------------------------------------
Pipeline --> |GenerateId --> Reshuffle --> WriteToMongoSink|
-----------------------------------------------
(WriteToMongoDB)
The ``GenerateId`` transform adds a random and unique*_id* field to the
documents if they don't already have one, it uses the same format as MongoDB
default. The ``Reshuffle`` transform makes sure that no fusion happens between
``GenerateId`` and the final write stage transform,so that the set of
documents and their unique IDs are not regenerated if final write step is
retried due to a failure. This prevents duplicate writes of the same document
with different unique IDs.
"""
def __init__(
self,
uri="mongodb://localhost:27017",
db=None,
coll=None,
batch_size=100,
extra_client_params=None,
):
"""
Args:
uri (str): The MongoDB connection string following the URI format
db (str): The MongoDB database name
coll (str): The MongoDB collection name
batch_size(int): Number of documents per bulk_write to MongoDB,
default to 100
extra_client_params(dict): Optional `MongoClient
<https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_
parameters as keyword arguments
Returns:
:class:`~apache_beam.transforms.ptransform.PTransform`
"""
if extra_client_params is None:
extra_client_params = {}
if not isinstance(db, str):
raise ValueError("WriteToMongoDB db param must be specified as a string")
if not isinstance(coll, str):
raise ValueError(
"WriteToMongoDB coll param must be specified as a string")
self._uri = uri
self._db = db
self._coll = coll
self._batch_size = batch_size
self._spec = extra_client_params
[docs] def expand(self, pcoll):
return (
pcoll
| beam.ParDo(_GenerateObjectIdFn())
| Reshuffle()
| beam.ParDo(
_WriteMongoFn(
self._uri, self._db, self._coll, self._batch_size, self._spec)))
class _GenerateObjectIdFn(DoFn):
def process(self, element, *args, **kwargs):
# if _id field already exist we keep it as it is, otherwise the ptransform
# generates a new _id field to achieve idempotent write to mongodb.
if "_id" not in element:
# object.ObjectId() generates a unique identifier that follows mongodb
# default format, if _id is not present in document, mongodb server
# generates it with this same function upon write. However the
# uniqueness of generated id may not be guaranteed if the work load are
# distributed across too many processes. See more on the ObjectId format
# https://docs.mongodb.com/manual/reference/bson-types/#objectid.
element["_id"] = objectid.ObjectId()
yield element
class _WriteMongoFn(DoFn):
def __init__(
self, uri=None, db=None, coll=None, batch_size=100, extra_params=None):
if extra_params is None:
extra_params = {}
self.uri = uri
self.db = db
self.coll = coll
self.spec = extra_params
self.batch_size = batch_size
self.batch = []
def finish_bundle(self):
self._flush()
def process(self, element, *args, **kwargs):
self.batch.append(element)
if len(self.batch) >= self.batch_size:
self._flush()
def _flush(self):
if len(self.batch) == 0:
return
with _MongoSink(self.uri, self.db, self.coll, self.spec) as sink:
sink.write(self.batch)
self.batch = []
def display_data(self):
res = super().display_data()
res["database"] = self.db
res["collection"] = self.coll
res["batch_size"] = self.batch_size
return res
class _MongoSink:
def __init__(self, uri=None, db=None, coll=None, extra_params=None):
if extra_params is None:
extra_params = {}
self.uri = uri
self.db = db
self.coll = coll
self.spec = extra_params
self.client = None
def write(self, documents):
if self.client is None:
self.client = MongoClient(host=self.uri, **self.spec)
requests = []
for doc in documents:
# match document based on _id field, if not found in current collection,
# insert new one, otherwise overwrite it.
requests.append(
ReplaceOne(
filter={"_id": doc.get("_id", None)},
replacement=doc,
upsert=True))
resp = self.client[self.db][self.coll].bulk_write(requests)
_LOGGER.debug(
"BulkWrite to MongoDB result in nModified:%d, nUpserted:%d, "
"nMatched:%d, Errors:%s" % (
resp.modified_count,
resp.upserted_count,
resp.matched_count,
resp.bulk_api_result.get("writeErrors"),
))
def __enter__(self):
if self.client is None:
self.client = MongoClient(host=self.uri, **self.spec)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.client is not None:
self.client.close()