ai-content-maker/.venv/Lib/site-packages/torch/utils/data/graph_settings.py

161 lines
5.4 KiB
Python

import inspect
import warnings
from typing import Any, List, Optional, Set
import torch
from torch.utils.data.datapipes.iter.sharding import (
_ShardingIterDataPipe,
SHARDING_PRIORITIES,
)
from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
__all__ = [
"apply_random_seed",
"apply_sharding",
"apply_shuffle_seed",
"apply_shuffle_settings",
"get_all_graph_pipes",
]
def get_all_graph_pipes(graph: DataPipeGraph) -> List[DataPipe]:
return _get_all_graph_pipes_helper(graph, set())
def _get_all_graph_pipes_helper(graph: DataPipeGraph, id_cache: Set[int]) -> List[DataPipe]:
results: List[DataPipe] = []
for dp_id, (datapipe, sub_graph) in graph.items():
if dp_id in id_cache:
continue
id_cache.add(dp_id)
results.append(datapipe)
results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache))
return results
def _is_sharding_datapipe(datapipe: DataPipe) -> bool:
if isinstance(datapipe, _ShardingIterDataPipe):
return True
if hasattr(datapipe, "apply_sharding") and inspect.ismethod(datapipe.apply_sharding):
return True
return False
def apply_sharding(datapipe: DataPipe,
num_of_instances: int,
instance_id: int,
sharding_group=SHARDING_PRIORITIES.DEFAULT) -> DataPipe:
r"""
Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``.
RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch.
"""
graph = traverse_dps(datapipe)
def _helper(graph, prev_applied=None):
for (dp, sub_graph) in graph.values():
applied = None
if _is_sharding_datapipe(dp):
if prev_applied is not None:
raise RuntimeError("Sharding twice on a single pipeline is likely unintended and will cause data loss. "
f"Sharding already applied to {prev_applied} while trying to apply to {dp}")
# For BC, only provide sharding_group if accepted
sig = inspect.signature(dp.apply_sharding)
if len(sig.parameters) < 3:
dp.apply_sharding(num_of_instances, instance_id)
else:
dp.apply_sharding(num_of_instances, instance_id, sharding_group=sharding_group)
applied = dp
if applied is None:
applied = prev_applied
_helper(sub_graph, applied)
_helper(graph)
return datapipe
def _is_shuffle_datapipe(datapipe: DataPipe) -> bool:
if not hasattr(datapipe, "set_shuffle") or not hasattr(datapipe, "set_seed"):
return False
if not inspect.ismethod(datapipe.set_shuffle) or not inspect.ismethod(datapipe.set_seed):
return False
return True
def apply_shuffle_settings(datapipe: DataPipe, shuffle: Optional[bool] = None) -> DataPipe:
r"""
Traverse the graph of ``DataPipes`` to find and set shuffle attribute.
Apply the method to each `DataPipe` that has APIs of ``set_shuffle``
and ``set_seed``.
Args:
datapipe: DataPipe that needs to set shuffle attribute
shuffle: Shuffle option (default: ``None`` and no-op to the graph)
"""
if shuffle is None:
return datapipe
graph = traverse_dps(datapipe)
all_pipes = get_all_graph_pipes(graph)
shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)]
if not shufflers and shuffle:
warnings.warn(
"`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. "
"Be aware that the default buffer size might not be sufficient for your task."
)
datapipe = datapipe.shuffle()
shufflers = [datapipe, ] # type: ignore[list-item]
for shuffler in shufflers:
shuffler.set_shuffle(shuffle)
return datapipe
def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe:
warnings.warn(
"`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases."
"\nPlease use `apply_random_seed` instead."
)
return apply_random_seed(datapipe, rng)
def _is_random_datapipe(datapipe: DataPipe) -> bool:
if hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed):
return True
return False
def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
r"""
Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of ``set_seed``.
Then set the random seed based on the provided RNG to those ``DataPipe``.
Args:
datapipe: DataPipe that needs to set randomness
rng: Random number generator to generate random seeds
"""
graph = traverse_dps(datapipe)
all_pipes = get_all_graph_pipes(graph)
# Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once.
# And, `id` is used in case of unhashable DataPipe
cache = set()
random_datapipes = []
for pipe in all_pipes:
if id(pipe) in cache:
continue
if _is_random_datapipe(pipe):
random_datapipes.append(pipe)
cache.add(id(pipe))
for pipe in random_datapipes:
random_seed = int(torch.empty((), dtype=torch.int64).random_(generator=rng).item())
pipe.set_seed(random_seed)
return datapipe