405 lines
16 KiB
Python
405 lines
16 KiB
Python
import functools
|
|
import pickle
|
|
from typing import Dict, Callable, Optional, TypeVar, Generic, Iterator
|
|
|
|
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
|
|
from torch.utils.data.datapipes._hook_iterator import _SnapshotState
|
|
from torch.utils.data.datapipes.utils.common import (
|
|
_deprecation_warning,
|
|
_iter_deprecated_functional_names,
|
|
_map_deprecated_functional_names,
|
|
)
|
|
from torch.utils.data.dataset import Dataset, IterableDataset
|
|
from torch.utils._import_utils import import_dill
|
|
|
|
dill = import_dill()
|
|
HAS_DILL = dill is not None
|
|
|
|
__all__ = [
|
|
"DataChunk",
|
|
"DFIterDataPipe",
|
|
"IterDataPipe",
|
|
"MapDataPipe",
|
|
]
|
|
|
|
T = TypeVar('T')
|
|
T_co = TypeVar('T_co', covariant=True)
|
|
|
|
UNTRACABLE_DATAFRAME_PIPES = ['batch', # As it returns DataChunks
|
|
'groupby', # As it returns DataChunks
|
|
'_dataframes_as_tuples', # As it unpacks DF
|
|
'trace_as_dataframe', # As it used to mark DF for tracing
|
|
]
|
|
|
|
|
|
class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
|
|
r"""
|
|
Iterable-style DataPipe.
|
|
|
|
All DataPipes that represent an iterable of data samples should subclass this.
|
|
This style of DataPipes is particularly useful when data come from a stream, or
|
|
when the number of samples is too large to fit them all in memory. ``IterDataPipe`` is lazily initialized and its
|
|
elements are computed only when ``next()`` is called on the iterator of an ``IterDataPipe``.
|
|
|
|
All subclasses should overwrite :meth:`__iter__`, which would return an
|
|
iterator of samples in this DataPipe. Calling ``__iter__`` of an ``IterDataPipe`` automatically invokes its
|
|
method ``reset()``, which by default performs no operation. When writing a custom ``IterDataPipe``, users should
|
|
override ``reset()`` if necessary. The common usages include resetting buffers, pointers,
|
|
and various state variables within the custom ``IterDataPipe``.
|
|
|
|
Note:
|
|
Only `one` iterator can be valid for each ``IterDataPipe`` at a time,
|
|
and the creation a second iterator will invalidate the first one. This constraint is necessary because
|
|
some ``IterDataPipe`` have internal buffers, whose states can become invalid if there are multiple iterators.
|
|
The code example below presents details on how this constraint looks in practice.
|
|
If you have any feedback related to this constraint, please see `GitHub IterDataPipe Single Iterator Issue`_.
|
|
|
|
These DataPipes can be invoked in two ways, using the class constructor or applying their
|
|
functional form onto an existing ``IterDataPipe`` (recommended, available to most but not all DataPipes).
|
|
You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple
|
|
operations in succession.
|
|
|
|
.. _GitHub IterDataPipe Single Iterator Issue:
|
|
https://github.com/pytorch/data/issues/45
|
|
|
|
Note:
|
|
When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
|
|
item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader`
|
|
iterator. When :attr:`num_workers > 0`, each worker process will have a
|
|
different copy of the DataPipe object, so it is often desired to configure
|
|
each copy independently to avoid having duplicate data returned from the
|
|
workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
|
|
process, returns information about the worker. It can be used in either the
|
|
dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
|
|
:attr:`worker_init_fn` option to modify each copy's behavior.
|
|
|
|
Examples:
|
|
General Usage:
|
|
>>> # xdoctest: +SKIP
|
|
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
|
|
>>> dp = IterableWrapper(range(10))
|
|
>>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
|
|
>>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended)
|
|
>>> list(map_dp_1)
|
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|
>>> list(map_dp_2)
|
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|
>>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0)
|
|
>>> list(filter_dp)
|
|
[2, 4, 6, 8, 10]
|
|
Single Iterator Constraint Example:
|
|
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
|
|
>>> source_dp = IterableWrapper(range(10))
|
|
>>> it1 = iter(source_dp)
|
|
>>> list(it1)
|
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
>>> it1 = iter(source_dp)
|
|
>>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1`
|
|
>>> next(it2)
|
|
0
|
|
>>> next(it1) # Further usage of `it1` will raise a `RunTimeError`
|
|
"""
|
|
|
|
functions: Dict[str, Callable] = {}
|
|
reduce_ex_hook: Optional[Callable] = None
|
|
getstate_hook: Optional[Callable] = None
|
|
str_hook: Optional[Callable] = None
|
|
repr_hook: Optional[Callable] = None
|
|
_valid_iterator_id: Optional[int] = None
|
|
_number_of_samples_yielded: int = 0
|
|
_snapshot_state: _SnapshotState = _SnapshotState.NotStarted
|
|
_fast_forward_iterator: Optional[Iterator] = None
|
|
|
|
def __iter__(self) -> Iterator[T_co]:
|
|
return self
|
|
|
|
def __getattr__(self, attribute_name):
|
|
if attribute_name in IterDataPipe.functions:
|
|
if attribute_name in _iter_deprecated_functional_names:
|
|
kwargs = _iter_deprecated_functional_names[attribute_name]
|
|
_deprecation_warning(**kwargs)
|
|
f = IterDataPipe.functions[attribute_name]
|
|
function = functools.partial(f, self)
|
|
functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
|
|
return function
|
|
else:
|
|
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attribute_name}")
|
|
|
|
@classmethod
|
|
def register_function(cls, function_name, function):
|
|
cls.functions[function_name] = function
|
|
|
|
@classmethod
|
|
def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
|
|
if function_name in cls.functions:
|
|
raise Exception(f"Unable to add DataPipe function name {function_name} as it is already taken")
|
|
|
|
def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
|
|
result_pipe = cls(source_dp, *args, **kwargs)
|
|
if isinstance(result_pipe, IterDataPipe):
|
|
if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
|
|
if function_name not in UNTRACABLE_DATAFRAME_PIPES:
|
|
result_pipe = result_pipe.trace_as_dataframe()
|
|
|
|
return result_pipe
|
|
|
|
function = functools.partial(
|
|
class_function, cls_to_register, enable_df_api_tracing
|
|
)
|
|
functools.update_wrapper(
|
|
wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
|
|
)
|
|
cls.functions[function_name] = function
|
|
|
|
def __getstate__(self):
|
|
"""
|
|
Serialize `lambda` functions when `dill` is available.
|
|
|
|
If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
|
|
`__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
|
|
"""
|
|
state = self.__dict__
|
|
if IterDataPipe.getstate_hook is not None:
|
|
return IterDataPipe.getstate_hook(state)
|
|
return state
|
|
|
|
def __reduce_ex__(self, *args, **kwargs):
|
|
if IterDataPipe.reduce_ex_hook is not None:
|
|
try:
|
|
return IterDataPipe.reduce_ex_hook(self)
|
|
except NotImplementedError:
|
|
pass
|
|
return super().__reduce_ex__(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def set_getstate_hook(cls, hook_fn):
|
|
if IterDataPipe.getstate_hook is not None and hook_fn is not None:
|
|
raise Exception("Attempt to override existing getstate_hook")
|
|
IterDataPipe.getstate_hook = hook_fn
|
|
|
|
@classmethod
|
|
def set_reduce_ex_hook(cls, hook_fn):
|
|
if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None:
|
|
raise Exception("Attempt to override existing reduce_ex_hook")
|
|
IterDataPipe.reduce_ex_hook = hook_fn
|
|
|
|
def __repr__(self):
|
|
if self.repr_hook is not None:
|
|
return self.repr_hook(self)
|
|
# Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
|
|
return str(self.__class__.__qualname__)
|
|
|
|
def __str__(self):
|
|
if self.str_hook is not None:
|
|
return self.str_hook(self)
|
|
# Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
|
|
return str(self.__class__.__qualname__)
|
|
|
|
def __dir__(self):
|
|
# for auto-completion in a REPL (e.g. Jupyter notebook)
|
|
return list(super().__dir__()) + list(self.functions.keys())
|
|
|
|
def reset(self) -> None:
|
|
r"""
|
|
Reset the `IterDataPipe` to the initial state.
|
|
|
|
By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities,
|
|
they may want to override this method with implementations that
|
|
may clear the buffers and reset pointers of the DataPipe.
|
|
The `reset` method is always called when `__iter__` is called as part of `hook_iterator`.
|
|
"""
|
|
pass
|
|
|
|
|
|
class DFIterDataPipe(IterDataPipe):
|
|
def _is_dfpipe(self):
|
|
return True
|
|
|
|
|
|
class MapDataPipe(Dataset[T_co], metaclass=_DataPipeMeta):
|
|
r"""
|
|
Map-style DataPipe.
|
|
|
|
All datasets that represent a map from keys to data samples should subclass this.
|
|
Subclasses should overwrite :meth:`__getitem__`, supporting fetching a
|
|
data sample for a given, unique key. Subclasses can also optionally overwrite
|
|
:meth:`__len__`, which is expected to return the size of the dataset by many
|
|
:class:`~torch.utils.data.Sampler` implementations and the default options
|
|
of :class:`~torch.utils.data.DataLoader`.
|
|
|
|
These DataPipes can be invoked in two ways, using the class constructor or applying their
|
|
functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes).
|
|
|
|
Note:
|
|
:class:`~torch.utils.data.DataLoader` by default constructs an index
|
|
sampler that yields integral indices. To make it work with a map-style
|
|
DataPipe with non-integral indices/keys, a custom sampler must be provided.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> from torchdata.datapipes.map import SequenceWrapper, Mapper
|
|
>>> dp = SequenceWrapper(range(10))
|
|
>>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form (recommended)
|
|
>>> list(map_dp_1)
|
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|
>>> map_dp_2 = Mapper(dp, lambda x: x + 1) # Using class constructor
|
|
>>> list(map_dp_2)
|
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|
>>> batch_dp = map_dp_1.batch(batch_size=2)
|
|
>>> list(batch_dp)
|
|
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
|
|
"""
|
|
|
|
functions: Dict[str, Callable] = {}
|
|
reduce_ex_hook: Optional[Callable] = None
|
|
getstate_hook: Optional[Callable] = None
|
|
str_hook: Optional[Callable] = None
|
|
repr_hook: Optional[Callable] = None
|
|
|
|
def __getattr__(self, attribute_name):
|
|
if attribute_name in MapDataPipe.functions:
|
|
if attribute_name in _map_deprecated_functional_names:
|
|
kwargs = _map_deprecated_functional_names[attribute_name]
|
|
_deprecation_warning(**kwargs)
|
|
f = MapDataPipe.functions[attribute_name]
|
|
function = functools.partial(f, self)
|
|
functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
|
|
return function
|
|
else:
|
|
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attribute_name}")
|
|
|
|
@classmethod
|
|
def register_function(cls, function_name, function):
|
|
cls.functions[function_name] = function
|
|
|
|
@classmethod
|
|
def register_datapipe_as_function(cls, function_name, cls_to_register):
|
|
if function_name in cls.functions:
|
|
raise Exception(f"Unable to add DataPipe function name {function_name} as it is already taken")
|
|
|
|
def class_function(cls, source_dp, *args, **kwargs):
|
|
result_pipe = cls(source_dp, *args, **kwargs)
|
|
return result_pipe
|
|
|
|
function = functools.partial(class_function, cls_to_register)
|
|
functools.update_wrapper(
|
|
wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
|
|
)
|
|
cls.functions[function_name] = function
|
|
|
|
def __getstate__(self):
|
|
"""
|
|
Serialize `lambda` functions when `dill` is available.
|
|
|
|
If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
|
|
`__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
|
|
"""
|
|
state = self.__dict__
|
|
if MapDataPipe.getstate_hook is not None:
|
|
return MapDataPipe.getstate_hook(state)
|
|
return state
|
|
|
|
def __reduce_ex__(self, *args, **kwargs):
|
|
if MapDataPipe.reduce_ex_hook is not None:
|
|
try:
|
|
return MapDataPipe.reduce_ex_hook(self)
|
|
except NotImplementedError:
|
|
pass
|
|
return super().__reduce_ex__(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def set_getstate_hook(cls, hook_fn):
|
|
if MapDataPipe.getstate_hook is not None and hook_fn is not None:
|
|
raise Exception("Attempt to override existing getstate_hook")
|
|
MapDataPipe.getstate_hook = hook_fn
|
|
|
|
@classmethod
|
|
def set_reduce_ex_hook(cls, hook_fn):
|
|
if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None:
|
|
raise Exception("Attempt to override existing reduce_ex_hook")
|
|
MapDataPipe.reduce_ex_hook = hook_fn
|
|
|
|
def __repr__(self):
|
|
if self.repr_hook is not None:
|
|
return self.repr_hook(self)
|
|
# Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
|
|
return str(self.__class__.__qualname__)
|
|
|
|
def __str__(self):
|
|
if self.str_hook is not None:
|
|
return self.str_hook(self)
|
|
# Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
|
|
return str(self.__class__.__qualname__)
|
|
|
|
def __dir__(self):
|
|
# for auto-completion in a REPL (e.g. Jupyter notebook)
|
|
return list(super().__dir__()) + list(self.functions.keys())
|
|
|
|
|
|
|
|
class _DataPipeSerializationWrapper:
|
|
def __init__(self, datapipe):
|
|
self._datapipe = datapipe
|
|
|
|
def __getstate__(self):
|
|
use_dill = False
|
|
try:
|
|
value = pickle.dumps(self._datapipe)
|
|
except Exception:
|
|
if HAS_DILL:
|
|
value = dill.dumps(self._datapipe)
|
|
use_dill = True
|
|
else:
|
|
raise
|
|
return (value, use_dill)
|
|
|
|
def __setstate__(self, state):
|
|
value, use_dill = state
|
|
if use_dill:
|
|
self._datapipe = dill.loads(value)
|
|
else:
|
|
self._datapipe = pickle.loads(value)
|
|
|
|
def __len__(self):
|
|
try:
|
|
return len(self._datapipe)
|
|
except Exception as e:
|
|
raise TypeError(
|
|
f"{type(self).__name__} instance doesn't have valid length"
|
|
) from e
|
|
|
|
|
|
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
|
|
def __init__(self, datapipe: IterDataPipe[T_co]):
|
|
super().__init__(datapipe)
|
|
self._datapipe_iter: Optional[Iterator[T_co]] = None
|
|
|
|
def __iter__(self) -> "_IterDataPipeSerializationWrapper":
|
|
self._datapipe_iter = iter(self._datapipe)
|
|
return self
|
|
|
|
def __next__(self) -> T_co: # type: ignore[type-var]
|
|
assert self._datapipe_iter is not None
|
|
return next(self._datapipe_iter)
|
|
|
|
|
|
class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
|
|
def __getitem__(self, idx):
|
|
return self._datapipe[idx]
|
|
|
|
|
|
class DataChunk(list, Generic[T]):
|
|
def __init__(self, items):
|
|
super().__init__(items)
|
|
self.items = items
|
|
|
|
def as_str(self, indent=''):
|
|
res = indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
|
|
return res
|
|
|
|
def __iter__(self) -> Iterator[T]:
|
|
yield from super().__iter__()
|
|
|
|
def raw_iterator(self) -> T: # type: ignore[misc]
|
|
yield from self.items
|