275 lines
12 KiB
Python
275 lines
12 KiB
Python
|
import torch
|
||
|
import torch.utils._pytree as pytree
|
||
|
from collections import namedtuple
|
||
|
import functools
|
||
|
|
||
|
|
||
|
# NOTE [CustomOp autograd kernel indirection]
|
||
|
# We register `inner` as the autograd kernel for this custom_op.
|
||
|
# `inner` either calls the autograd formula registered by the user,
|
||
|
# or goes into an `autograd_not_implemented` kernel.
|
||
|
#
|
||
|
# The reason why this indirection exists is
|
||
|
# so that we can swap out the autograd kernel (the PyTorch dispatcher
|
||
|
# doesn't actually allow us to do this). By default, we want
|
||
|
# the `autograd_not_implemented` behavior, but then the user may come
|
||
|
# and register something that is actually a backward formula
|
||
|
def autograd_kernel_indirection(custom_op):
|
||
|
autograd_fallback = autograd_not_implemented(custom_op)
|
||
|
|
||
|
def inner(*args, **kwargs):
|
||
|
if custom_op._has_impl('autograd'):
|
||
|
kernel = custom_op._get_impl('autograd').func
|
||
|
return kernel(*args, **kwargs)
|
||
|
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
|
||
|
# after the user gives us "backward" and "save_for_backward", we generate
|
||
|
# the "autograd" impl. If the user only provided one, then we tell
|
||
|
# the user they've done something wrong.
|
||
|
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
|
||
|
missing = (
|
||
|
'save_for_backward' if custom_op._has_impl('backward')
|
||
|
else 'backward'
|
||
|
)
|
||
|
found = 'save_for_backward' if missing == 'backward' else 'backward'
|
||
|
loc = custom_op._get_impl(found).location
|
||
|
raise RuntimeError(
|
||
|
f"We found a '{found}' registration for {custom_op} at "
|
||
|
f"{loc} but were unable to find a '{missing}' registration. "
|
||
|
f"To use the CustomOp API to register a backward formula, "
|
||
|
f"please provide us both a backward function and a "
|
||
|
f"'save for backward' function via `impl_backward` and "
|
||
|
f"`impl_save_for_backward` respectively.")
|
||
|
return autograd_fallback(*args, **kwargs)
|
||
|
return inner
|
||
|
|
||
|
|
||
|
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
|
||
|
# or change the default autograd fallback to the autograd not implemented fallback.
|
||
|
def autograd_not_implemented(custom_op):
|
||
|
def kernel(*args, **kwargs):
|
||
|
if torch.is_grad_enabled() and pytree.tree_any(
|
||
|
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
|
||
|
):
|
||
|
raise RuntimeError("Autograd has not been implemented for operator")
|
||
|
with torch._C._AutoDispatchBelowAutograd():
|
||
|
return custom_op(*args, **kwargs)
|
||
|
return kernel
|
||
|
|
||
|
|
||
|
def mark_non_differentiable(ctx, output, output_differentiability):
|
||
|
# Output types are restricted to be:
|
||
|
# - Tensor
|
||
|
# - Tensor[]
|
||
|
# - int, bool, Scalar, float
|
||
|
# See _check_can_register_backward
|
||
|
if output_differentiability is not None:
|
||
|
if not isinstance(output, tuple):
|
||
|
tuple_output = (output,)
|
||
|
else:
|
||
|
tuple_output = output # type: ignore[assignment]
|
||
|
assert len(output_differentiability) == len(tuple_output)
|
||
|
non_differentiable_tensors = []
|
||
|
for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
|
||
|
if isinstance(out, torch.Tensor):
|
||
|
if not differentiable:
|
||
|
non_differentiable_tensors.append(out)
|
||
|
continue
|
||
|
if isinstance(out, list):
|
||
|
if not differentiable:
|
||
|
non_differentiable_tensors.extend(out)
|
||
|
continue
|
||
|
if differentiable:
|
||
|
raise RuntimeError(
|
||
|
f"With output_differentiability={output_differentiability}. "
|
||
|
f"At idx {idx}, we received an object of type {type(out)} that "
|
||
|
f"is not a Tensor, so it cannot have be marked as differentiable in "
|
||
|
f"output_differentiability.")
|
||
|
if non_differentiable_tensors:
|
||
|
ctx.mark_non_differentiable(*non_differentiable_tensors)
|
||
|
|
||
|
|
||
|
def construct_autograd_kernel(
|
||
|
schema,
|
||
|
output_differentiability,
|
||
|
custom_op,
|
||
|
op_overload,
|
||
|
save_for_backward_fn,
|
||
|
backward_fn):
|
||
|
|
||
|
def apply(*args):
|
||
|
flat_args, spec = pytree.tree_flatten(args)
|
||
|
out_spec = None
|
||
|
|
||
|
def forward(ctx, *flat_args):
|
||
|
ctx.set_materialize_grads(True)
|
||
|
args = pytree.tree_unflatten(list(flat_args), spec)
|
||
|
with torch._C._AutoDispatchBelowAutograd():
|
||
|
output = op_overload(*args)
|
||
|
|
||
|
# We use the info about args to give better error messages in backward
|
||
|
args_info = namedtuple_args(
|
||
|
schema, pytree.tree_map(type, args))
|
||
|
|
||
|
save_for_backward_fn_inputs = namedtuple_args(schema, args)
|
||
|
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
|
||
|
|
||
|
save_pytree_for_backward(ctx, (to_save, args_info))
|
||
|
mark_non_differentiable(ctx, output, output_differentiability)
|
||
|
|
||
|
nonlocal out_spec
|
||
|
flat_output, out_spec = pytree.tree_flatten(output)
|
||
|
return tuple(flat_output)
|
||
|
|
||
|
def backward(ctx, *flat_grad_output):
|
||
|
assert out_spec is not None
|
||
|
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
|
||
|
saved, args_info = unpack_saved(ctx)
|
||
|
# There is nothing on the ctx object for now, it is just there so
|
||
|
# that we can add additional things in the future.
|
||
|
inner_ctx = object()
|
||
|
if not isinstance(grads, tuple):
|
||
|
grads = (grads,)
|
||
|
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
|
||
|
|
||
|
# Massage the grad_inputs_dict to a form acceptable by
|
||
|
# autograd.Function.
|
||
|
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
|
||
|
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
|
||
|
|
||
|
generated_cls = gen_autograd_function(
|
||
|
custom_op._opname + '_customop', forward, backward)
|
||
|
|
||
|
flat_output = generated_cls.apply(*flat_args)
|
||
|
assert out_spec is not None
|
||
|
return pytree.tree_unflatten(list(flat_output), out_spec)
|
||
|
return apply
|
||
|
|
||
|
|
||
|
def gen_autograd_function(name, forward, backward):
|
||
|
generated_cls = type(
|
||
|
name,
|
||
|
(torch.autograd.Function,),
|
||
|
{
|
||
|
'forward': staticmethod(forward),
|
||
|
'backward': staticmethod(backward),
|
||
|
}
|
||
|
)
|
||
|
return generated_cls
|
||
|
|
||
|
|
||
|
@functools.lru_cache
|
||
|
def namedtuple_args_cls(schema):
|
||
|
attribs = [arg.name for arg in schema.arguments.flat_all]
|
||
|
name = str(schema.name) + "_args"
|
||
|
# mypy doesn't support dynamic namedtuple name
|
||
|
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
|
||
|
return tuple_cls
|
||
|
|
||
|
|
||
|
def namedtuple_args(schema, args):
|
||
|
assert isinstance(args, tuple)
|
||
|
tuple_cls = namedtuple_args_cls(schema)
|
||
|
return tuple_cls(*args)
|
||
|
|
||
|
|
||
|
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
|
||
|
def error(what):
|
||
|
backward = forward_op._get_impl('backward')
|
||
|
raise RuntimeError(
|
||
|
f"In the backward function defined for {forward_op} at "
|
||
|
f"{backward.location} using the CustomOp API, {what}")
|
||
|
|
||
|
if not isinstance(grad_inputs_dict, dict):
|
||
|
error(f"expected the output of the backward function to be a dict but "
|
||
|
f"got {type(grad_inputs_dict)}")
|
||
|
|
||
|
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
|
||
|
if arg.type.is_tensor_like()}
|
||
|
actual_keys = grad_inputs_dict.keys()
|
||
|
if expected_keys != actual_keys:
|
||
|
error(f"expected the returned grad_input dict to have keys "
|
||
|
f"{expected_keys} but got {actual_keys}. The backward "
|
||
|
f"function must return a gradient (can be None) for each arg "
|
||
|
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
|
||
|
f"Args declared to be non-Tensor-like types should not appear "
|
||
|
f"in the grad_input dict")
|
||
|
|
||
|
for name, grad in grad_inputs_dict.items():
|
||
|
arg_info = getattr(args_info, name)
|
||
|
|
||
|
if isinstance(arg_info, list):
|
||
|
if not isinstance(grad, (tuple, list)):
|
||
|
error(f"for input '{name}' expected the grad_input dict to "
|
||
|
f"hold a list of gradients but got object of type "
|
||
|
f"{type(grad)}.")
|
||
|
if not len(grad) == len(arg_info):
|
||
|
error(f"for input '{name}' expected the grad_input dict to "
|
||
|
f"hold a list of {len(arg_info)} gradients but got "
|
||
|
f"{len(grad)}")
|
||
|
for idx, (g, info) in enumerate(zip(grad, arg_info)):
|
||
|
if g is None:
|
||
|
continue
|
||
|
if not isinstance(g, torch.Tensor):
|
||
|
error(f"for input '{name}' expected the grad_input dict to "
|
||
|
f"hold a list of None or Tensor gradients but got "
|
||
|
f"object of {type(g)} at index {idx}")
|
||
|
if not issubclass(info, torch.Tensor):
|
||
|
error(f"for input '{name}', got a Tensor as the gradient "
|
||
|
f"for the {idx}-th value but expected None because "
|
||
|
f"the {idx}-th value was not a Tensor (it was "
|
||
|
f"type {arg_info}")
|
||
|
continue
|
||
|
|
||
|
if grad is None:
|
||
|
continue
|
||
|
if not isinstance(grad, torch.Tensor):
|
||
|
error(f"got object of type {type(grad)} as the gradient for input "
|
||
|
f"'{name}', "
|
||
|
f"but expected the gradient to be either None or a Tensor")
|
||
|
if not issubclass(arg_info, torch.Tensor):
|
||
|
error(f"got a Tensor as the gradient for input '{name}' but "
|
||
|
f"expected None as the gradient because input '{name}' "
|
||
|
f"was not a Tensor (it was type {arg_info}).")
|
||
|
|
||
|
|
||
|
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
|
||
|
result = []
|
||
|
for name, arg_info in args_info._asdict().items():
|
||
|
if name not in grad_inputs_dict:
|
||
|
result.append(pytree.tree_map(lambda x: None, arg_info))
|
||
|
continue
|
||
|
result.append(grad_inputs_dict[name])
|
||
|
return tuple(pytree.tree_leaves(result))
|
||
|
|
||
|
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
|
||
|
# autograd.Function prefers that users use ctx.save_for_backward to
|
||
|
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
|
||
|
# ctx object.
|
||
|
def save_pytree_for_backward(ctx, stuff):
|
||
|
flat_stuff, spec = pytree.tree_flatten(stuff)
|
||
|
num_elts = len(flat_stuff)
|
||
|
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
||
|
if isinstance(thing, torch.Tensor)]
|
||
|
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
||
|
if not isinstance(thing, torch.Tensor)]
|
||
|
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
|
||
|
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
|
||
|
|
||
|
ctx.spec = spec
|
||
|
ctx.num_elts = num_elts
|
||
|
ctx.save_for_backward(*tensors)
|
||
|
ctx.tensor_idxs = tensor_idxs
|
||
|
ctx.saved_non_tensors = non_tensors
|
||
|
ctx.non_tensor_idxs = non_tensor_idxs
|
||
|
|
||
|
|
||
|
# Inverse operation to save_pytree_for_backward
|
||
|
def unpack_saved(ctx):
|
||
|
flat_stuff = [None] * ctx.num_elts
|
||
|
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
|
||
|
flat_stuff[idx] = tensor
|
||
|
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
|
||
|
flat_stuff[idx] = non_tensor
|
||
|
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
|
||
|
return stuff
|