344 lines
11 KiB
Python
344 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
||
#
|
||
# Copyright (C) 2019 Radim Rehurek <me@radimrehurek.com>
|
||
#
|
||
# This code is distributed under the terms and conditions
|
||
# from the MIT License (MIT).
|
||
#
|
||
"""Implements file-like objects for reading from http."""
|
||
|
||
import io
|
||
import logging
|
||
import os.path
|
||
import urllib.parse
|
||
|
||
try:
|
||
import requests
|
||
except ImportError:
|
||
MISSING_DEPS = True
|
||
|
||
from smart_open import bytebuffer, constants
|
||
import smart_open.utils
|
||
|
||
DEFAULT_BUFFER_SIZE = 128 * 1024
|
||
SCHEMES = ('http', 'https')
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
_HEADERS = {'Accept-Encoding': 'identity'}
|
||
"""The headers we send to the server with every HTTP request.
|
||
|
||
For now, we ask the server to send us the files as they are.
|
||
Sometimes, servers compress the file for more efficient transfer, in which case
|
||
the client (us) has to decompress them with the appropriate algorithm.
|
||
"""
|
||
|
||
|
||
def parse_uri(uri_as_string):
|
||
split_uri = urllib.parse.urlsplit(uri_as_string)
|
||
assert split_uri.scheme in SCHEMES
|
||
|
||
uri_path = split_uri.netloc + split_uri.path
|
||
uri_path = "/" + uri_path.lstrip("/")
|
||
return dict(scheme=split_uri.scheme, uri_path=uri_path)
|
||
|
||
|
||
def open_uri(uri, mode, transport_params):
|
||
kwargs = smart_open.utils.check_kwargs(open, transport_params)
|
||
return open(uri, mode, **kwargs)
|
||
|
||
|
||
def open(uri, mode, kerberos=False, user=None, password=None, cert=None,
|
||
headers=None, timeout=None, buffer_size=DEFAULT_BUFFER_SIZE):
|
||
"""Implement streamed reader from a web site.
|
||
|
||
Supports Kerberos and Basic HTTP authentication.
|
||
|
||
Parameters
|
||
----------
|
||
url: str
|
||
The URL to open.
|
||
mode: str
|
||
The mode to open using.
|
||
kerberos: boolean, optional
|
||
If True, will attempt to use the local Kerberos credentials
|
||
user: str, optional
|
||
The username for authenticating over HTTP
|
||
password: str, optional
|
||
The password for authenticating over HTTP
|
||
cert: str/tuple, optional
|
||
if String, path to ssl client cert file (.pem). If Tuple, (‘cert’, ‘key’)
|
||
headers: dict, optional
|
||
Any headers to send in the request. If ``None``, the default headers are sent:
|
||
``{'Accept-Encoding': 'identity'}``. To use no headers at all,
|
||
set this variable to an empty dict, ``{}``.
|
||
buffer_size: int, optional
|
||
The buffer size to use when performing I/O.
|
||
|
||
Note
|
||
----
|
||
If neither kerberos or (user, password) are set, will connect
|
||
unauthenticated, unless set separately in headers.
|
||
|
||
"""
|
||
if mode == constants.READ_BINARY:
|
||
fobj = SeekableBufferedInputBase(
|
||
uri, mode, buffer_size=buffer_size, kerberos=kerberos,
|
||
user=user, password=password, cert=cert,
|
||
headers=headers, timeout=timeout,
|
||
)
|
||
fobj.name = os.path.basename(urllib.parse.urlparse(uri).path)
|
||
return fobj
|
||
else:
|
||
raise NotImplementedError('http support for mode %r not implemented' % mode)
|
||
|
||
|
||
class BufferedInputBase(io.BufferedIOBase):
|
||
def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
|
||
kerberos=False, user=None, password=None, cert=None,
|
||
headers=None, timeout=None):
|
||
if kerberos:
|
||
import requests_kerberos
|
||
auth = requests_kerberos.HTTPKerberosAuth()
|
||
elif user is not None and password is not None:
|
||
auth = (user, password)
|
||
else:
|
||
auth = None
|
||
|
||
self.buffer_size = buffer_size
|
||
self.mode = mode
|
||
|
||
if headers is None:
|
||
self.headers = _HEADERS.copy()
|
||
else:
|
||
self.headers = headers
|
||
|
||
self.timeout = timeout
|
||
|
||
self.response = requests.get(
|
||
url,
|
||
auth=auth,
|
||
cert=cert,
|
||
stream=True,
|
||
headers=self.headers,
|
||
timeout=self.timeout,
|
||
)
|
||
|
||
if not self.response.ok:
|
||
self.response.raise_for_status()
|
||
|
||
self._read_iter = self.response.iter_content(self.buffer_size)
|
||
self._read_buffer = bytebuffer.ByteBuffer(buffer_size)
|
||
self._current_pos = 0
|
||
|
||
#
|
||
# This member is part of the io.BufferedIOBase interface.
|
||
#
|
||
self.raw = None
|
||
|
||
#
|
||
# Override some methods from io.IOBase.
|
||
#
|
||
def close(self):
|
||
"""Flush and close this stream."""
|
||
logger.debug("close: called")
|
||
self.response = None
|
||
self._read_iter = None
|
||
|
||
def readable(self):
|
||
"""Return True if the stream can be read from."""
|
||
return True
|
||
|
||
def seekable(self):
|
||
return False
|
||
|
||
#
|
||
# io.BufferedIOBase methods.
|
||
#
|
||
def detach(self):
|
||
"""Unsupported."""
|
||
raise io.UnsupportedOperation
|
||
|
||
def read(self, size=-1):
|
||
"""
|
||
Mimics the read call to a filehandle object.
|
||
"""
|
||
logger.debug("reading with size: %d", size)
|
||
if self.response is None:
|
||
return b''
|
||
|
||
if size == 0:
|
||
return b''
|
||
elif size < 0 and len(self._read_buffer) == 0:
|
||
retval = self.response.raw.read()
|
||
elif size < 0:
|
||
retval = self._read_buffer.read() + self.response.raw.read()
|
||
else:
|
||
while len(self._read_buffer) < size:
|
||
logger.debug(
|
||
"http reading more content at current_pos: %d with size: %d",
|
||
self._current_pos, size,
|
||
)
|
||
bytes_read = self._read_buffer.fill(self._read_iter)
|
||
if bytes_read == 0:
|
||
# Oops, ran out of data early.
|
||
retval = self._read_buffer.read()
|
||
self._current_pos += len(retval)
|
||
|
||
return retval
|
||
|
||
# If we got here, it means we have enough data in the buffer
|
||
# to return to the caller.
|
||
retval = self._read_buffer.read(size)
|
||
|
||
self._current_pos += len(retval)
|
||
return retval
|
||
|
||
def read1(self, size=-1):
|
||
"""This is the same as read()."""
|
||
return self.read(size=size)
|
||
|
||
def readinto(self, b):
|
||
"""Read up to len(b) bytes into b, and return the number of bytes
|
||
read."""
|
||
data = self.read(len(b))
|
||
if not data:
|
||
return 0
|
||
b[:len(data)] = data
|
||
return len(data)
|
||
|
||
|
||
class SeekableBufferedInputBase(BufferedInputBase):
|
||
"""
|
||
Implement seekable streamed reader from a web site.
|
||
Supports Kerberos, client certificate and Basic HTTP authentication.
|
||
"""
|
||
|
||
def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
|
||
kerberos=False, user=None, password=None, cert=None,
|
||
headers=None, timeout=None):
|
||
"""
|
||
If Kerberos is True, will attempt to use the local Kerberos credentials.
|
||
If cert is set, will try to use a client certificate
|
||
Otherwise, will try to use "basic" HTTP authentication via username/password.
|
||
|
||
If none of those are set, will connect unauthenticated.
|
||
"""
|
||
self.url = url
|
||
|
||
if kerberos:
|
||
import requests_kerberos
|
||
self.auth = requests_kerberos.HTTPKerberosAuth()
|
||
elif user is not None and password is not None:
|
||
self.auth = (user, password)
|
||
else:
|
||
self.auth = None
|
||
|
||
if headers is None:
|
||
self.headers = _HEADERS.copy()
|
||
else:
|
||
self.headers = headers
|
||
|
||
self.cert = cert
|
||
self.timeout = timeout
|
||
|
||
self.buffer_size = buffer_size
|
||
self.mode = mode
|
||
self.response = self._partial_request()
|
||
|
||
if not self.response.ok:
|
||
self.response.raise_for_status()
|
||
|
||
logger.debug('self.response: %r, raw: %r', self.response, self.response.raw)
|
||
|
||
self.content_length = int(self.response.headers.get("Content-Length", -1))
|
||
#
|
||
# We assume the HTTP stream is seekable unless the server explicitly
|
||
# tells us it isn't. It's better to err on the side of "seekable"
|
||
# because we don't want to prevent users from seeking a stream that
|
||
# does not appear to be seekable but really is.
|
||
#
|
||
self._seekable = self.response.headers.get("Accept-Ranges", "").lower() != "none"
|
||
|
||
self._read_iter = self.response.iter_content(self.buffer_size)
|
||
self._read_buffer = bytebuffer.ByteBuffer(buffer_size)
|
||
self._current_pos = 0
|
||
|
||
#
|
||
# This member is part of the io.BufferedIOBase interface.
|
||
#
|
||
self.raw = None
|
||
|
||
def seek(self, offset, whence=0):
|
||
"""Seek to the specified position.
|
||
|
||
:param int offset: The offset in bytes.
|
||
:param int whence: Where the offset is from.
|
||
|
||
Returns the position after seeking."""
|
||
logger.debug('seeking to offset: %r whence: %r', offset, whence)
|
||
if whence not in constants.WHENCE_CHOICES:
|
||
raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES)
|
||
|
||
if not self.seekable():
|
||
raise OSError('stream is not seekable')
|
||
|
||
if whence == constants.WHENCE_START:
|
||
new_pos = offset
|
||
elif whence == constants.WHENCE_CURRENT:
|
||
new_pos = self._current_pos + offset
|
||
elif whence == constants.WHENCE_END:
|
||
new_pos = self.content_length + offset
|
||
|
||
if self.content_length == -1:
|
||
new_pos = smart_open.utils.clamp(new_pos, maxval=None)
|
||
else:
|
||
new_pos = smart_open.utils.clamp(new_pos, maxval=self.content_length)
|
||
|
||
if self._current_pos == new_pos:
|
||
return self._current_pos
|
||
|
||
logger.debug("http seeking from current_pos: %d to new_pos: %d", self._current_pos, new_pos)
|
||
|
||
self._current_pos = new_pos
|
||
|
||
if new_pos == self.content_length:
|
||
self.response = None
|
||
self._read_iter = None
|
||
self._read_buffer.empty()
|
||
else:
|
||
response = self._partial_request(new_pos)
|
||
if response.ok:
|
||
self.response = response
|
||
self._read_iter = self.response.iter_content(self.buffer_size)
|
||
self._read_buffer.empty()
|
||
else:
|
||
self.response = None
|
||
|
||
return self._current_pos
|
||
|
||
def tell(self):
|
||
return self._current_pos
|
||
|
||
def seekable(self, *args, **kwargs):
|
||
return self._seekable
|
||
|
||
def truncate(self, size=None):
|
||
"""Unsupported."""
|
||
raise io.UnsupportedOperation
|
||
|
||
def _partial_request(self, start_pos=None):
|
||
if start_pos is not None:
|
||
self.headers.update({"range": smart_open.utils.make_range_string(start_pos)})
|
||
|
||
response = requests.get(
|
||
self.url,
|
||
auth=self.auth,
|
||
stream=True,
|
||
cert=self.cert,
|
||
headers=self.headers,
|
||
timeout=self.timeout,
|
||
)
|
||
return response
|