import torch import torch.utils._pytree as pytree from collections import namedtuple import functools # NOTE [CustomOp autograd kernel indirection] # We register `inner` as the autograd kernel for this custom_op. # `inner` either calls the autograd formula registered by the user, # or goes into an `autograd_not_implemented` kernel. # # The reason why this indirection exists is # so that we can swap out the autograd kernel (the PyTorch dispatcher # doesn't actually allow us to do this). By default, we want # the `autograd_not_implemented` behavior, but then the user may come # and register something that is actually a backward formula def autograd_kernel_indirection(custom_op): autograd_fallback = autograd_not_implemented(custom_op) def inner(*args, **kwargs): if custom_op._has_impl('autograd'): kernel = custom_op._get_impl('autograd').func return kernel(*args, **kwargs) # As explained in NOTE ["backward", "save_for_backward", and "autograd"], # after the user gives us "backward" and "save_for_backward", we generate # the "autograd" impl. If the user only provided one, then we tell # the user they've done something wrong. if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'): missing = ( 'save_for_backward' if custom_op._has_impl('backward') else 'backward' ) found = 'save_for_backward' if missing == 'backward' else 'backward' loc = custom_op._get_impl(found).location raise RuntimeError( f"We found a '{found}' registration for {custom_op} at " f"{loc} but were unable to find a '{missing}' registration. " f"To use the CustomOp API to register a backward formula, " f"please provide us both a backward function and a " f"'save for backward' function via `impl_backward` and " f"`impl_save_for_backward` respectively.") return autograd_fallback(*args, **kwargs) return inner # TODO(#101191): Use the actual C++ autograd not implemented fallback, # or change the default autograd fallback to the autograd not implemented fallback. def autograd_not_implemented(custom_op): def kernel(*args, **kwargs): if torch.is_grad_enabled() and pytree.tree_any( lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs) ): raise RuntimeError("Autograd has not been implemented for operator") with torch._C._AutoDispatchBelowAutograd(): return custom_op(*args, **kwargs) return kernel def mark_non_differentiable(ctx, output, output_differentiability): # Output types are restricted to be: # - Tensor # - Tensor[] # - int, bool, Scalar, float # See _check_can_register_backward if output_differentiability is not None: if not isinstance(output, tuple): tuple_output = (output,) else: tuple_output = output # type: ignore[assignment] assert len(output_differentiability) == len(tuple_output) non_differentiable_tensors = [] for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)): if isinstance(out, torch.Tensor): if not differentiable: non_differentiable_tensors.append(out) continue if isinstance(out, list): if not differentiable: non_differentiable_tensors.extend(out) continue if differentiable: raise RuntimeError( f"With output_differentiability={output_differentiability}. " f"At idx {idx}, we received an object of type {type(out)} that " f"is not a Tensor, so it cannot have be marked as differentiable in " f"output_differentiability.") if non_differentiable_tensors: ctx.mark_non_differentiable(*non_differentiable_tensors) def construct_autograd_kernel( schema, output_differentiability, custom_op, op_overload, save_for_backward_fn, backward_fn): def apply(*args): flat_args, spec = pytree.tree_flatten(args) out_spec = None def forward(ctx, *flat_args): ctx.set_materialize_grads(True) args = pytree.tree_unflatten(list(flat_args), spec) with torch._C._AutoDispatchBelowAutograd(): output = op_overload(*args) # We use the info about args to give better error messages in backward args_info = namedtuple_args( schema, pytree.tree_map(type, args)) save_for_backward_fn_inputs = namedtuple_args(schema, args) to_save = save_for_backward_fn(save_for_backward_fn_inputs, output) save_pytree_for_backward(ctx, (to_save, args_info)) mark_non_differentiable(ctx, output, output_differentiability) nonlocal out_spec flat_output, out_spec = pytree.tree_flatten(output) return tuple(flat_output) def backward(ctx, *flat_grad_output): assert out_spec is not None grads = pytree.tree_unflatten(list(flat_grad_output), out_spec) saved, args_info = unpack_saved(ctx) # There is nothing on the ctx object for now, it is just there so # that we can add additional things in the future. inner_ctx = object() if not isinstance(grads, tuple): grads = (grads,) grad_inputs_dict = backward_fn(inner_ctx, saved, *grads) # Massage the grad_inputs_dict to a form acceptable by # autograd.Function. validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info) return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info) generated_cls = gen_autograd_function( custom_op._opname + '_customop', forward, backward) flat_output = generated_cls.apply(*flat_args) assert out_spec is not None return pytree.tree_unflatten(list(flat_output), out_spec) return apply def gen_autograd_function(name, forward, backward): generated_cls = type( name, (torch.autograd.Function,), { 'forward': staticmethod(forward), 'backward': staticmethod(backward), } ) return generated_cls @functools.lru_cache def namedtuple_args_cls(schema): attribs = [arg.name for arg in schema.arguments.flat_all] name = str(schema.name) + "_args" # mypy doesn't support dynamic namedtuple name tuple_cls = namedtuple(name, attribs) # type: ignore[misc] return tuple_cls def namedtuple_args(schema, args): assert isinstance(args, tuple) tuple_cls = namedtuple_args_cls(schema) return tuple_cls(*args) def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info): def error(what): backward = forward_op._get_impl('backward') raise RuntimeError( f"In the backward function defined for {forward_op} at " f"{backward.location} using the CustomOp API, {what}") if not isinstance(grad_inputs_dict, dict): error(f"expected the output of the backward function to be a dict but " f"got {type(grad_inputs_dict)}") expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all if arg.type.is_tensor_like()} actual_keys = grad_inputs_dict.keys() if expected_keys != actual_keys: error(f"expected the returned grad_input dict to have keys " f"{expected_keys} but got {actual_keys}. The backward " f"function must return a gradient (can be None) for each arg " f"to the CustomOp that may be a Tensor or Sequence[Tensor]. " f"Args declared to be non-Tensor-like types should not appear " f"in the grad_input dict") for name, grad in grad_inputs_dict.items(): arg_info = getattr(args_info, name) if isinstance(arg_info, list): if not isinstance(grad, (tuple, list)): error(f"for input '{name}' expected the grad_input dict to " f"hold a list of gradients but got object of type " f"{type(grad)}.") if not len(grad) == len(arg_info): error(f"for input '{name}' expected the grad_input dict to " f"hold a list of {len(arg_info)} gradients but got " f"{len(grad)}") for idx, (g, info) in enumerate(zip(grad, arg_info)): if g is None: continue if not isinstance(g, torch.Tensor): error(f"for input '{name}' expected the grad_input dict to " f"hold a list of None or Tensor gradients but got " f"object of {type(g)} at index {idx}") if not issubclass(info, torch.Tensor): error(f"for input '{name}', got a Tensor as the gradient " f"for the {idx}-th value but expected None because " f"the {idx}-th value was not a Tensor (it was " f"type {arg_info}") continue if grad is None: continue if not isinstance(grad, torch.Tensor): error(f"got object of type {type(grad)} as the gradient for input " f"'{name}', " f"but expected the gradient to be either None or a Tensor") if not issubclass(arg_info, torch.Tensor): error(f"got a Tensor as the gradient for input '{name}' but " f"expected None as the gradient because input '{name}' " f"was not a Tensor (it was type {arg_info}).") def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info): result = [] for name, arg_info in args_info._asdict().items(): if name not in grad_inputs_dict: result.append(pytree.tree_map(lambda x: None, arg_info)) continue result.append(grad_inputs_dict[name]) return tuple(pytree.tree_leaves(result)) # Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it. # autograd.Function prefers that users use ctx.save_for_backward to # save Tensors (to avoid reference cycles) and for non-Tensors to go onto the # ctx object. def save_pytree_for_backward(ctx, stuff): flat_stuff, spec = pytree.tree_flatten(stuff) num_elts = len(flat_stuff) tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) if isinstance(thing, torch.Tensor)] non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) if not isinstance(thing, torch.Tensor)] tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)] non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)] ctx.spec = spec ctx.num_elts = num_elts ctx.save_for_backward(*tensors) ctx.tensor_idxs = tensor_idxs ctx.saved_non_tensors = non_tensors ctx.non_tensor_idxs = non_tensor_idxs # Inverse operation to save_pytree_for_backward def unpack_saved(ctx): flat_stuff = [None] * ctx.num_elts for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs): flat_stuff[idx] = tensor for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs): flat_stuff[idx] = non_tensor stuff = pytree.tree_unflatten(flat_stuff, ctx.spec) return stuff