148 lines
4.5 KiB
Python
148 lines
4.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
#
|
|
# Copyright (C) 2020 Radim Rehurek <me@radimrehurek.com>
|
|
#
|
|
# This code is distributed under the terms and conditions
|
|
# from the MIT License (MIT).
|
|
#
|
|
"""Implements the compression layer of the ``smart_open`` library."""
|
|
import logging
|
|
import os.path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_COMPRESSOR_REGISTRY = {}
|
|
|
|
NO_COMPRESSION = 'disable'
|
|
"""Use no compression. Read/write the data as-is."""
|
|
INFER_FROM_EXTENSION = 'infer_from_extension'
|
|
"""Determine the compression to use from the file extension.
|
|
|
|
See get_supported_extensions().
|
|
"""
|
|
|
|
|
|
def get_supported_compression_types():
|
|
"""Return the list of supported compression types available to open.
|
|
|
|
See compression paratemeter to smart_open.open().
|
|
"""
|
|
return [NO_COMPRESSION, INFER_FROM_EXTENSION] + get_supported_extensions()
|
|
|
|
|
|
def get_supported_extensions():
|
|
"""Return the list of file extensions for which we have registered compressors."""
|
|
return sorted(_COMPRESSOR_REGISTRY.keys())
|
|
|
|
|
|
def register_compressor(ext, callback):
|
|
"""Register a callback for transparently decompressing files with a specific extension.
|
|
|
|
Parameters
|
|
----------
|
|
ext: str
|
|
The extension. Must include the leading period, e.g. ``.gz``.
|
|
callback: callable
|
|
The callback. It must accept two position arguments, file_obj and mode.
|
|
This function will be called when ``smart_open`` is opening a file with
|
|
the specified extension.
|
|
|
|
Examples
|
|
--------
|
|
|
|
Instruct smart_open to use the `lzma` module whenever opening a file
|
|
with a .xz extension (see README.rst for the complete example showing I/O):
|
|
|
|
>>> def _handle_xz(file_obj, mode):
|
|
... import lzma
|
|
... return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ)
|
|
>>>
|
|
>>> register_compressor('.xz', _handle_xz)
|
|
|
|
"""
|
|
if not (ext and ext[0] == '.'):
|
|
raise ValueError('ext must be a string starting with ., not %r' % ext)
|
|
ext = ext.lower()
|
|
if ext in _COMPRESSOR_REGISTRY:
|
|
logger.warning('overriding existing compression handler for %r', ext)
|
|
_COMPRESSOR_REGISTRY[ext] = callback
|
|
|
|
|
|
def tweak_close(outer, inner):
|
|
"""Ensure that closing the `outer` stream closes the `inner` stream as well.
|
|
|
|
Use this when your compression library's `close` method does not
|
|
automatically close the underlying filestream. See
|
|
https://github.com/RaRe-Technologies/smart_open/issues/630 for an
|
|
explanation why that is a problem for smart_open.
|
|
"""
|
|
outer_close = outer.close
|
|
|
|
def close_both(*args):
|
|
nonlocal inner
|
|
try:
|
|
outer_close()
|
|
finally:
|
|
if inner:
|
|
inner, fp = None, inner
|
|
fp.close()
|
|
|
|
outer.close = close_both
|
|
|
|
|
|
def _handle_bz2(file_obj, mode):
|
|
from bz2 import BZ2File
|
|
result = BZ2File(file_obj, mode)
|
|
tweak_close(result, file_obj)
|
|
return result
|
|
|
|
|
|
def _handle_gzip(file_obj, mode):
|
|
import gzip
|
|
result = gzip.GzipFile(fileobj=file_obj, mode=mode)
|
|
tweak_close(result, file_obj)
|
|
return result
|
|
|
|
|
|
def compression_wrapper(file_obj, mode, compression=INFER_FROM_EXTENSION, filename=None):
|
|
"""
|
|
Wrap `file_obj` with an appropriate [de]compression mechanism based on its file extension.
|
|
|
|
If the filename extension isn't recognized, simply return the original `file_obj` unchanged.
|
|
|
|
`file_obj` must either be a filehandle object, or a class which behaves like one.
|
|
|
|
If `filename` is specified, it will be used to extract the extension.
|
|
If not, the `file_obj.name` attribute is used as the filename.
|
|
|
|
"""
|
|
if compression == NO_COMPRESSION:
|
|
return file_obj
|
|
elif compression == INFER_FROM_EXTENSION:
|
|
try:
|
|
filename = (filename or file_obj.name).lower()
|
|
except (AttributeError, TypeError):
|
|
logger.warning(
|
|
'unable to transparently decompress %r because it '
|
|
'seems to lack a string-like .name', file_obj
|
|
)
|
|
return file_obj
|
|
_, compression = os.path.splitext(filename)
|
|
|
|
if compression in _COMPRESSOR_REGISTRY and mode.endswith('+'):
|
|
raise ValueError('transparent (de)compression unsupported for mode %r' % mode)
|
|
|
|
try:
|
|
callback = _COMPRESSOR_REGISTRY[compression]
|
|
except KeyError:
|
|
return file_obj
|
|
else:
|
|
return callback(file_obj, mode)
|
|
|
|
|
|
#
|
|
# NB. avoid using lambda here to make stack traces more readable.
|
|
#
|
|
register_compressor('.bz2', _handle_bz2)
|
|
register_compressor('.gz', _handle_gzip)
|