
148 lines
4.5 KiB

# -*- coding: utf-8 -*-
# Copyright (C) 2020 Radim Rehurek <>
# This code is distributed under the terms and conditions
# from the MIT License (MIT).
"""Implements the compression layer of the ``smart_open`` library."""
import logging
import os.path
logger = logging.getLogger(__name__)
NO_COMPRESSION = 'disable'
"""Use no compression. Read/write the data as-is."""
INFER_FROM_EXTENSION = 'infer_from_extension'
"""Determine the compression to use from the file extension.
See get_supported_extensions().
def get_supported_compression_types():
"""Return the list of supported compression types available to open.
See compression paratemeter to
return [NO_COMPRESSION, INFER_FROM_EXTENSION] + get_supported_extensions()
def get_supported_extensions():
"""Return the list of file extensions for which we have registered compressors."""
return sorted(_COMPRESSOR_REGISTRY.keys())
def register_compressor(ext, callback):
"""Register a callback for transparently decompressing files with a specific extension.
ext: str
The extension. Must include the leading period, e.g. ``.gz``.
callback: callable
The callback. It must accept two position arguments, file_obj and mode.
This function will be called when ``smart_open`` is opening a file with
the specified extension.
Instruct smart_open to use the `lzma` module whenever opening a file
with a .xz extension (see README.rst for the complete example showing I/O):
>>> def _handle_xz(file_obj, mode):
... import lzma
... return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ)
>>> register_compressor('.xz', _handle_xz)
if not (ext and ext[0] == '.'):
raise ValueError('ext must be a string starting with ., not %r' % ext)
ext = ext.lower()
logger.warning('overriding existing compression handler for %r', ext)
_COMPRESSOR_REGISTRY[ext] = callback
def tweak_close(outer, inner):
"""Ensure that closing the `outer` stream closes the `inner` stream as well.
Use this when your compression library's `close` method does not
automatically close the underlying filestream. See for an
explanation why that is a problem for smart_open.
outer_close = outer.close
def close_both(*args):
nonlocal inner
if inner:
inner, fp = None, inner
outer.close = close_both
def _handle_bz2(file_obj, mode):
from bz2 import BZ2File
result = BZ2File(file_obj, mode)
tweak_close(result, file_obj)
return result
def _handle_gzip(file_obj, mode):
import gzip
result = gzip.GzipFile(fileobj=file_obj, mode=mode)
tweak_close(result, file_obj)
return result
def compression_wrapper(file_obj, mode, compression=INFER_FROM_EXTENSION, filename=None):
Wrap `file_obj` with an appropriate [de]compression mechanism based on its file extension.
If the filename extension isn't recognized, simply return the original `file_obj` unchanged.
`file_obj` must either be a filehandle object, or a class which behaves like one.
If `filename` is specified, it will be used to extract the extension.
If not, the `` attribute is used as the filename.
if compression == NO_COMPRESSION:
return file_obj
elif compression == INFER_FROM_EXTENSION:
filename = (filename or
except (AttributeError, TypeError):
'unable to transparently decompress %r because it '
'seems to lack a string-like .name', file_obj
return file_obj
_, compression = os.path.splitext(filename)
if compression in _COMPRESSOR_REGISTRY and mode.endswith('+'):
raise ValueError('transparent (de)compression unsupported for mode %r' % mode)
callback = _COMPRESSOR_REGISTRY[compression]
except KeyError:
return file_obj
return callback(file_obj, mode)
# NB. avoid using lambda here to make stack traces more readable.
register_compressor('.bz2', _handle_bz2)
register_compressor('.gz', _handle_gzip)