import inspect from typing import Any, Dict, List, Optional, Union import torch.nn from . import utils, variables from .bytecode_transformation import ( create_call_function, create_call_method, create_instruction, ) from .codegen import PyCodegen from .exc import unimplemented from .source import LocalSource, Source from .utils import nn_module_new, object_new from .variables.base import ( is_side_effect_safe, MutableLocalBase, MutableLocalSource, VariableTracker, ) class MutableSideEffects(MutableLocalBase): """ VariableTracker.mutable_local marker to indicate a list passed as an input that if we mutate we need to re-apply those mutations after the graph runs. """ def __init__(self, source: Source, is_modified: bool = False): super().__init__(MutableLocalSource.Existing) self.source = source self.is_modified = is_modified class AttributeMutation(MutableLocalBase): """ VariableTracker.mutable_local marker to track changes to attributes """ def __init__(self, typ: MutableLocalSource, source: Optional[Source]): super().__init__(typ) self.source = source class AttributeMutationExisting(AttributeMutation): def __init__(self, source: Source): super().__init__(MutableLocalSource.Existing, source) self.source = source class AttributeMutationNew(AttributeMutation): def __init__(self, source: Optional[Source], cls_source: Optional[Source]): super().__init__(MutableLocalSource.Local, source) self.cls_source = cls_source class SideEffects: """ Track side effects (list mutation, setattr, etc) that need to be applied after an FX graph is run. """ id_to_variable: Dict[int, VariableTracker] store_attr_mutations: Dict[MutableLocalBase, Dict[str, VariableTracker]] keepalive: List[Any] def __init__( self, id_to_variable=None, store_attr_mutations=None, keepalive=None, save_for_backward=None, tensor_hooks=None, ): super().__init__() self.id_to_variable = id_to_variable or {} self.store_attr_mutations = store_attr_mutations or {} self.keepalive = keepalive or [] self.save_for_backward = save_for_backward or [] self.tensor_hooks = tensor_hooks or {} def __eq__(self, other: object) -> bool: assert isinstance(other, SideEffects) # NB: do NOT test keepalive return ( self.id_to_variable == other.id_to_variable and self.store_attr_mutations == other.store_attr_mutations and self.save_for_backward == other.save_for_backward and self.tensor_hooks == other.tensor_hooks ) def diff(self, other: "SideEffects") -> Optional[str]: if self.id_to_variable != other.id_to_variable: sk_itv = self.id_to_variable.keys() ok_itv = other.id_to_variable.keys() if sk_itv != ok_itv: return f"id_to_variable keys: {sk_itv} != {ok_itv}" # Feel free to augment this with more fancy diffing logic # if needed for debugging return "id_to_variable: unknown diff" elif self.store_attr_mutations != other.store_attr_mutations: sk_sam = self.store_attr_mutations.keys() ok_sam = other.store_attr_mutations.keys() if sk_sam != ok_sam: return f"store_attr_mutations keys: {sk_sam} != {ok_sam}" return "store_attr_mutations: unknown diff" elif self.save_for_backward != other.save_for_backward: return "save_for_backward" elif self.tensor_hooks != other.tensor_hooks: return "tensor_hooks" else: return None def clone(self): """Create a shallow copy""" return self.__class__( id_to_variable=dict(self.id_to_variable), store_attr_mutations={ k: dict(v) for k, v in self.store_attr_mutations.items() }, keepalive=list(self.keepalive), save_for_backward=self.save_for_backward, tensor_hooks=self.tensor_hooks, ) def apply(self, fn, cache=None, skip_fn=lambda _: False): if cache is None: cache = dict() self.id_to_variable = { k: VariableTracker.apply(fn, v, cache, skip_fn) for k, v in self.id_to_variable.items() } self.store_attr_mutations = { k: VariableTracker.apply(fn, v, cache, skip_fn) for k, v in self.store_attr_mutations.items() } self.save_for_backward = VariableTracker.apply( fn, self.save_for_backward, cache, skip_fn ) self.tensor_hooks = VariableTracker.apply(fn, self.tensor_hooks, cache, skip_fn) def __contains__(self, item): return id(item) in self.id_to_variable def __getitem__(self, item): return self.id_to_variable[id(item)] def check_allowed_side_effect(self, item): from torch._dynamo.variables.misc import AutogradFunctionContextVariable # People do things like self.dim = dim inside autograd.Function. # These are benign. if isinstance(item, AutogradFunctionContextVariable): return True if not is_side_effect_safe(item.mutable_local): unimplemented( "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)" ) def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): assert self.is_attribute_mutation(item) self.check_allowed_side_effect(item) if item.mutable_local not in self.store_attr_mutations: self.store_attr_mutations[item.mutable_local] = {} self.store_attr_mutations[item.mutable_local][name] = value def load_attr(self, item, name, deleted_ok=False): assert self.is_attribute_mutation(item) result = self.store_attr_mutations[item.mutable_local][name] if not deleted_ok and isinstance(result, variables.DeletedVariable): unimplemented("read deleted attribute") return result def store_cell(self, cellvar, value): assert isinstance(cellvar, variables.NewCellVariable) assert isinstance(value, variables.VariableTracker) self.store_attr(cellvar, "cell_contents", value) def load_cell(self, cellvar): assert isinstance(cellvar, variables.NewCellVariable) return self.load_attr(cellvar, "cell_contents") def load_global(self, gvar: VariableTracker, name: str): assert isinstance(gvar, variables.VariableTracker) return self.load_attr(gvar, name) def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker): assert isinstance(gvar, variables.VariableTracker) assert isinstance(value, variables.VariableTracker) self.store_attr(gvar, name, value) @staticmethod def cls_supports_mutation_side_effects(cls): return inspect.getattr_static(cls, "__setattr__", None) in ( object.__setattr__, torch.nn.Module.__setattr__, ) def is_attribute_mutation(self, item): return isinstance(item.mutable_local, AttributeMutation) def has_pending_mutation(self, item): return self.is_attribute_mutation(item) and bool( self.store_attr_mutations.get(item.mutable_local) ) def is_modified(self, item): if isinstance(item.mutable_local, AttributeMutationNew): return True if self.is_attribute_mutation(item): return item.mutable_local in self.store_attr_mutations return item.mutable_local.is_modified def _track_obj( self, item: Any, variable: VariableTracker, mutable_cls=MutableSideEffects, ): """Start tracking a new variable for mutation""" assert variable.source is not None variable.mutable_local = mutable_cls(variable.source) self.id_to_variable[id(item)] = variable self.keepalive.append(item) return variable track_mutable = _track_obj def track_object_existing( self, item: Any, variable: VariableTracker, ): return self._track_obj(item, variable, mutable_cls=AttributeMutationExisting) def track_object_new( self, cls_source: Source, user_cls: Any, variable_cls: Any, options, ): if user_cls is torch.autograd.function.FunctionCtx: obj = torch.autograd.Function() elif issubclass(user_cls, torch.nn.Module): obj = nn_module_new(user_cls) else: obj = object_new(user_cls) variable = variable_cls( obj, mutable_local=AttributeMutationNew(None, cls_source), **options, ) self.id_to_variable[id(obj)] = variable self.keepalive.append(obj) return variable def track_cell_new( self, ): obj = object() variable = variables.NewCellVariable( mutable_local=AttributeMutationNew(None, None), ) self.id_to_variable[id(obj)] = variable self.keepalive.append(obj) return variable def track_cell_existing(self, source: Source, item: Any): variable = variables.NewCellVariable( mutable_local=AttributeMutationExisting(source), ) self.id_to_variable[id(item)] = variable self.keepalive.append(item) return variable def track_global_existing(self, source: Source, item: Any): variable = variables.NewGlobalVariable( mutable_local=AttributeMutationExisting(source), ) self.id_to_variable[id(item)] = variable self.keepalive.append(item) return variable def track_save_for_backward(self, ctx, args): assert isinstance(ctx, variables.AutogradFunctionContextVariable) self.save_for_backward.append((ctx, args)) def track_tensor_variables_from_runahead_side_effects(self, other): # In higher order ops we want to keep track of tensors seen in the # speculate_subgraph so that we don't lift them again as a new input in # other speculate_subgraph or in the root tracer. for other_item in other.keepalive: other_id = id(other_item) other_variable = other.id_to_variable[other_id] if other_id not in self.id_to_variable and isinstance( other_variable, variables.TensorVariable ): self.track_object_existing(other_item, other_variable) def prune_dead_object_new(self, tx): live_new_objects = set() skip_obj = None def visit(var: VariableTracker): if ( isinstance(var.mutable_local, AttributeMutationNew) and var.mutable_local is not skip_obj ): live_new_objects.add(var.mutable_local) return var def is_live(var: Union[MutableLocalBase, VariableTracker]): if isinstance(var, AttributeMutationNew): return var in live_new_objects if isinstance(var, VariableTracker): return is_live(var.mutable_local) return True VariableTracker.apply(visit, (tx.stack, tx.symbolic_locals)) for var in self.id_to_variable.values(): if not isinstance(var.mutable_local, AttributeMutationNew): VariableTracker.apply(visit, var) for skip_obj, setattrs in self.store_attr_mutations.items(): VariableTracker.apply(visit, setattrs) self.id_to_variable = { k: v for k, v in self.id_to_variable.items() if is_live(v) } self.store_attr_mutations = { k: v for k, v in self.store_attr_mutations.items() if is_live(k) } def mutation(self, var): self.check_allowed_side_effect(var) if isinstance(var.mutable_local, MutableSideEffects): var.mutable_local = MutableSideEffects(var.mutable_local.source, True) def _get_modified_vars(self): return [var for var in self.id_to_variable.values() if self.is_modified(var)] def codegen_save_tempvars(self, cg: PyCodegen): for var in self._get_modified_vars(): if isinstance( var.mutable_local, (AttributeMutationExisting, AttributeMutationNew) ) and isinstance(var, variables.NewCellVariable): cg.load_import_from(utils.__name__, "make_cell") cg.extend_output(create_call_function(0, True)) cg.add_cache(var) if isinstance(var.mutable_local, AttributeMutationNew): var.mutable_local.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] elif isinstance(var.mutable_local, AttributeMutationNew): if isinstance(var, variables.AutogradFunctionContextVariable): unimplemented("AutogradFunctionContextVariable escaped") if "__call_nn_module_init" in self.store_attr_mutations.get( var.mutable_local, {} ): assert isinstance(var, variables.UnspecializedNNModuleVariable) cg.load_import_from(utils.__name__, "nn_module_new") else: cg.load_import_from(utils.__name__, "object_new") cg(var.mutable_local.cls_source) cg.extend_output(create_call_function(1, True)) cg.add_cache(var) var.mutable_local.source = LocalSource(cg.tempvars[var]) elif var in cg.tempvars: assert cg.tempvars.get(var) is None # subsequent usage should point to the original variable cg(var.mutable_local.source) cg.add_cache(var) for ctx, args in self.save_for_backward: cg(ctx.source) cg.extend_output( [create_instruction("LOAD_METHOD", argval="save_for_backward")] ) for arg in args: cg(arg) cg.extend_output( [ *create_call_method(len(args)), create_instruction("POP_TOP"), ] ) def register_hook(self, tensor, hook, handle, name): assert isinstance(tensor, variables.TensorVariable) assert isinstance(hook, variables.VariableTracker) assert ( isinstance(handle, variables.RemovableHandleVariable) and handle.mutable_local ) assert hasattr(torch.Tensor, name) idx = len(self.tensor_hooks.keys()) # duplicate index possible because of self.remove_hook() while idx in self.tensor_hooks: idx += 1 self.tensor_hooks[idx] = (tensor, hook, handle, name) assert not handle.idx handle.idx = idx def remove_hook(self, idx): del self.tensor_hooks[idx] def codegen_hooks(self, cg): for ( tensor, hook, handle, name, ) in self.tensor_hooks.values(): # Note: [On tensor.register_hook] # # register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented # when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries). # # For tensors with a source, we bypass direct inclusion of register_hook calls in the graph. # Instead, these are tracked and stashed as a global variable, enabling their association with tensors in # the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able # tensors. Because a source indicates knowledge of this object outside the torch compile region, and # because we are running residuals firmly before .backward() can be run, it is sound to invoke # `register_hook` on a known tensor. # # For tensors without a source, we support a limited subset of hooks. Global functions only, and # compiled_autograd must be enabled or we will graph break. # # Handling the Handle: When a user retains the register_hook result in a handle, we intercept the # STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed # bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the # stack intact. # # Dynamo Tensor Hooks Workflow: # - Functions passed to register_hook are lifted globally. # - For tensors with sources: # - In the "side_effects" phase of codegen, we iterate over tensors with hooks to: # - Generate the tensor. # - Issue a register_hook call on the tensor, linking to the globally stored function. # - Incorporate a handle if one was established in the eager phase. # - For tensors without sources: # - We don't generate any instructions for registering a hook. # - Handles from intermediary hooks are NYI. # - We produce a call function that utilizes the trace_wrapped higher order op, closing over it. # - We then manually insert the call function above into the graph. # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST. assert tensor.source, "Hooks on non input tensors NYI - should not get here" cg(tensor) cg.extend_output([cg.create_load_attr(name)]) cg(hook) cg.extend_output(create_call_function(1, True)) # Adding the handle to the cache means RemovableHandleVariable().reconstruct() will # be associated with the return value of register_hook(). This consumes the top of stack. cg.add_cache(handle) def codegen_update_mutated(self, cg: PyCodegen): suffixes = [] for var in self._get_modified_vars(): if isinstance(var, variables.ListVariable): # old[:] = new cg(var, allow_cache=False) cg(var.mutable_local.source) # type: ignore[attr-defined] cg.extend_output( [ cg.create_load_const(None), cg.create_load_const(None), create_instruction("BUILD_SLICE", arg=2), ] ) suffixes.append([create_instruction("STORE_SUBSCR")]) elif isinstance(var, variables.ConstDictVariable): cg.tx.output.update_co_names("clear") cg.tx.output.update_co_names("update") cg(var.mutable_local.source) # type: ignore[attr-defined] cg.extend_output([create_instruction("LOAD_METHOD", argval="update")]) cg(var, allow_cache=False) cg(var.mutable_local.source) # type: ignore[attr-defined] cg.extend_output([create_instruction("LOAD_METHOD", argval="clear")]) suffixes.append( [ *create_call_method(0), # clear create_instruction("POP_TOP"), *create_call_method(1), # update create_instruction("POP_TOP"), ] ) elif self.is_attribute_mutation(var): for name, value in self.store_attr_mutations.get( var.mutable_local, {} ).items(): if isinstance(var, variables.NewGlobalVariable): cg.tx.output.update_co_names(name) cg(value) suffixes.append( [create_instruction("STORE_GLOBAL", argval=name)] ) elif name == "__call_nn_module_init": pass # handled in codegen_save_tempvars elif isinstance(value, variables.DeletedVariable): if isinstance( var.mutable_local, AttributeMutationExisting ) and hasattr(getattr(var, "value", None), name): cg.tx.output.update_co_names(name) cg(var.mutable_local.source) suffixes.append( [create_instruction("DELETE_ATTR", argval=name)] ) else: cg.tx.output.update_co_names(name) cg(value) cg(var.mutable_local.source) suffixes.append([create_instruction("STORE_ATTR", argval=name)]) elif isinstance(var, variables.TupleIteratorVariable): for _ in range(var.index): cg.load_import_from(utils.__name__, "iter_next") cg(var.mutable_local.source) # type: ignore[attr-defined] cg.extend_output(create_call_function(1, True)) cg.append_output(create_instruction("POP_TOP")) else: raise AssertionError(type(var)) # do all the actual mutations at the very end to handle dependencies for suffix in reversed(suffixes): cg.extend_output(suffix) def is_empty(self): return not ( any(map(self.is_modified, self.id_to_variable.values())) or self.tensor_hooks or self.save_for_backward or self.tensor_hooks ) def clear(self): self.keepalive.clear() self.id_to_variable.clear()