#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import datetime
import time
import pytz
from apache_beam.io.aws.clients.s3 import messages
[docs]class FakeFile(object):
def __init__(self, bucket, key, contents, etag=None):
self.bucket = bucket
self.key = key
self.contents = contents
self.last_modified = time.time()
if not etag:
self.etag = '"%s-1"' % ('x' * 32)
else:
self.etag = etag
[docs]class FakeS3Client(object):
def __init__(self):
self.files = {}
self.list_continuation_tokens = {}
self.multipart_uploads = {}
# boto3 has different behavior when running some operations against a bucket
# that exists vs. against one that doesn't. To emulate that behavior, the
# mock client keeps a set of bucket names that it knows "exist".
self.known_buckets = set()
[docs] def add_file(self, f):
self.files[(f.bucket, f.key)] = f
if f.bucket not in self.known_buckets:
self.known_buckets.add(f.bucket)
[docs] def get_file(self, bucket, obj):
try:
return self.files[bucket, obj]
except:
raise messages.S3ClientError('Not Found', 404)
[docs] def delete_file(self, bucket, obj):
del self.files[(bucket, obj)]
[docs] def list(self, request):
bucket = request.bucket
prefix = request.prefix or ''
matching_files = []
for file_bucket, file_name in sorted(iter(self.files)):
if bucket == file_bucket and file_name.startswith(prefix):
file_object = self.get_file(file_bucket, file_name).get_metadata()
matching_files.append(file_object)
if not matching_files:
message = 'Tried to list nonexistent S3 path: s3://%s/%s' % (
bucket, prefix)
raise messages.S3ClientError(message, 404)
# Handle pagination.
items_per_page = 5
if not request.continuation_token:
range_start = 0
else:
if request.continuation_token not in self.list_continuation_tokens:
raise ValueError('Invalid page token.')
range_start = self.list_continuation_tokens[request.continuation_token]
del self.list_continuation_tokens[request.continuation_token]
result = messages.ListResponse(
items=matching_files[range_start:range_start + items_per_page])
if range_start + items_per_page < len(matching_files):
next_range_start = range_start + items_per_page
next_continuation_token = '_page_token_%s_%s_%d' % (
bucket, prefix, next_range_start)
self.list_continuation_tokens[next_continuation_token] = next_range_start
result.next_token = next_continuation_token
return result
[docs] def get_range(self, request, start, end):
r"""Retrieves an object.
Args:
request: (GetRequest) request
Returns:
(bytes) The response message.
"""
file_ = self.get_file(request.bucket, request.object)
# Replicates S3's behavior, per the spec here:
# https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35
if start < 0 or end <= start:
return file_.contents
return file_.contents[start:end]
[docs] def delete(self, request):
if request.bucket not in self.known_buckets:
raise messages.S3ClientError('The specified bucket does not exist', 404)
if (request.bucket, request.object) in self.files:
self.delete_file(request.bucket, request.object)
else:
# S3 doesn't raise an error if you try to delete a nonexistent file from
# an extant bucket
return
[docs] def delete_batch(self, request):
deleted, failed, errors = [], [], []
for object in request.objects:
try:
delete_request = messages.DeleteRequest(request.bucket, object)
self.delete(delete_request)
deleted.append(object)
except messages.S3ClientError as e:
failed.append(object)
errors.append(e)
return messages.DeleteBatchResponse(deleted, failed, errors)
[docs] def copy(self, request):
src_file = self.get_file(request.src_bucket, request.src_key)
dest_file = FakeFile(
request.dest_bucket, request.dest_key, src_file.contents)
self.add_file(dest_file)
[docs] def create_multipart_upload(self, request):
# Create hash of bucket and key
# Store upload_id internally
upload_id = request.bucket + request.object
self.multipart_uploads[upload_id] = {}
return messages.UploadResponse(upload_id)
[docs] def upload_part(self, request):
# Save off bytes passed to internal data store
upload_id, part_number = request.upload_id, request.part_number
if part_number < 0 or not isinstance(part_number, int):
raise messages.S3ClientError(
'Param validation failed on part number', 400)
if upload_id not in self.multipart_uploads:
raise messages.S3ClientError('The specified upload does not exist', 404)
self.multipart_uploads[upload_id][part_number] = request.bytes
etag = '"%s"' % ('x' * 32)
return messages.UploadPartResponse(etag, part_number)
[docs] def complete_multipart_upload(self, request):
MIN_PART_SIZE = 5 * 2**10 # 5 KiB
parts_received = self.multipart_uploads[request.upload_id]
# Check that we got all the parts that they intended to send
part_numbers_to_confirm = set(part['PartNumber'] for part in request.parts)
# Make sure all the expected parts are present
if part_numbers_to_confirm != set(parts_received.keys()):
raise messages.S3ClientError(
'One or more of the specified parts could not be found', 400)
# Sort by part number
sorted_parts = sorted(parts_received.items(), key=lambda pair: pair[0])
sorted_bytes = [bytes_ for (_, bytes_) in sorted_parts]
# Make sure that the parts aren't too small (except the last part)
part_sizes = [len(bytes_) for bytes_ in sorted_bytes]
if any(size < MIN_PART_SIZE for size in part_sizes[:-1]):
e_message = """
All parts but the last must be larger than %d bytes
""" % MIN_PART_SIZE
raise messages.S3ClientError(e_message, 400)
# String together all bytes for the given upload
final_contents = b''.join(sorted_bytes)
# Create FakeFile object
num_parts = len(parts_received)
etag = '"%s-%d"' % ('x' * 32, num_parts)
file_ = FakeFile(request.bucket, request.object, final_contents, etag=etag)
# Store FakeFile in self.files
self.add_file(file_)