import contextlib import functools import warnings from typing import Callable, Optional import torch from torch._library.utils import Kernel, RegistrationHandle class AbstractImplHolder: """A holder where one can register an abstract impl to.""" def __init__(self, qualname: str): self.qualname: str = qualname self.kernel: Optional[Kernel] = None self.lib: Optional[torch.library.Library] = None def register(self, func: Callable, source: str) -> RegistrationHandle: """Register an abstract impl. Returns a RegistrationHandle that one can use to de-register this abstract impl. """ if self.kernel is not None: raise RuntimeError( f"impl_abstract(...): the operator {self.qualname} " f"already has an abstract impl registered at " f"{self.kernel.source}." ) if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): raise RuntimeError( f"impl_abstract(...): the operator {self.qualname} " f"already has an DispatchKey::Meta implementation via a " f"pre-existing torch.library or TORCH_LIBRARY registration. " f"Please either remove that registration or don't call " f"impl_abstract." ) if torch._C._dispatch_has_kernel_for_dispatch_key( self.qualname, "CompositeImplicitAutograd" ): raise RuntimeError( f"impl_abstract(...): the operator {self.qualname} " f"already has an implementation for this device type via a " f"pre-existing registration to " f"DispatchKey::CompositeImplicitAutograd." f"CompositeImplicitAutograd operators do not need an abstract " f"impl; " f"instead, the operator will decompose into its constituents " f"and those " f"can have abstract impls defined on them." ) # Store the kernel in this holder self.kernel = Kernel(func, source) # Also register the abstract impl to Meta key if self.lib is None: ns = self.qualname.split("::")[0] self.lib = torch.library.Library(ns, "FRAGMENT") meta_kernel = construct_meta_kernel(self.qualname, self) self.lib.impl(self.qualname, meta_kernel, "Meta") def deregister_abstract_impl(): if self.lib: self.lib._destroy() self.lib = None self.kernel = None return RegistrationHandle(deregister_abstract_impl) def construct_meta_kernel( qualname: str, abstract_impl_holder: AbstractImplHolder ) -> Callable: assert abstract_impl_holder.kernel is not None @functools.wraps(abstract_impl_holder.kernel.func) def meta_kernel(*args, **kwargs): assert abstract_impl_holder.kernel is not None source = abstract_impl_holder.kernel.source def error_on_ctx(): raise RuntimeError( f"Attempted to call get_ctx() for the meta implementation " f"for {qualname} (implemented at {source})" f"You have presumably called get_ctx() because the operator " f"has a data-dependent output shape; if so, there is no " f"such meta implementation and this error is the correct " f"behavior." ) with set_ctx_getter(error_on_ctx): return abstract_impl_holder.kernel(*args, **kwargs) return meta_kernel def get_none(): return None global_ctx_getter: Callable = get_none @contextlib.contextmanager def set_ctx_getter(ctx_getter): global global_ctx_getter prev = global_ctx_getter try: global_ctx_getter = ctx_getter yield finally: global_ctx_getter = prev class AbstractImplCtx: """ Context object for writing abstract implementations for custom operators. """ def __init__(self, _shape_env, _op): self._shape_env = _shape_env self._op = _op def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt: warnings.warn( "create_unbacked_symint is deprecated, please use new_dynamic_size instead" ) return self.new_dynamic_size(min=min, max=max) def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt: """Constructs a new symint (symbolic int) representing a data-dependent value. This is useful for writing the abstract implementation (which is necessary for torch.compile) for a CustomOp where an output Tensor has a size that depends on the data of the input Tensors. Args: min (int): A statically known inclusive lower bound for this symint. Default: 0 max (Optional[int]): A statically known inclusive upper bound for this symint. Default: None .. warning: It is important that the ``min`` and ``max`` (if not None) values are set correctly, otherwise, there will be undefined behavior under torch.compile. The default value of ``min`` is 2 due to torch.compile specializing on 0/1 sizes. You must also verify that your implementation on concrete Tensors (e.g. CPU/CUDA) only returns Tensors where the size that corresponds to the symint also has respects these constraint. The easiest way to do this is to add an assertion in the CPU/CUDA/etc implementation that the size follows these bounds. Example:: >>> # An operator with data-dependent output shape >>> lib = torch.library.Library("mymodule", "FRAGMENT") >>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor") >>> >>> @torch.library.impl_abstract("mymodule::custom_nonzero") >>> def custom_nonzero_abstract(x): >>> # Number of nonzero-elements is data-dependent. >>> # Since we cannot peek at the data in an abstract impl, >>> # we use the ctx object to construct a new symint that >>> # represents the data-dependent size. >>> ctx = torch.library.get_ctx() >>> nnz = ctx.new_dynamic_size() >>> shape = [nnz, x.dim()] >>> result = x.new_empty(shape, dtype=torch.int64) >>> return result >>> >>> @torch.library.impl(lib, "custom_nonzero", "CPU") >>> def custom_nonzero_cpu(x): >>> x_np = x.numpy() >>> res = np.stack(np.nonzero(x_np), axis=1) >>> return torch.tensor(res, device=x.device) """ if ( self._shape_env is None or not self._shape_env.allow_dynamic_output_shape_ops ): raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op) if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt): raise ValueError( f"ctx.new_dynamic_size(min={min}, max={max}): expected " f"min and max to be statically known ints but got SymInt. " f"This is not supported." ) if min < 0: raise ValueError( f"ctx.new_dynamic_size(min={min}, ...): expected min to be " f"greater than or equal to 0: this API can only create " f"non-negative sizes." ) result = self._shape_env.create_unbacked_symint() torch.fx.experimental.symbolic_shapes._constrain_range_for_size( result, min=min, max=max ) return result