Source code for apache_beam.io.filebasedsink

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""File-based sink."""

from __future__ import absolute_import

import logging
import os
import re
import time
import uuid

from apache_beam.internal import util
from apache_beam.io import iobase
from apache_beam.io.filesystem import BeamIOError
from apache_beam.io.filesystem import CompressionTypes
from apache_beam.io.filesystems import FileSystems
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.options.value_provider import ValueProvider
from apache_beam.options.value_provider import StaticValueProvider
from apache_beam.options.value_provider import check_accessible

DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN'

__all__ = ['FileBasedSink']


[docs]class FileBasedSink(iobase.Sink): """A sink to a GCS or local files. To implement a file-based sink, extend this class and override either ``write_record()`` or ``write_encoded_record()``. If needed, also overwrite ``open()`` and/or ``close()`` to customize the file handling or write headers and footers. The output of this write is a PCollection of all written shards. """ # Max number of threads to be used for renaming. _MAX_RENAME_THREADS = 64 def __init__(self, file_path_prefix, coder, file_name_suffix='', num_shards=0, shard_name_template=None, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): """ Raises: TypeError: if file path parameters are not a string or ValueProvider, or if compression_type is not member of CompressionTypes. ValueError: if shard_name_template is not of expected format. """ if not isinstance(file_path_prefix, (basestring, ValueProvider)): raise TypeError('file_path_prefix must be a string or ValueProvider;' 'got %r instead' % file_path_prefix) if not isinstance(file_name_suffix, (basestring, ValueProvider)): raise TypeError('file_name_suffix must be a string or ValueProvider;' 'got %r instead' % file_name_suffix) if not CompressionTypes.is_valid_compression_type(compression_type): raise TypeError('compression_type must be CompressionType object but ' 'was %s' % type(compression_type)) if shard_name_template is None: shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE elif shard_name_template == '': num_shards = 1 if isinstance(file_path_prefix, basestring): file_path_prefix = StaticValueProvider(str, file_path_prefix) if isinstance(file_name_suffix, basestring): file_name_suffix = StaticValueProvider(str, file_name_suffix) self.file_path_prefix = file_path_prefix self.file_name_suffix = file_name_suffix self.num_shards = num_shards self.coder = coder self.shard_name_format = self._template_to_format(shard_name_template) self.compression_type = compression_type self.mime_type = mime_type
[docs] def display_data(self): return {'shards': DisplayDataItem(self.num_shards, label='Number of Shards').drop_if_default(0), 'compression': DisplayDataItem(str(self.compression_type)), 'file_pattern': DisplayDataItem('{}{}{}'.format(self.file_path_prefix, self.shard_name_format, self.file_name_suffix), label='File Pattern')}
@check_accessible(['file_path_prefix'])
[docs] def open(self, temp_path): """Opens ``temp_path``, returning an opaque file handle object. The returned file handle is passed to ``write_[encoded_]record`` and ``close``. """ return FileSystems.create(temp_path, self.mime_type, self.compression_type)
[docs] def write_record(self, file_handle, value): """Writes a single record go the file handle returned by ``open()``. By default, calls ``write_encoded_record`` after encoding the record with this sink's Coder. """ self.write_encoded_record(file_handle, self.coder.encode(value))
[docs] def write_encoded_record(self, file_handle, encoded_value): """Writes a single encoded record to the file handle returned by ``open()``. """ raise NotImplementedError
[docs] def close(self, file_handle): """Finalize and close the file handle returned from ``open()``. Called after all records are written. By default, calls ``file_handle.close()`` iff it is not None. """ if file_handle is not None: file_handle.close()
@check_accessible(['file_path_prefix', 'file_name_suffix'])
[docs] def initialize_write(self): file_path_prefix = self.file_path_prefix.get() tmp_dir = self._create_temp_dir(file_path_prefix) FileSystems.mkdirs(tmp_dir) return tmp_dir
def _create_temp_dir(self, file_path_prefix): base_path, last_component = FileSystems.split(file_path_prefix) if not last_component: # Trying to re-split the base_path to check if it's a root. new_base_path, _ = FileSystems.split(base_path) if base_path == new_base_path: raise ValueError('Cannot create a temporary directory for root path ' 'prefix %s. Please specify a file path prefix with ' 'at least two components.', file_path_prefix) path_components = [base_path, 'beam-temp-' + last_component + '-' + uuid.uuid1().hex] return FileSystems.join(*path_components) @check_accessible(['file_path_prefix', 'file_name_suffix'])
[docs] def open_writer(self, init_result, uid): # A proper suffix is needed for AUTO compression detection. # We also ensure there will be no collisions with uid and a # (possibly unsharded) file_path_prefix and a (possibly empty) # file_name_suffix. file_path_prefix = self.file_path_prefix.get() file_name_suffix = self.file_name_suffix.get() suffix = ( '.' + os.path.basename(file_path_prefix) + file_name_suffix) return FileBasedSinkWriter(self, os.path.join(init_result, uid) + suffix)
@check_accessible(['file_path_prefix', 'file_name_suffix'])
[docs] def finalize_write(self, init_result, writer_results): file_path_prefix = self.file_path_prefix.get() file_name_suffix = self.file_name_suffix.get() writer_results = sorted(writer_results) num_shards = len(writer_results) min_threads = min(num_shards, FileBasedSink._MAX_RENAME_THREADS) num_threads = max(1, min_threads) source_files = [] destination_files = [] chunk_size = FileSystems.get_chunk_size(file_path_prefix) for shard_num, shard in enumerate(writer_results): final_name = ''.join([ file_path_prefix, self.shard_name_format % dict( shard_num=shard_num, num_shards=num_shards), file_name_suffix ]) source_files.append(shard) destination_files.append(final_name) source_file_batch = [source_files[i:i + chunk_size] for i in xrange(0, len(source_files), chunk_size)] destination_file_batch = [destination_files[i:i + chunk_size] for i in xrange(0, len(destination_files), chunk_size)] logging.info( 'Starting finalize_write threads with num_shards: %d, ' 'batches: %d, num_threads: %d', num_shards, len(source_file_batch), num_threads) start_time = time.time() # Use a thread pool for renaming operations. def _rename_batch(batch): """_rename_batch executes batch rename operations.""" source_files, destination_files = batch exceptions = [] try: FileSystems.rename(source_files, destination_files) return exceptions except BeamIOError as exp: if exp.exception_details is None: raise for (src, dest), exception in exp.exception_details.iteritems(): if exception: logging.warning('Rename not successful: %s -> %s, %s', src, dest, exception) should_report = True if isinstance(exception, IOError): # May have already been copied. try: if FileSystems.exists(dest): should_report = False except Exception as exists_e: # pylint: disable=broad-except logging.warning('Exception when checking if file %s exists: ' '%s', dest, exists_e) if should_report: logging.warning(('Exception in _rename_batch. src: %s, ' 'dest: %s, err: %s'), src, dest, exception) exceptions.append(exception) else: logging.debug('Rename successful: %s -> %s', src, dest) return exceptions exception_batches = util.run_using_threadpool( _rename_batch, zip(source_file_batch, destination_file_batch), num_threads) all_exceptions = [e for exception_batch in exception_batches for e in exception_batch] if all_exceptions: raise Exception('Encountered exceptions in finalize_write: %s', all_exceptions) for final_name in destination_files: yield final_name logging.info('Renamed %d shards in %.2f seconds.', num_shards, time.time() - start_time) try: FileSystems.delete([init_result]) except IOError: # May have already been removed. pass
@staticmethod def _template_to_format(shard_name_template): if not shard_name_template: return '' m = re.search('S+', shard_name_template) if m is None: raise ValueError("Shard number pattern S+ not found in template '%s'" % shard_name_template) shard_name_format = shard_name_template.replace( m.group(0), '%%(shard_num)0%dd' % len(m.group(0))) m = re.search('N+', shard_name_format) if m: shard_name_format = shard_name_format.replace( m.group(0), '%%(num_shards)0%dd' % len(m.group(0))) return shard_name_format def __eq__(self, other): # TODO: Clean up workitem_test which uses this. # pylint: disable=unidiomatic-typecheck return type(self) == type(other) and self.__dict__ == other.__dict__
class FileBasedSinkWriter(iobase.Writer): """The writer for FileBasedSink. """ def __init__(self, sink, temp_shard_path): self.sink = sink self.temp_shard_path = temp_shard_path self.temp_handle = self.sink.open(temp_shard_path) def write(self, value): self.sink.write_record(self.temp_handle, value) def close(self): self.sink.close(self.temp_handle) return self.temp_shard_path