ai-content-maker/.venv/Lib/site-packages/torch/_library/abstract_impl.py

207 lines
7.7 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import contextlib
import functools
import warnings
from typing import Callable, Optional
import torch
from torch._library.utils import Kernel, RegistrationHandle
class AbstractImplHolder:
"""A holder where one can register an abstract impl to."""
def __init__(self, qualname: str):
self.qualname: str = qualname
self.kernel: Optional[Kernel] = None
self.lib: Optional[torch.library.Library] = None
def register(self, func: Callable, source: str) -> RegistrationHandle:
"""Register an abstract impl.
Returns a RegistrationHandle that one can use to de-register this
abstract impl.
"""
if self.kernel is not None:
raise RuntimeError(
f"impl_abstract(...): the operator {self.qualname} "
f"already has an abstract impl registered at "
f"{self.kernel.source}."
)
if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
raise RuntimeError(
f"impl_abstract(...): the operator {self.qualname} "
f"already has an DispatchKey::Meta implementation via a "
f"pre-existing torch.library or TORCH_LIBRARY registration. "
f"Please either remove that registration or don't call "
f"impl_abstract."
)
if torch._C._dispatch_has_kernel_for_dispatch_key(
self.qualname, "CompositeImplicitAutograd"
):
raise RuntimeError(
f"impl_abstract(...): the operator {self.qualname} "
f"already has an implementation for this device type via a "
f"pre-existing registration to "
f"DispatchKey::CompositeImplicitAutograd."
f"CompositeImplicitAutograd operators do not need an abstract "
f"impl; "
f"instead, the operator will decompose into its constituents "
f"and those "
f"can have abstract impls defined on them."
)
# Store the kernel in this holder
self.kernel = Kernel(func, source)
# Also register the abstract impl to Meta key
if self.lib is None:
ns = self.qualname.split("::")[0]
self.lib = torch.library.Library(ns, "FRAGMENT")
meta_kernel = construct_meta_kernel(self.qualname, self)
self.lib.impl(self.qualname, meta_kernel, "Meta")
def deregister_abstract_impl():
if self.lib:
self.lib._destroy()
self.lib = None
self.kernel = None
return RegistrationHandle(deregister_abstract_impl)
def construct_meta_kernel(
qualname: str, abstract_impl_holder: AbstractImplHolder
) -> Callable:
assert abstract_impl_holder.kernel is not None
@functools.wraps(abstract_impl_holder.kernel.func)
def meta_kernel(*args, **kwargs):
assert abstract_impl_holder.kernel is not None
source = abstract_impl_holder.kernel.source
def error_on_ctx():
raise RuntimeError(
f"Attempted to call get_ctx() for the meta implementation "
f"for {qualname} (implemented at {source})"
f"You have presumably called get_ctx() because the operator "
f"has a data-dependent output shape; if so, there is no "
f"such meta implementation and this error is the correct "
f"behavior."
)
with set_ctx_getter(error_on_ctx):
return abstract_impl_holder.kernel(*args, **kwargs)
return meta_kernel
def get_none():
return None
global_ctx_getter: Callable = get_none
@contextlib.contextmanager
def set_ctx_getter(ctx_getter):
global global_ctx_getter
prev = global_ctx_getter
try:
global_ctx_getter = ctx_getter
yield
finally:
global_ctx_getter = prev
class AbstractImplCtx:
"""
Context object for writing abstract implementations for custom operators.
"""
def __init__(self, _shape_env, _op):
self._shape_env = _shape_env
self._op = _op
def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt:
warnings.warn(
"create_unbacked_symint is deprecated, please use new_dynamic_size instead"
)
return self.new_dynamic_size(min=min, max=max)
def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt:
"""Constructs a new symint (symbolic int) representing a data-dependent value.
This is useful for writing the abstract implementation (which is necessary
for torch.compile) for a CustomOp where an output Tensor has a size
that depends on the data of the input Tensors.
Args:
min (int): A statically known inclusive lower bound for this symint. Default: 0
max (Optional[int]): A statically known inclusive upper bound for this
symint. Default: None
.. warning:
It is important that the ``min`` and ``max`` (if not None) values are set
correctly, otherwise, there will be undefined behavior under
torch.compile. The default value of ``min`` is 2 due to torch.compile
specializing on 0/1 sizes.
You must also verify that your implementation on concrete Tensors
(e.g. CPU/CUDA) only returns Tensors where the size that corresponds
to the symint also has respects these constraint.
The easiest way to do this is to add an assertion in the CPU/CUDA/etc
implementation that the size follows these bounds.
Example::
>>> # An operator with data-dependent output shape
>>> lib = torch.library.Library("mymodule", "FRAGMENT")
>>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor")
>>>
>>> @torch.library.impl_abstract("mymodule::custom_nonzero")
>>> def custom_nonzero_abstract(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an abstract impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>> ctx = torch.library.get_ctx()
>>> nnz = ctx.new_dynamic_size()
>>> shape = [nnz, x.dim()]
>>> result = x.new_empty(shape, dtype=torch.int64)
>>> return result
>>>
>>> @torch.library.impl(lib, "custom_nonzero", "CPU")
>>> def custom_nonzero_cpu(x):
>>> x_np = x.numpy()
>>> res = np.stack(np.nonzero(x_np), axis=1)
>>> return torch.tensor(res, device=x.device)
"""
if (
self._shape_env is None
or not self._shape_env.allow_dynamic_output_shape_ops
):
raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op)
if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt):
raise ValueError(
f"ctx.new_dynamic_size(min={min}, max={max}): expected "
f"min and max to be statically known ints but got SymInt. "
f"This is not supported."
)
if min < 0:
raise ValueError(
f"ctx.new_dynamic_size(min={min}, ...): expected min to be "
f"greater than or equal to 0: this API can only create "
f"non-negative sizes."
)
result = self._shape_env.create_unbacked_symint()
torch.fx.experimental.symbolic_shapes._constrain_range_for_size(
result, min=min, max=max
)
return result