ai-content-maker/.venv/Lib/site-packages/torch/_subclasses/fake_utils.py

191 lines
7.0 KiB
Python

# mypy: ignore-errors
import functools
import warnings
from typing import Callable, Union
import torch
import torch.utils._pytree as pytree
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import (
FakeTensorMode,
tree_flatten_only,
UnsupportedFakeTensorException,
)
from torch.utils._python_dispatch import TorchDispatchMode
aten = torch._ops.ops.aten
def outputs_alias_inputs(outputs, inputs):
input_storages = {
inp._typed_storage()._cdata
for inp in tree_flatten_only(torch.Tensor, inputs)
if torch._C._has_storage(inp)
}
return any(
torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
for out in tree_flatten_only(torch.Tensor, outputs)
)
def outputs_are_inputs(outputs, inputs):
input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))
def output_alias_each_other(outputs):
storages = set()
for out in tree_flatten_only(torch.Tensor, outputs):
if not torch._C._has_storage(out):
continue
stor = out._typed_storage()._cdata
if stor in storages:
return True
storages.add(stor)
return False
def is_sdpa_error(func, idx, e):
if (
(
func is aten._scaled_dot_product_flash_attention.default
or func is aten._flash_attention_forward.default
)
and idx in (6, 7)
and "Devices" in repr(e)
):
return True
if (
(
func is aten._scaled_dot_product_efficient_attention.default
or func is aten._efficient_attention_forward.default
)
and idx in (2, 3)
and "Devices" in repr(e)
):
return True
return False
class CrossRefFakeMode(TorchDispatchMode):
def __init__(
self,
ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
*,
check_strides=True,
check_aliasing=True,
):
self.ignore_op_fn = (
ignore_op_fn if ignore_op_fn is not None else lambda fn: False
)
self.check_strides = check_strides
self.check_aliasing = check_aliasing
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
fake_r = None
# empty_like excluded for now due to sparse complex
# aten._to_dense.default this one is getting called with csc
if (
func
not in (
aten.lift_fresh.default,
aten.lift_fresh_copy.default,
aten.set_.source_Storage_storage_offset,
)
and not self.ignore_op_fn(func)
and torch.Tag.dynamic_output_shape not in func.tags
and torch.Tag.inplace_view not in func.tags
and torch.Tag.data_dependent_output not in func.tags
):
# Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow
from torch.fx.experimental.symbolic_shapes import ShapeEnv
try:
# TODO: enable_python_dispatcher() here
with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
fake_args, fake_kwargs = pytree.tree_map_only(
torch.Tensor,
functools.partial(fake_mode.from_tensor, static_shapes=True),
(args, kwargs),
)
with warnings.catch_warnings():
fake_r = func(*fake_args, **fake_kwargs)
except UnsupportedFakeTensorException:
pass
context = (
f"When comparing the output of {func} on FakeTensor and concrete Tensors, "
f"found"
)
r = func(*args, **kwargs)
if fake_r is not None:
r_flat = pytree.tree_leaves(r)
f_flat = pytree.tree_leaves(fake_r)
assert len(f_flat) == len(
r_flat
), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
if self.check_aliasing:
r_aliasing = outputs_alias_inputs(r, (args, kwargs))
f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs))
assert (
r_aliasing == f_aliasing
), f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}"
r_identity_eq = outputs_are_inputs(r, (args, kwargs))
f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs))
assert (
r_identity_eq == f_identity_eq
), f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}"
r_output_alias_each_other = output_alias_each_other(r)
f_output_alias_each_other = output_alias_each_other(fake_r)
assert r_output_alias_each_other == f_output_alias_each_other, (
f"{context} mismatch in outputs_alias_each_other check "
f"{f_output_alias_each_other} != {r_output_alias_each_other}"
)
for idx, (r_out, fake_out) in enumerate(
zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r))
):
r_is_ten = isinstance(r_out, torch.Tensor)
assert r_is_ten == isinstance(
fake_out, torch.Tensor
), f"{context} mismatched number of tensor outputs"
if r_is_ten:
assert r_out.requires_grad == fake_out.requires_grad, (
f"{context} mismatched requires_grad-ness of outputs. "
f"This usually means that you have added autograd support "
f"for your operator at a dispatch key other than Autograd, "
f"which will lead to problems"
)
if torch._C._has_storage(r_out):
r_offset = r_out.storage_offset()
f_offset = fake_out.storage_offset()
assert (
r_offset == f_offset
), f"{context} mismatched storage offset"
try:
torch._prims.utils.compare_tensor_meta(
r_out,
fake_out,
check_strides=self.check_strides,
allow_rhs_unbacked=True,
)
except Exception as e:
if is_sdpa_error(func, idx, e):
continue
error_message = (
f"{context} mismatched tensor metadata: {e}"
if len(r_flat) == 1
else f"{context} mismatched tensor metadata for output[{idx}]: {e}"
)
raise RuntimeError(error_message) from e
return r