988 lines
46 KiB
Python
988 lines
46 KiB
Python
import contextlib
|
|
import warnings
|
|
import weakref
|
|
from typing import ContextManager, Dict, List, Optional, Tuple, TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch._C._functorch import (
|
|
_add_batch_dim,
|
|
_unwrap_functional_tensor,
|
|
_wrap_functional_tensor,
|
|
current_level,
|
|
get_unwrapped,
|
|
is_batchedtensor,
|
|
is_functorch_wrapped_tensor,
|
|
is_gradtrackingtensor,
|
|
maybe_get_bdim,
|
|
maybe_get_level,
|
|
peek_interpreter_stack,
|
|
TransformType,
|
|
)
|
|
from torch._guards import Source
|
|
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
from torch.utils._python_dispatch import (
|
|
is_traceable_wrapper_subclass,
|
|
transform_subclass,
|
|
)
|
|
from torch.utils.weak import WeakIdRef
|
|
|
|
if TYPE_CHECKING:
|
|
# Import the following modules during type checking to enable code intelligence features,
|
|
# Do not import unconditionally, as they import sympy and importing sympy is very slow
|
|
from torch.fx.experimental.symbolic_shapes import SymbolicContext
|
|
|
|
DimList = List
|
|
|
|
|
|
def safe_is_leaf(t):
|
|
try:
|
|
return t.is_leaf
|
|
except RuntimeError:
|
|
# inference mode can trigger this
|
|
return False
|
|
|
|
|
|
def safe_grad(t):
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
|
|
return t.grad
|
|
|
|
|
|
def assert_eq(a, b):
|
|
assert a == b, f"{a} != {b}"
|
|
|
|
|
|
def assert_metadata_eq(assert_eq, m1, m2, *, skip_symbolic=False):
|
|
def go(m1, m2):
|
|
assert_eq(m1.dtype, m2.dtype)
|
|
if not skip_symbolic:
|
|
assert_eq(m1.shape, m2.shape)
|
|
assert_eq(m1.requires_grad, m2.requires_grad)
|
|
assert_eq(m1.is_leaf, m2.is_leaf)
|
|
assert_eq(m1.grad_fn is None, m2.grad_fn is None)
|
|
assert_eq(m1.is_sparse, m2.is_sparse)
|
|
assert_eq(m1.is_inference(), m2.is_inference())
|
|
assert_eq(m1.is_conj(), m2.is_conj())
|
|
assert_eq(m1.is_neg(), m2.is_neg())
|
|
assert_eq(safe_grad(m1) is not None, safe_grad(m2) is not None)
|
|
if safe_grad(m1) is not None:
|
|
go(safe_grad(m1), safe_grad(m2))
|
|
if m1.is_sparse:
|
|
assert_eq(m1.dense_dim(), m2.dense_dim())
|
|
assert_eq(m1.sparse_dim(), m2.sparse_dim())
|
|
assert_eq(m1.is_coalesced(), m2.is_coalesced())
|
|
else:
|
|
if not skip_symbolic:
|
|
assert_eq(m1.stride(), m2.stride())
|
|
assert_eq(m1.storage_offset(), m2.storage_offset())
|
|
assert_eq(m1._is_view(), m2._is_view())
|
|
if m1._is_view():
|
|
go(m1._base, m2._base)
|
|
# TODO: test if is resizable (no direct query for this atm)
|
|
# TODO: audit AutogradMeta to see if it matches
|
|
# TODO: test forward AD
|
|
|
|
return go(m1, m2)
|
|
|
|
|
|
def is_sparse_coo(t):
|
|
return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo
|
|
|
|
|
|
def is_sparse_compressed(t):
|
|
return isinstance(t, torch.Tensor) and t.layout in {
|
|
torch.sparse_csr,
|
|
torch.sparse_csc,
|
|
torch.sparse_bsr,
|
|
torch.sparse_bsc,
|
|
}
|
|
|
|
|
|
def is_sparse_any(t):
|
|
return is_sparse_coo(t) or is_sparse_compressed(t)
|
|
|
|
|
|
# This is a class for converting multiple tensors into meta tensors which
|
|
# share the same view/storage structure. The operation model is you allocate
|
|
# one of these, and then call it repeatedly on all the tensors you want to
|
|
# convert. It's important to use the same object for tensors you want to
|
|
# share storage because this is how we correlate shared storages to the same
|
|
# meta storages. This class will hold weak references to cached tenosrs
|
|
# and tensor storages.
|
|
class MetaConverter:
|
|
def __init__(self):
|
|
self.storage_memo = {}
|
|
self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
|
self.maybe_storages_to_delete = []
|
|
self.check_expired_frequency = 128
|
|
self.check_expired_count = 0
|
|
self.hit = 0
|
|
self.miss = 0
|
|
self.del_hook = None
|
|
self.arg_cnt = 0
|
|
|
|
def successful(self):
|
|
return self.hit > 0 and self.miss == 0
|
|
|
|
def check_for_expired_weak_storages(self):
|
|
new_li = []
|
|
stor_to_delete = []
|
|
for obj in self.maybe_storages_to_delete:
|
|
if not obj.expired():
|
|
new_li.append(obj)
|
|
else:
|
|
stor_to_delete.append(obj)
|
|
for obj in stor_to_delete:
|
|
self.storage_memo.pop(obj, None)
|
|
self.maybe_storages_to_delete = new_li
|
|
|
|
# if for some reason we have aquired many storages which have not expired
|
|
# even though a tensor with their storage has expired (aliasing or otherwise)
|
|
# check for expired storages less often so as to bound the amount of work we
|
|
# do checking for expired storages
|
|
self.check_expired_frequency = max(
|
|
self.check_expired_frequency, len(self.maybe_storages_to_delete)
|
|
)
|
|
|
|
def get_tensor_memo(self, t):
|
|
return self.tensor_memo.get(WeakIdRef(t), None)
|
|
|
|
def set_tensor_memo(self, t, v):
|
|
# hold a weak ref to self, otherwise it will be kept alive
|
|
# by the del_ten closure
|
|
self_weak_ref = weakref.ref(self)
|
|
if is_sparse_any(t) or t.is_mkldnn or is_functorch_wrapped_tensor(t):
|
|
weak_st = None
|
|
else:
|
|
weak_st = StorageWeakRef(t._typed_storage())
|
|
tensor_ref_key = WeakIdRef(t)
|
|
|
|
def del_ten():
|
|
# tensor outlives the converter
|
|
self_ref = self_weak_ref()
|
|
if self_ref is None:
|
|
return
|
|
# on shutdown, tensor_ref_key may not be in memo
|
|
self_ref.tensor_memo.pop(tensor_ref_key, None)
|
|
if weak_st and weak_st.expired():
|
|
self_ref.storage_memo.pop(weak_st, None)
|
|
elif weak_st is not None:
|
|
# [expired-storages]
|
|
# NB: even though the tensor has died,
|
|
# the deallocation of its storage can take longer,
|
|
# even when the storage has no other uses/views.
|
|
# In this case, the StorageWeakRef object will be kept alive
|
|
# longer than it needs to be, however the storage itself
|
|
# will be deallocated. We retain the possibly dead storages
|
|
# and periodically check if any of them are expired and
|
|
# can be freed.
|
|
self_ref.maybe_storages_to_delete.append(weak_st)
|
|
|
|
weakref.finalize(t, del_ten)
|
|
self.tensor_memo[tensor_ref_key] = v
|
|
|
|
# NB: doesn't actually return a storage, because meta storage is
|
|
# not supported
|
|
def meta_storage(self, s, callback):
|
|
# NB: TypedStorage is freshly allocated and cannot be used as hash
|
|
# key index.
|
|
|
|
# Use a Weak Ref to s in order to not leak memory
|
|
swr = StorageWeakRef(s)
|
|
if swr not in self.storage_memo:
|
|
self.storage_memo[swr] = callback(
|
|
lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
|
|
).untyped_storage()
|
|
return self.storage_memo[swr]
|
|
|
|
# This function assumes that it's possible to do the conversion
|
|
# NB: name here is used in a conventional way by Dynamo; it corresponds
|
|
# precisely to the Source.name() of the tensor we're fakeifying and
|
|
# corresponds to a valid Python expression. When we construct sub-names
|
|
# as part of this process, we will maintain this invariant! (Even though
|
|
# other users of this may not need it this property to be upheld.)
|
|
def meta_tensor(
|
|
self,
|
|
t,
|
|
shape_env=None,
|
|
callback=lambda t: t(),
|
|
source: Optional[Source] = None,
|
|
symbolic_context: Optional["SymbolicContext"] = None,
|
|
):
|
|
if source is None:
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
# TODO: make a dedicated UnknownSource for this?
|
|
source = ConstantSource(
|
|
f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
|
|
)
|
|
|
|
# This indicates you set no_dispatch() before calling into this
|
|
# function. This is an error: we may be creating fake tensors and
|
|
# will perform operations on them which need fake tensor mode to
|
|
# be active. You will segfault if you are in a no_dispatch() block.
|
|
assert not torch._C._dispatch_tls_local_exclude_set().has(
|
|
torch._C.DispatchKey.Python
|
|
)
|
|
arg_cnt = self.arg_cnt
|
|
self.arg_cnt += 1
|
|
|
|
# When we make as_strided calls, we end up generating a guard
|
|
# that the new as_strided tensor is in bounds for the old storage
|
|
# for the base (since as_strided calls can "bust" out of their
|
|
# bounding box.) This guard is unnecessary: if a user is able
|
|
# to provide us a tensor with the view base setup this way, we
|
|
# don't need to produce a guard, because the fact that they
|
|
# were able to produce the view base means its in bounds.
|
|
#
|
|
# Now, ordinarily, this guard would be harmless. However, the
|
|
# generated guard refers to variables bound on the base variable.
|
|
# At the moment, Dynamo doesn't actually guard on x._base, because
|
|
# according to Voz this results in a lot of spurious invalidations,
|
|
# and also if the user doesn't directly make use of _base, its
|
|
# pointless anyway (because programs should be parametric over
|
|
# whether or not the input tensor is a view or not--unless you're
|
|
# mutating the input, but that's a whole 'nother ballgame). So
|
|
# for expediency, we suppress these guards so we don't have to
|
|
# deal with this (yet, anyway.)
|
|
#
|
|
# NB: An old version of this code suppressed guards for ALL operations
|
|
# happening during meta conversion, not just as_strided calls.
|
|
# This is too aggressive: we do duck sizing and 0/1 simplification
|
|
# as we allocate variables, and we do need to register guards for
|
|
# these cases.
|
|
maybe_suppress = contextlib.nullcontext
|
|
if shape_env is not None:
|
|
maybe_suppress = shape_env.suppress_guards
|
|
|
|
def sym_sizes_strides_storage_offset(
|
|
t, src, symbolic_context=symbolic_context
|
|
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
|
|
if shape_env is not None:
|
|
fake_mode = torch._subclasses.fake_tensor.maybe_get_fake_mode(t)
|
|
if fake_mode is not None and fake_mode.shape_env is shape_env:
|
|
# Don't reallocate the sizes; the shape envs are the same,
|
|
# so reuse the old sizes/strides/etc
|
|
return (t.size(), t.stride(), t.storage_offset())
|
|
else:
|
|
return shape_env.create_symbolic_sizes_strides_storage_offset(
|
|
t,
|
|
src,
|
|
symbolic_context=symbolic_context,
|
|
)
|
|
else:
|
|
assert symbolic_context is None
|
|
return (t.size(), t.stride(), t.storage_offset())
|
|
|
|
def empty_create(inner_t, inner_src, symbolic_context=symbolic_context):
|
|
(
|
|
inner_sizes,
|
|
inner_strides,
|
|
inner_storage_offset,
|
|
) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context)
|
|
return torch.empty_strided(
|
|
inner_sizes,
|
|
inner_strides,
|
|
dtype=inner_t.dtype,
|
|
device="meta",
|
|
)
|
|
|
|
# Creates a subclass instance with empty inner tensors according to the specified
|
|
# symbolic context.
|
|
def empty_create_subclass(
|
|
t,
|
|
outer_size,
|
|
outer_stride,
|
|
symbolic_context=symbolic_context,
|
|
callback=callback,
|
|
source=source,
|
|
):
|
|
from torch._dynamo.source import AttrSource
|
|
from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext
|
|
|
|
assert symbolic_context is None or isinstance(
|
|
symbolic_context, SubclassSymbolicContext
|
|
)
|
|
|
|
# Note: transform_subclass will use __tensor_unflatten__ to generate
|
|
# a fresh subclass wrapper with outer sizes / strides according to the
|
|
# outer symbolic context (passed in to this function). Inner size / stride
|
|
# / storage offset symbols are allocated according to the appropriate inner
|
|
# symbolic contexts, after which the checks in transform_subclass() will
|
|
# relate them to the outer metadata as possible.
|
|
return transform_subclass(
|
|
t,
|
|
lambda attr, inner_t: callback(
|
|
lambda: empty_create(
|
|
inner_t,
|
|
AttrSource(source, attr),
|
|
symbolic_context=(
|
|
None
|
|
if symbolic_context is None
|
|
else symbolic_context.inner_contexts[attr]
|
|
),
|
|
)
|
|
),
|
|
outer_size=outer_size,
|
|
outer_stride=outer_stride,
|
|
)
|
|
|
|
# Returns an all-dynamic symbolic context used for metafying the given tensor with
|
|
# fully dynamic dims. This is useful when fake-ifying intermediate tensors in
|
|
# closed-over ViewFunc state, as we don't have symbolic contexts for them, but we
|
|
# don't want to over-specialize during view replay.
|
|
def all_dynamic_symbolic_context(t, source, shape_env, callback):
|
|
from torch._dynamo.source import AttrSource
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
DimDynamic,
|
|
StatelessSymbolicContext,
|
|
SubclassSymbolicContext,
|
|
SymbolicContext,
|
|
)
|
|
|
|
view_base_context: Optional[SymbolicContext] = None
|
|
if t._is_view():
|
|
view_base_context = all_dynamic_symbolic_context(
|
|
t._base, AttrSource(source, "_base"), shape_env, callback
|
|
)
|
|
|
|
t_symbolic_context: SymbolicContext
|
|
t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.dim()
|
|
if is_traceable_wrapper_subclass(t):
|
|
inner_contexts: Dict[str, SymbolicContext] = {}
|
|
attrs, _ = t.__tensor_flatten__()
|
|
for attr in attrs:
|
|
assert isinstance(attr, str)
|
|
inner = getattr(t, attr)
|
|
inner_contexts[attr] = all_dynamic_symbolic_context(
|
|
inner, AttrSource(source, attr), shape_env, callback
|
|
)
|
|
t_symbolic_context = SubclassSymbolicContext(
|
|
dynamic_sizes=t_dynamic_sizes,
|
|
constraint_sizes=[None] * t.dim(),
|
|
inner_contexts=inner_contexts,
|
|
tensor_source=source,
|
|
view_base_context=view_base_context,
|
|
)
|
|
else:
|
|
t_symbolic_context = StatelessSymbolicContext(
|
|
dynamic_sizes=t_dynamic_sizes,
|
|
constraint_sizes=[None] * t.dim(),
|
|
view_base_context=view_base_context,
|
|
)
|
|
|
|
return t_symbolic_context
|
|
|
|
# Returns a fake-ified version of an input view tensor t, given an already fake-ified
|
|
# base. At a high level, we want two things:
|
|
# 1. fake_t should have the same view relationship to the given fake base as the
|
|
# input t has to its _base.
|
|
# 2. fake_t should have symbolic sizes / strides / storage offset according to the
|
|
# appropriate symbolic context (i.e. from the automatic dynamic algorithm).
|
|
#
|
|
# We currently take different strategies across view types:
|
|
# * For dense -> dense views, accomplish both (1) and (2) simultaneously via an
|
|
# as_strided() call on the fake-ified base, passing symbolic metadata.
|
|
# * For views involving subclasses, perform view replay using view funcs to
|
|
# achieve (1). It's necessary for (2) to swap out any closed-over state in
|
|
# the view funcs with symbolicized SymInts and fake-ified tensors. Doing this
|
|
# avoids specialization (and thus over-eager simplification of symbols) that
|
|
# could occur during view replay on the fake-ified base.
|
|
#
|
|
# Examples:
|
|
# * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled
|
|
# with an as_strided() call on the fake base passing symbolic metadata.
|
|
# * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg
|
|
# is made symbolic to avoid invalid specialization and view replay is then
|
|
# done to reconstruct the view.
|
|
# * _nested_from_jagged(values, offsets) is a dense -> subclass view
|
|
# that returns a subclass instance from a dense values tensor. The offsets
|
|
# tensor is closed over in the view func, as it can be considered view metadata.
|
|
# First, the offsets tensor is fake-ified according to the inner symbolic
|
|
# context and with the correct relationship to the outer size / stride metadata.
|
|
# Then view replay is done, swapping in the fake offsets so the view replay output
|
|
# is fully fake with no invalid specialization.
|
|
def view_from_base(base, t, source=source, shape_env=shape_env):
|
|
# fake-ify t's metadata according to the outer symbolic context
|
|
(sizes, strides, storage_offset) = sym_sizes_strides_storage_offset(
|
|
t, source
|
|
)
|
|
if not is_traceable_wrapper_subclass(
|
|
t
|
|
) and not is_traceable_wrapper_subclass(base):
|
|
# Dense -> Dense view case uses as_strided() to construct view relationship.
|
|
# TODO: Change this logic to use view replay for consistency?
|
|
# It's likely there is no view func available.
|
|
return base.as_strided(sizes, strides, storage_offset)
|
|
|
|
from torch._dynamo.source import EphemeralSource
|
|
from torch.fx.experimental.symbolic_shapes import sym_eq
|
|
|
|
def symint_visitor_fn(s):
|
|
if shape_env is None:
|
|
return s
|
|
|
|
# NB: The symbol here is expected to be simplified out because we a priori
|
|
# allocate inner and outer symbols according to the appropriate symbolic
|
|
# contexts and prefer those over this symbol during symbol simplification
|
|
# (via usage of EphemeralSource below). This -shouldn't- happen, but if
|
|
# this symbol somehow leaks out beyond the view tensor's shape metadata, our
|
|
# assumption of it being simplified out will fail and it may be guarded on,
|
|
# which will hard error.
|
|
sym_source = EphemeralSource("symint_visitor_fn")
|
|
symbol = shape_env.create_symbol(s, sym_source)
|
|
return shape_env.create_symintnode(symbol, hint=s, source=sym_source)
|
|
|
|
real_to_fake_mapping = {}
|
|
if is_traceable_wrapper_subclass(t):
|
|
# Fake-ify t naively here; this is only done so we can get fake-ified inner
|
|
# tensors with the correct relationships to the outer sizes / strides for use
|
|
# in view replay. It's done beforehand here because it's not easy to do when
|
|
# visiting tensors one-by-one during view replay.
|
|
#
|
|
# Example:
|
|
# Consider a Dense -> NJT view. NJT has (values, offsets) components and we
|
|
# want a view of values with the offsets closed over. As the offsets component
|
|
# is needed to describe the output view, it's important that it's fakeified
|
|
# correctly.
|
|
fake_t = empty_create_subclass(
|
|
t, outer_size=sizes, outer_stride=strides
|
|
)
|
|
attrs, _ = fake_t.__tensor_flatten__()
|
|
for attr in attrs:
|
|
real_to_fake_mapping[getattr(t, attr)] = getattr(fake_t, attr)
|
|
|
|
def tensor_visitor_fn(
|
|
visited_t, shape_env=shape_env, callback=callback, source=source
|
|
):
|
|
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
|
|
if visited_t is None:
|
|
return None
|
|
|
|
# Fake inner tensors of view subclasses will come from the mapping built above.
|
|
fake_visited_t = real_to_fake_mapping.get(visited_t, None)
|
|
if fake_visited_t is not None:
|
|
return fake_visited_t
|
|
|
|
# For other closed-over tensor state, fake-ify it as all dynamic with an
|
|
# ephemeral source. This avoids invalid specialization during view replay.
|
|
# If we find that in practice the usage of ephemeral sources isn't enough
|
|
# to guarantee that we don't have guards on these symbols, we may need to
|
|
# explicitly suppress guards (as is done for _base in the dense -> dense
|
|
# view case).
|
|
temp_source = EphemeralSource("tensor_visitor_fn")
|
|
return self.meta_tensor(
|
|
visited_t,
|
|
shape_env,
|
|
callback,
|
|
source=temp_source,
|
|
symbolic_context=all_dynamic_symbolic_context(
|
|
visited_t, temp_source, shape_env, callback
|
|
),
|
|
)
|
|
|
|
# Replay the view, swapping out any non-symbolic SymInts or real tensors
|
|
# for symbolic SymInts or fake tensors.
|
|
fake_t = t._view_func_unsafe(base, symint_visitor_fn, tensor_visitor_fn)
|
|
|
|
# Ensure the output has symbolic shapes according to the outer symbolic context.
|
|
# These checks should simplify out any symbols created for closed-over view func
|
|
# SymInts.
|
|
torch._check(sym_eq(fake_t.size(), sizes))
|
|
torch._check(sym_eq(fake_t.stride(), strides))
|
|
torch._check(sym_eq(fake_t.storage_offset(), storage_offset))
|
|
return fake_t
|
|
|
|
# see expired-storages
|
|
self.check_expired_count += 1
|
|
if self.check_expired_count >= self.check_expired_frequency:
|
|
self.check_for_expired_weak_storages()
|
|
self.check_expired_count = 0
|
|
|
|
if self.get_tensor_memo(t) is None:
|
|
with torch.inference_mode(t.is_inference()):
|
|
if t.is_sparse:
|
|
is_leaf = safe_is_leaf(t)
|
|
|
|
# The lambda function below is similar to
|
|
# `t.to(device='meta')` except the latter
|
|
# preserves nnz value
|
|
r = callback(
|
|
lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
|
|
t.sparse_dim(),
|
|
t.dense_dim(),
|
|
t.shape,
|
|
dtype=t.dtype,
|
|
layout=torch.sparse_coo,
|
|
device="meta",
|
|
)
|
|
)
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
# Note [is_coalesced is dispatched]
|
|
# Strangely enough, is_coalesced() is a dispatched operator,
|
|
# which means that it will get caught by fake tensor mode.
|
|
# Ordinarily this would error, but there's some logic in
|
|
# fake tensor ensure this doesn't happen.
|
|
r._coalesced_(t.is_coalesced())
|
|
if t.requires_grad:
|
|
r.requires_grad = True
|
|
if t.requires_grad and not is_leaf:
|
|
with torch.enable_grad():
|
|
r = r.clone()
|
|
r._coalesced_(t.is_coalesced())
|
|
elif is_sparse_compressed(t):
|
|
is_leaf = safe_is_leaf(t)
|
|
|
|
def mk_meta():
|
|
nnz = 0
|
|
batch_dim = t.ndim - t.sparse_dim() - t.dense_dim()
|
|
batch_size = t.shape[:batch_dim]
|
|
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
|
index_dtype = t.crow_indices().dtype
|
|
compressed_indices = torch.empty(
|
|
t.crow_indices().shape, device="meta", dtype=index_dtype
|
|
)
|
|
plain_indices = torch.empty(
|
|
(*t.col_indices().shape[:-1], nnz),
|
|
device="meta",
|
|
dtype=index_dtype,
|
|
)
|
|
else:
|
|
index_dtype = t.ccol_indices().dtype
|
|
compressed_indices = torch.empty(
|
|
t.ccol_indices().shape, device="meta", dtype=index_dtype
|
|
)
|
|
plain_indices = torch.empty(
|
|
(*t.row_indices().shape[:-1], nnz),
|
|
device="meta",
|
|
dtype=index_dtype,
|
|
)
|
|
values_shape = t.values().shape
|
|
values = torch.empty(
|
|
(
|
|
*values_shape[:batch_dim],
|
|
nnz,
|
|
*values_shape[batch_dim + 1 :],
|
|
),
|
|
dtype=t.dtype,
|
|
device="meta",
|
|
)
|
|
return torch.ops.aten.sparse_compressed_tensor(
|
|
compressed_indices,
|
|
plain_indices,
|
|
values,
|
|
t.shape,
|
|
layout=t.layout,
|
|
dtype=t.dtype,
|
|
device="meta",
|
|
)
|
|
|
|
# `mk_meta()` is similar to `t.to(device='meta'))`
|
|
# except `to('meta')` preserves nnz value while
|
|
# `mk_meta` result has nnz == 0.
|
|
r = callback(mk_meta)
|
|
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
if t.requires_grad:
|
|
r.requires_grad = True
|
|
if t.requires_grad and not is_leaf:
|
|
with torch.enable_grad():
|
|
r = r.clone()
|
|
elif t.is_nested and not is_traceable_wrapper_subclass(t):
|
|
# TODO: Handle this better in Dynamo?
|
|
# There are checks there now, but this can still be triggered by a dense
|
|
# tensor graph input that is a view of a strided NT.
|
|
from torch._dynamo.exc import unimplemented
|
|
|
|
unimplemented(
|
|
"strided nested tensors are not supported by meta conversion"
|
|
)
|
|
elif t.is_mkldnn:
|
|
is_leaf = safe_is_leaf(t)
|
|
sizes, strides, _storage_offset = sym_sizes_strides_storage_offset(
|
|
t, source
|
|
)
|
|
r = callback(
|
|
lambda: torch.empty_strided(
|
|
sizes, strides, dtype=t.dtype, device="meta"
|
|
)
|
|
)
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
if t.requires_grad:
|
|
r.requires_grad = True
|
|
if t.requires_grad and not is_leaf:
|
|
with torch.enable_grad():
|
|
r = r.clone()
|
|
elif is_functorch_wrapped_tensor(t):
|
|
if t._is_view():
|
|
from torch._dynamo.exc import unimplemented
|
|
|
|
unimplemented(
|
|
"view functorch tensors are not supported by meta conversion"
|
|
)
|
|
|
|
# Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor)
|
|
# in a FakeTensor
|
|
def _to_fake_tensor(t):
|
|
if is_batchedtensor(t):
|
|
ft = _to_fake_tensor(get_unwrapped(t))
|
|
lvl = maybe_get_level(t)
|
|
bdim = maybe_get_bdim(t)
|
|
r = _add_batch_dim(ft, bdim, lvl)
|
|
elif is_gradtrackingtensor(t):
|
|
disable_functorch = torch._C._DisableFuncTorch
|
|
with disable_functorch():
|
|
ft = _to_fake_tensor(get_unwrapped(t))
|
|
lvl = torch._C._functorch.maybe_get_level(t)
|
|
r = torch._C._functorch._wrap_for_grad(ft, lvl)
|
|
|
|
is_leaf = safe_is_leaf(t)
|
|
if t.requires_grad and safe_is_leaf(r):
|
|
r.requires_grad = True
|
|
elif t.requires_grad and not is_leaf:
|
|
with torch.enable_grad():
|
|
r = r.clone()
|
|
else:
|
|
sizes = t.size()
|
|
strides = t.stride()
|
|
r = callback(
|
|
lambda: torch.empty_strided(
|
|
sizes,
|
|
strides,
|
|
dtype=t.dtype,
|
|
device="meta",
|
|
)
|
|
)
|
|
return r
|
|
|
|
r = _to_fake_tensor(t)
|
|
|
|
elif t._is_view():
|
|
# Construct views in two steps: recursively meta-fy their
|
|
# base, and then create view(s) off that. NB: doing it
|
|
# directly from storage is WRONG because this won't cause
|
|
# version counters to get shared.
|
|
assert t._is_view()
|
|
|
|
base_symbolic_context = None
|
|
if shape_env and symbolic_context is not None:
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
StatelessSymbolicContext,
|
|
)
|
|
|
|
assert isinstance(symbolic_context, StatelessSymbolicContext)
|
|
# NB: This should generally be set when the input is a view,
|
|
# but the exception right now is for fake-ifying grads, which is
|
|
# a work in progress.
|
|
if symbolic_context.view_base_context is not None:
|
|
base_symbolic_context = symbolic_context.view_base_context
|
|
|
|
base = self.meta_tensor(
|
|
t._base,
|
|
shape_env,
|
|
callback,
|
|
source=torch._dynamo.source.AttrSource(source, "_base"),
|
|
symbolic_context=base_symbolic_context,
|
|
)
|
|
|
|
def is_c_of_r(complex_dtype, real_dtype):
|
|
return (
|
|
utils.is_complex_dtype(complex_dtype)
|
|
and utils.corresponding_real_dtype(complex_dtype)
|
|
== real_dtype
|
|
)
|
|
|
|
# In some situations, MetaConverter may be called in a
|
|
# context where autograd is disabled. For the _is_view
|
|
# assert to pass, we have to setup the autograd view
|
|
# metadata anyway. Do this by reenabling the
|
|
# ADInplaceOrView key. This is kind of a hack.
|
|
old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
|
|
torch._C.DispatchKey.ADInplaceOrView
|
|
)
|
|
torch._C._dispatch_tls_set_dispatch_key_excluded(
|
|
torch._C.DispatchKey.ADInplaceOrView, False
|
|
)
|
|
try:
|
|
if base.dtype == t.dtype:
|
|
pass
|
|
elif is_c_of_r(base.dtype, t.dtype):
|
|
base = torch.view_as_real(base)
|
|
elif is_c_of_r(t.dtype, base.dtype):
|
|
base = torch.view_as_complex(base)
|
|
else:
|
|
# This is not guaranteed to succeed. If it fails, it
|
|
# means there is another dtype-converting view function
|
|
# that hasn't been handled here
|
|
base = base.view(t.dtype)
|
|
|
|
# This is very tricky. Naively, you might expect this
|
|
# to hold:
|
|
#
|
|
# if t.requires_grad and not safe_is_leaf(t)
|
|
# assert t._base.requires_grad
|
|
#
|
|
# But it's not true! As you can see in the following
|
|
# program:
|
|
#
|
|
# x = torch.zeros(4)
|
|
# y = x.view(1, 4)
|
|
# y.requires_grad = True
|
|
# z = y.view(1, 1, 4)
|
|
# assert z._base is x
|
|
#
|
|
# So we may have to do *two* views out of the base to
|
|
# recreate this situation.
|
|
if safe_is_leaf(t):
|
|
# Leaf views that track view metadata are created by
|
|
# creating a view inside a no_grad block
|
|
with torch.no_grad(), maybe_suppress():
|
|
r = view_from_base(base, t)
|
|
# As it's a leaf, we can directly assign requires_grad
|
|
r.requires_grad = t.requires_grad
|
|
else:
|
|
if t._base.requires_grad == t.requires_grad:
|
|
# Easy case, just run the view op
|
|
with torch.enable_grad(), maybe_suppress():
|
|
r = view_from_base(base, t)
|
|
|
|
# NB: We don't actaully faithfully replicate
|
|
# autograd connectivity, but that doesn't matter
|
|
# today. See following for more info:
|
|
# https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913
|
|
else:
|
|
# Obscure case. Create a leaf view and give it the
|
|
# correct requires_grad, then do the final view.
|
|
# NB: Can't have a non-leaf without requiring grad!
|
|
assert t.requires_grad
|
|
with torch.no_grad():
|
|
mid = base.view(base.shape)
|
|
mid.requires_grad = t.requires_grad
|
|
with torch.enable_grad(), maybe_suppress():
|
|
r = view_from_base(mid, t)
|
|
# The CreationMeta influences whether or not inplace
|
|
# mutation is an error or not. So we need to make
|
|
# sure we properly propagate this as well.
|
|
torch._C._autograd._set_creation_meta(
|
|
r, torch._C._autograd._get_creation_meta(t)
|
|
)
|
|
finally:
|
|
torch._C._dispatch_tls_set_dispatch_key_excluded(
|
|
torch._C.DispatchKey.ADInplaceOrView, old_exclude
|
|
)
|
|
|
|
else:
|
|
is_leaf = safe_is_leaf(t)
|
|
|
|
(
|
|
sizes,
|
|
strides,
|
|
storage_offset,
|
|
) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
|
|
|
|
# If we have a subclass that desugars into dense tensors,
|
|
# perform our callback on each inner tensor.
|
|
if is_traceable_wrapper_subclass(t):
|
|
r = empty_create_subclass(
|
|
t, outer_size=sizes, outer_stride=strides
|
|
)
|
|
else:
|
|
r = callback(
|
|
lambda: torch.empty_strided(
|
|
sizes,
|
|
strides,
|
|
dtype=t.dtype,
|
|
device="meta",
|
|
)
|
|
)
|
|
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
if t.requires_grad:
|
|
r.requires_grad = t.requires_grad
|
|
if not is_leaf:
|
|
# Fake up some autograd history.
|
|
with torch.enable_grad():
|
|
# preserve_format is the default, but we want to
|
|
# emphasize how important it is to preserve
|
|
# format here
|
|
r = r.clone(memory_format=torch.preserve_format)
|
|
|
|
# Graph-Break for wrapped tensors
|
|
if not (
|
|
is_batchedtensor(t) or is_gradtrackingtensor(t)
|
|
) and torch._C._functorch.is_functorch_wrapped_tensor(t):
|
|
return NotImplemented
|
|
|
|
s = t.untyped_storage()
|
|
swr = StorageWeakRef(s)
|
|
if swr not in self.storage_memo and (
|
|
r.is_nested
|
|
or (
|
|
r.stride() == strides
|
|
and r.storage_offset() == storage_offset
|
|
)
|
|
):
|
|
# You're normal and happy, install the fresh storage into the memo
|
|
self.storage_memo[swr] = r.untyped_storage()
|
|
else:
|
|
# You're in crazy town; somehow you gave us a tensor
|
|
# that wasn't a view, but had nonzero storage offset,
|
|
# nontrivial strides (such that clone() couldn't
|
|
# preserve them), or already aliases with another
|
|
# tensor's storage. The most typical way to end
|
|
# up here is with set_. So use set_ to bludgeon this
|
|
# in.
|
|
r_s = self.meta_storage(s, callback=callback)
|
|
# NB: In principle, this should always work, but there
|
|
# is some subtle difference in the autograd metadata
|
|
# that means we will backprop the set_ call, even if
|
|
# r is declared as an input to grad.
|
|
# See https://github.com/pytorch/pytorch/issues/87956
|
|
# for the reproducer.
|
|
# NB: The in_kernel_invocation_manager here is necessary
|
|
# for fake tensor. If we run the set_ call with fake
|
|
# tensor on, r will improperly report that it is NOT a
|
|
# meta tensor but a cpu tensor, and then the set_ call
|
|
# will fail due to device mismatch. no_dispatch() is
|
|
# not enough, because the fake tensor will still claim
|
|
# to be a CPU tensor and you'll end up in the CPU
|
|
# kernel. Arguably this is a hack; a cleaner way to
|
|
# solve this is to have a FakeStorage concept which
|
|
# would report it's CPU device--no problem now! But
|
|
# this is difficult to do because we don't have storage
|
|
# subclasses. Relevant test is
|
|
# DynamicShapesFunctionTests::test_add_dynamic_shapes in
|
|
# test/dynamo/test_dynamic_shapes.py
|
|
maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext()
|
|
from torch._subclasses.fake_tensor import (
|
|
in_kernel_invocation_manager,
|
|
maybe_get_fake_mode,
|
|
)
|
|
|
|
mb_fake_mode = maybe_get_fake_mode(r)
|
|
if mb_fake_mode is not None:
|
|
maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode)
|
|
with maybe_fake_mgr, torch.no_grad():
|
|
r.set_(r_s, storage_offset, sizes, strides)
|
|
|
|
if safe_grad(t) is not None:
|
|
from torch._dynamo.source import AttrSource
|
|
|
|
# TODO: Use a valid grad-specific symbolic context instead of recycling
|
|
# the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view().
|
|
r.grad = self.meta_tensor(
|
|
safe_grad(t),
|
|
shape_env,
|
|
callback,
|
|
source=AttrSource(source, "grad"),
|
|
symbolic_context=symbolic_context,
|
|
)
|
|
torch._C._set_conj(r, t.is_conj())
|
|
torch._C._set_neg(r, t.is_neg())
|
|
# This can be skipped if necessary for performance reasons
|
|
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True)
|
|
self.set_tensor_memo(t, r)
|
|
|
|
return self.get_tensor_memo(t)
|
|
|
|
def __call__(
|
|
self,
|
|
t,
|
|
shape_env=None,
|
|
*,
|
|
callback=lambda t: t(),
|
|
source=None,
|
|
symbolic_context=None,
|
|
):
|
|
# TODO: zero tensors? We appear to have eliminated them by
|
|
# excluding complex for now
|
|
|
|
if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t):
|
|
if t.device.type != "xla" and any(
|
|
[
|
|
t.is_quantized,
|
|
t._is_view() and t._base is not None and t._base.is_sparse,
|
|
torch._is_functional_tensor(t),
|
|
t.device.type in ("lazy"),
|
|
# We need a way to test if a tensor is batched but there
|
|
# is no official APi to do it
|
|
# torch._C._is_batched(t),
|
|
]
|
|
):
|
|
# TODO: sparse should support meta
|
|
# NB technically to('meta') does work but our logging
|
|
# instrumentation will see the meta conversions and the
|
|
# tests all break so we just exclude this. In any case
|
|
# the to conversion isn't really right anyhow.
|
|
|
|
if torch._is_functional_tensor(t) and t.device.type != "lazy":
|
|
if t._is_view():
|
|
raise RuntimeError(
|
|
"Cannot safely fakify a view because this process drops the view information right now."
|
|
)
|
|
|
|
st = peek_interpreter_stack()
|
|
assert (
|
|
st is None or st.key() == TransformType.Functionalize
|
|
), "Expect st to be either None or have Functionalize transform key."
|
|
if st is None:
|
|
# the case of AOTAutograd
|
|
torch._sync(t)
|
|
unwrap_t = torch._from_functional_tensor(t)
|
|
with torch._dispatch.python.suspend_functionalization():
|
|
fake_t = self.meta_tensor(
|
|
unwrap_t,
|
|
shape_env=shape_env,
|
|
callback=callback,
|
|
source=source,
|
|
symbolic_context=symbolic_context,
|
|
)
|
|
out = torch._to_functional_tensor(fake_t)
|
|
torch._mirror_autograd_meta_to(fake_t, out)
|
|
return out
|
|
else:
|
|
# torch.func.functionalize
|
|
reapply_views = torch._C._functionalization_reapply_views_tls()
|
|
unwrap_t = _unwrap_functional_tensor(t, reapply_views)
|
|
pop_st_ctx = (
|
|
torch._functorch.pyfunctorch.temporarily_pop_interpreter_stack()
|
|
)
|
|
with pop_st_ctx:
|
|
fake_t = self.meta_tensor(
|
|
unwrap_t,
|
|
shape_env=shape_env,
|
|
callback=callback,
|
|
source=source,
|
|
symbolic_context=symbolic_context,
|
|
)
|
|
return _wrap_functional_tensor(fake_t, current_level())
|
|
self.miss += 1
|
|
return NotImplemented
|
|
else:
|
|
self.hit += 1
|
|
|
|
disable_functorch = torch._C._DisableFuncTorch
|
|
with disable_functorch():
|
|
r = self.meta_tensor(
|
|
t,
|
|
shape_env=shape_env,
|
|
callback=callback,
|
|
source=source,
|
|
symbolic_context=symbolic_context,
|
|
)
|
|
if type(t) is torch.nn.Parameter:
|
|
# NB: Cannot directly use Parameter constructor
|
|
# because that would force a detach, not desirable
|
|
r._is_param = True
|
|
return r
|
|
elif torch.overrides.is_tensor_like(t):
|
|
self.miss += 1
|
|
return NotImplemented
|
|
else:
|
|
# non-Tensor types don't count as hit or miss
|
|
return t
|
|
|
|
|
|
import torch._prims_common as utils
|