269 lines
9.5 KiB
Python
269 lines
9.5 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch import _prims
|
|
from torch._C import DispatchKey
|
|
from torch._higher_order_ops.utils import autograd_not_implemented
|
|
from torch._ops import HigherOrderOperator
|
|
|
|
from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for
|
|
from torch._prims_common.wrappers import backwards_not_supported
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
disable_proxy_modes_tracing,
|
|
ProxyTorchDispatchMode,
|
|
track_tensor_tree,
|
|
)
|
|
from torch.types import _device, _dtype
|
|
|
|
|
|
rngprim_namespace = "rngprims"
|
|
rngprim = torch.library.Library(rngprim_namespace, "DEF")
|
|
rngprim_impl = torch.library.Library(
|
|
rngprim_namespace, "IMPL", "CompositeExplicitAutograd"
|
|
)
|
|
rngprim_autograd_impl = torch.library.Library(rngprim_namespace, "IMPL", "Autograd")
|
|
rngprim_meta_impl = torch.library.Library(rngprim_namespace, "IMPL", "Meta")
|
|
|
|
|
|
def throw_on_non_cuda(device):
|
|
raise RuntimeError(
|
|
f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
|
|
f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
|
|
"not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
|
|
)
|
|
|
|
|
|
def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None):
|
|
rngprim.define(schema)
|
|
rngprim_impl.impl(name, impl_aten)
|
|
rngprim_meta_impl.impl(name, impl_meta)
|
|
|
|
prim_packet = getattr(torch._ops.ops.rngprims, name)
|
|
prim = prim_packet.default
|
|
if tags:
|
|
prim._tags = tags
|
|
|
|
rngprim_autograd_impl.impl(name, backwards_not_supported(prim))
|
|
|
|
for p in (prim_packet, prim):
|
|
p.__doc__ = doc
|
|
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
|
|
|
p.schema = schema
|
|
p.impl_aten = impl_aten
|
|
p.prim_meta_impl = impl_meta
|
|
|
|
|
|
# Philox rand offsets could be shared in future with other philox ops, so
|
|
# keeping these functions in global scope.
|
|
def philox_rand_offset_meta(
|
|
shape: torch.Size,
|
|
):
|
|
return _prims.TensorLike(torch.tensor(0, dtype=torch.int64))
|
|
|
|
|
|
def philox_rand_offset(
|
|
shape: torch.Size,
|
|
):
|
|
# For impl, look at the function calc_execution_policy in the file
|
|
# aten/src/ATen/native/cuda/DistributionTemplates.h. The impl was copied at
|
|
# commit hash 72aa0667bd16707d50eb8fa337092a1f5d11dfb6
|
|
numel_scalar = 1
|
|
for dim_size in shape:
|
|
numel_scalar *= dim_size
|
|
numel = torch.scalar_tensor(numel_scalar, dtype=torch.int64)
|
|
|
|
block_size = 256
|
|
unroll = 4
|
|
curand4_engine_calls = 4
|
|
device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
|
|
grid_size = (numel + block_size - 1) // block_size
|
|
grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
|
|
offset = (
|
|
(numel - 1) // (block_size * grid_size * unroll) + 1
|
|
) * curand4_engine_calls
|
|
return offset
|
|
|
|
|
|
def register_philox_rand():
|
|
name = "philox_rand"
|
|
schema = "philox_rand(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)" # noqa: B950
|
|
|
|
def _philox_rand_meta(
|
|
shape: torch.Size,
|
|
seed: torch.Tensor,
|
|
offset: torch.Tensor,
|
|
stride: Optional[Tuple[int, ...]],
|
|
device: _device,
|
|
dtype: _dtype,
|
|
):
|
|
# stride arg will be useful for distributed usecase. Currently, its unused.
|
|
assert stride is None
|
|
stride = make_contiguous_strides_for(shape)
|
|
random_values = _prims.TensorMeta(
|
|
shape=shape, strides=stride, dtype=dtype, device=device
|
|
)
|
|
offset = philox_rand_offset_meta(shape)
|
|
return (random_values, offset)
|
|
|
|
def _philox_rand(
|
|
shape: torch.Size,
|
|
seed: torch.Tensor,
|
|
offset: torch.Tensor,
|
|
stride: Optional[Tuple[int, ...]],
|
|
device: _device,
|
|
dtype: _dtype,
|
|
):
|
|
# stride arg will be useful for distributed usecase. Currently, its unused.
|
|
assert stride is None
|
|
if device.type == "cpu":
|
|
devices = []
|
|
else:
|
|
devices = [device]
|
|
|
|
if device.type != "cuda":
|
|
raise throw_on_non_cuda(device)
|
|
|
|
with torch.random.fork_rng(devices):
|
|
CUDARngStateHelper.set_torch_state_tensor(seed, offset)
|
|
random_values = torch.rand(shape, device=device, dtype=dtype)
|
|
|
|
return random_values, philox_rand_offset(shape)
|
|
|
|
register_rng_prim(
|
|
name=name,
|
|
schema=schema,
|
|
impl_aten=_philox_rand,
|
|
impl_meta=_philox_rand_meta,
|
|
doc="Philox based stateless rand operator",
|
|
tags=(torch.Tag.nondeterministic_seeded,),
|
|
)
|
|
|
|
|
|
def get_device(args, kwargs):
|
|
if kwargs.get("device"):
|
|
device = kwargs.get("device")
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
return device.type
|
|
|
|
devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)}
|
|
if any(dev == "cuda" for dev in devices):
|
|
return "cuda"
|
|
elif any(dev == "cpu" for dev in devices):
|
|
return "cpu"
|
|
return None
|
|
|
|
|
|
def register_run_and_save_rng_state_op():
|
|
run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state")
|
|
|
|
run_and_save_rng_state.py_impl(DispatchKey.Autograd)(
|
|
autograd_not_implemented(run_and_save_rng_state, deferred_error=True)
|
|
)
|
|
|
|
@run_and_save_rng_state.py_impl(DispatchKey.CUDA)
|
|
def impl_cuda(op, *args, **kwargs):
|
|
return torch.cuda.get_rng_state(), op(*args, **kwargs)
|
|
|
|
@run_and_save_rng_state.py_impl(DispatchKey.CPU)
|
|
def impl_cpu(op, *args, **kwargs):
|
|
return torch.get_rng_state(), op(*args, **kwargs)
|
|
|
|
@run_and_save_rng_state.py_impl(DispatchKey.BackendSelect)
|
|
def impl_backend_select(op, *args, **kwargs):
|
|
impl_map = {"cuda": impl_cuda, "cpu": impl_cpu}
|
|
device = get_device(args, kwargs)
|
|
assert device in impl_map, f"Backend not supported for {device}"
|
|
impl = impl_map[device]
|
|
return impl(op, *args, **kwargs)
|
|
|
|
@run_and_save_rng_state.py_impl(FakeTensorMode)
|
|
def impl_fake_tensor_mode(mode, op, *args, **kwargs):
|
|
# Check device to call the right impl
|
|
with mode:
|
|
return impl_backend_select(op, *args, **kwargs)
|
|
|
|
@run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
|
|
def impl_proxy_dispatch_mode(mode, op, *args, **kwargs):
|
|
if mode.enable_tracing:
|
|
out = impl_backend_select(op, *args, **kwargs)
|
|
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
|
|
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
|
|
out_proxy = mode.tracer.create_proxy(
|
|
"call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
|
|
)
|
|
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
|
else:
|
|
return run_and_save_rng_state(op, *args, **kwargs)
|
|
|
|
return run_and_save_rng_state
|
|
|
|
|
|
def register_run_with_rng_state_op():
|
|
run_with_rng_state = HigherOrderOperator("run_with_rng_state")
|
|
|
|
run_with_rng_state.py_impl(DispatchKey.Autograd)(
|
|
autograd_not_implemented(run_with_rng_state, deferred_error=True)
|
|
)
|
|
|
|
@run_with_rng_state.py_impl(DispatchKey.CUDA)
|
|
def impl_cuda(rng_state, op, *args, **kwargs):
|
|
current_state = torch.cuda.get_rng_state()
|
|
torch.cuda.set_rng_state(rng_state.cpu())
|
|
out = op(*args, **kwargs)
|
|
torch.cuda.set_rng_state(current_state)
|
|
return out
|
|
|
|
@run_with_rng_state.py_impl(DispatchKey.CPU)
|
|
def impl_cpu(rng_state, op, *args, **kwargs):
|
|
current_state = torch.get_rng_state()
|
|
torch.set_rng_state(rng_state)
|
|
out = op(*args, **kwargs)
|
|
torch.set_rng_state(current_state)
|
|
return out
|
|
|
|
@run_with_rng_state.py_impl(ProxyTorchDispatchMode)
|
|
def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs):
|
|
if mode.enable_tracing:
|
|
with disable_proxy_modes_tracing():
|
|
out = run_with_rng_state(rng_state, op, *args, **kwargs)
|
|
proxy_args = pytree.tree_map(
|
|
mode.tracer.unwrap_proxy, (rng_state, op, *args)
|
|
)
|
|
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
|
|
out_proxy = mode.tracer.create_proxy(
|
|
"call_function", run_with_rng_state, proxy_args, proxy_kwargs
|
|
)
|
|
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
|
else:
|
|
return run_with_rng_state(rng_state, op, *args, **kwargs)
|
|
|
|
@run_with_rng_state.py_impl(DispatchKey.BackendSelect)
|
|
def impl_backend_select(rng_state, op, *args, **kwargs):
|
|
impl_map = {"cuda": impl_cuda, "cpu": impl_cpu}
|
|
device = get_device(args, kwargs)
|
|
assert device in impl_map, f"Backend not supported for {device}"
|
|
impl = impl_map[device]
|
|
return impl(rng_state, op, *args, **kwargs)
|
|
|
|
@run_with_rng_state.py_impl(FakeTensorMode)
|
|
def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs):
|
|
# Skip setting the set_rng_state as it does not work well with fake tensors.
|
|
# And it does not matter for the fake tensor mode.
|
|
with mode:
|
|
return op(*args, **kwargs)
|
|
|
|
return run_with_rng_state
|
|
|
|
|
|
run_and_save_rng_state = register_run_and_save_rng_state_op()
|
|
run_with_rng_state = register_run_with_rng_state_op()
|
|
|
|
|
|
def register_rng_prims():
|
|
register_philox_rand()
|