# mypy: ignore-errors import collections import dataclasses import functools import inspect import sys from typing import Dict, List, Optional from torch._subclasses.fake_tensor import is_fake from .. import variables from ..bytecode_transformation import ( create_call_function, create_call_method, create_instruction, ) from ..eval_frame import skip_code from ..exc import unimplemented from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GetItemSource from ..utils import dict_keys, dict_values, istype, specialize_symnode from .base import MutableLocal, VariableTracker from .constant import ConstantVariable # [Adding a new supported class within the keys of ConstDictVarialble] # - Add its tracker type to is_hashable # - (perhaps) Define how it is compared in _HashableTracker._eq_impl def is_hashable(x): if isinstance(x, variables.TensorVariable): # Tensors are hashable if they have an example_value (a fake tensor) # Most VT's should have one. # It'd be nice if at some point we could assert that they all have one return x.as_proxy().node.meta.get("example_value") is not None elif isinstance(x, variables.TupleVariable): return all(is_hashable(e) for e in x.items) else: return isinstance( x, ( variables.BuiltinVariable, variables.SymNodeVariable, variables.ConstantVariable, variables.EnumVariable, variables.user_defined.UserDefinedClassVariable, variables.UserFunctionVariable, variables.SkipFunctionVariable, variables.misc.NumpyVariable, variables.NNModuleVariable, variables.MethodWrapperVariable, variables.TorchInGraphFunctionVariable, variables.TypingVariable, variables.FunctoolsPartialVariable, ), ) class ConstDictVariable(VariableTracker): class _HashableTracker: """ Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable This should not be seen or touched by anything outside of ConstDictVariable and its children Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing """ def __init__(self, vt): # We specialize SymNodes vt = specialize_symnode(vt) # TODO Temorarily remove to figure out what keys are we breaking on # and add proper support for them if not is_hashable(vt): unimplemented(f"Dict key of type {type(vt)}. Key: {vt}") self.vt = vt @property def underlying_value(self): if isinstance(self.vt, variables.TensorVariable): x = self.vt.as_proxy().node.meta["example_value"] elif isinstance(self.vt, variables.TupleVariable): Hashable = ConstDictVariable._HashableTracker x = tuple(Hashable(e).underlying_value for e in self.vt.items) elif isinstance(self.vt, variables.NNModuleVariable): return self.vt.module elif isinstance(self.vt, variables.UserFunctionVariable): return self.vt.get_function() else: x = self.vt.as_python_constant() return x def __hash__(self): return hash(self.underlying_value) @staticmethod def _eq_impl(a, b): # TODO: Put this in utils and share it between variables/builtin.py and here if type(a) != type(b): return False elif isinstance(a, tuple): Hashable = ConstDictVariable._HashableTracker return len(a) == len(b) and all( Hashable._eq_impl(u, v) for u, v in zip(a, b) ) elif is_fake(a): return a is b else: return a == b def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: Hashable = ConstDictVariable._HashableTracker assert isinstance(other, Hashable) or ConstantVariable.is_literal( other ), type(other) if isinstance(other, Hashable): return Hashable._eq_impl(self.underlying_value, other.underlying_value) # constant return Hashable._eq_impl(self.underlying_value, other) def __init__( self, items: Dict[VariableTracker, VariableTracker], user_cls=dict, **kwargs ): super().__init__(**kwargs) Hashable = ConstDictVariable._HashableTracker # Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers assert all( isinstance(x, (VariableTracker, Hashable)) and isinstance(v, VariableTracker) for x, v in items.items() ) def make_hashable(key): return key if isinstance(key, Hashable) else Hashable(key) self.items = {make_hashable(x): v for x, v in items.items()} self.user_cls = user_cls def as_proxy(self): return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} def as_python_constant(self): return { k.vt.as_python_constant(): v.as_python_constant() for k, v in self.items.items() } def keys_as_python_constant(self): return {k.vt.as_python_constant(): v for k, v in self.items.items()} def python_type(self): return self.user_cls def __contains__(self, vt): assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker return is_hashable(vt) and Hashable(vt) in self.items def reconstruct(self, codegen): # instructions to load collections.OrderedDict if necessary if self.user_cls is collections.OrderedDict: codegen.extend_output( [ codegen.create_load_python_module(collections, True), codegen.create_load_attr("OrderedDict"), ] ) # instructions to build the dict keys and values for key, value in self.items.items(): codegen(key.vt) codegen(value) # BUILD_MAP and calling collections.OrderedDict if necessary if self.user_cls is collections.OrderedDict: codegen.extend_output( [ create_instruction("BUILD_MAP", arg=len(self.items)), *create_call_function(1, False), ] ) # BUILD_MAP only if user_cls is dict else: codegen.append_output(create_instruction("BUILD_MAP", arg=len(self.items))) def getitem_const(self, arg: VariableTracker): key = ConstDictVariable._HashableTracker(arg) if key not in self.items: raise KeyError(arg.value) return self.items[key] def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": from . import ( BuiltinVariable, ConstantVariable, ListIteratorVariable, ListVariable, TupleVariable, ) Hashable = ConstDictVariable._HashableTracker arg_hashable = args and is_hashable(args[0]) if name == "__getitem__": assert len(args) == 1 return self.getitem_const(args[0]) elif name == "items": assert not (args or kwargs) return TupleVariable( [TupleVariable([k.vt, v]) for k, v in self.items.items()] ) elif name == "keys": assert not (args or kwargs) return DictKeys(self) elif name == "values": assert not (args or kwargs) return DictValues(self) elif name == "copy": assert not (args or kwargs) return self.clone(items=self.items.copy(), mutable_local=MutableLocal()) elif name == "__len__": assert not (args or kwargs) return ConstantVariable.create(len(self.items)) elif name == "__setitem__" and arg_hashable and self.mutable_local: assert not kwargs and len(args) == 2 tx.output.side_effects.mutation(self) self.items[Hashable(args[0])] = args[1] return ConstantVariable.create(None) elif name in ("pop", "get") and len(args) in (1, 2) and args[0] not in self: # missing item, return the default value if len(args) == 1: return ConstantVariable(None) else: return args[1] elif name == "pop" and arg_hashable and self.mutable_local: tx.output.side_effects.mutation(self) return self.items.pop(Hashable(args[0])) elif name == "clear": tx.output.side_effects.mutation(self) self.items.clear() return ConstantVariable.create(None) elif ( name == "update" and len(args) == 1 and isinstance( args[0], ( ConstDictVariable, ListVariable, TupleVariable, ListIteratorVariable, ), ) and self.mutable_local ): tx.output.side_effects.mutation(self) if isinstance(args[0], ConstDictVariable): dict_vt = args[0] else: dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) self.items.update(dict_vt.items) # Wrap strings kwargs = { Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() } self.items.update(kwargs) return ConstantVariable.create(None) elif name in ("get", "__getattr__") and args[0] in self: return self.getitem_const(args[0]) elif name == "__contains__" and len(args) == 1: return ConstantVariable.create(args[0] in self) else: return super().call_method(tx, name, args, kwargs) def unpack_var_sequence(self, tx): return [x.vt for x in self.items.keys()] class DefaultDictVariable(ConstDictVariable): def __init__(self, items, user_cls, default_factory=None, **kwargs): super().__init__(items, user_cls, **kwargs) assert user_cls is collections.defaultdict self.default_factory = default_factory def is_python_constant(self): # Return false for unsupported defaults. This ensures that a bad handler # path is not taken in BuiltinVariable for getitem. if self.default_factory not in [list, tuple, dict] and not self.items: return False return super().is_python_constant() @staticmethod def is_supported_arg(arg): if isinstance(arg, variables.BuiltinVariable): return arg.fn in [list, tuple, dict] else: return isinstance(arg, variables.functions.BaseUserFunctionVariable) def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": if name == "__getitem__": assert len(args) == 1 if args[0] in self: return self.getitem_const(args[0]) else: if self.default_factory is None: raise KeyError(f"{args[0]}") else: default_var = self.default_factory.call_function(tx, [], {}) super().call_method( tx, "__setitem__", (args[0], default_var), kwargs ) return default_var else: return super().call_method(tx, name, args, kwargs) class SetVariable(ConstDictVariable): """We model a sets as dictonary with None values""" def __init__( self, items: List[VariableTracker], **kwargs, ): items = dict.fromkeys(items, SetVariable._default_value()) super().__init__(items, **kwargs) @property def set_items(self): return set(self.items.keys()) @staticmethod def _default_value(): # Variable to fill in he keys of the dictinary return ConstantVariable.create(None) def as_proxy(self): return {k.vt.as_proxy() for k in self.set_items} def python_type(self): return set def as_python_constant(self): return {k.vt.as_python_constant() for k in self.set_items} def reconstruct(self, codegen): codegen.foreach([x.vt for x in self.set_items]) codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) def call_method( self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker], ) -> "VariableTracker": # We foward the calls to the dictionary model if name == "add": assert not kwargs assert len(args) == 1 name = "__setitem__" args = (args[0], SetVariable._default_value()) elif name == "pop": assert not kwargs assert not args # Choose an item at random and pop it via the Dict.pop method result = self.set_items.pop().vt super().call_method(tx, name, (result,), kwargs) return result return super().call_method(tx, name, args, kwargs) def getitem_const(self, arg: VariableTracker): raise RuntimeError("Illegal to getitem on a set") class DictView(VariableTracker): """ Models _PyDictViewObject This is an "abstract" class. Subclasses will override kv and the items method """ kv: Optional[str] = None def __init__(self, dv_dict: ConstDictVariable, **kwargs): super().__init__(**kwargs) assert self.kv in ("keys", "values") assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict @property def view_items(self): return getattr(self.dv_dict.items, self.kv)() @property def view_items_vt(self): # Returns an iterable of the unpacked items # Implement in the subclasses raise NotImplementedError() def unpack_var_sequence(self, tx): def unwrap(x): return x.vt if self.kv == "keys" else x return [unwrap(x) for x in self.view_items] def reconstruct(self, codegen): codegen(self.dv_dict) codegen.extend_output( [ create_instruction("LOAD_METHOD", argval=self.kv), *create_call_method(0), ] ) def call_method( self, tx, name, args: List["VariableTracker"], kwargs: Dict[str, "VariableTracker"], ) -> "VariableTracker": if name == "__len__": return self.dv_dict.call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs) class DictKeys(DictView): kv = "keys" @property def set_items(self): return set(self.view_items) @property def view_items_vt(self): # Returns an iterable of the unpacked items return [x.vt for x in self.view_items] def python_type(self): return dict_keys def call_method( self, tx, name, args: List["VariableTracker"], kwargs: Dict[str, "VariableTracker"], ) -> "VariableTracker": if name == "__contains__": return self.dv_dict.call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs) class DictValues(DictView): # DictValues is an iterable but cannot be compared. kv = "values" @property def view_items_vt(self): return list(self.view_items) def python_type(self): return dict_values def _is_matching_transformers_cls(cls) -> bool: mod = sys.modules.get("transformers.file_utils") return mod is not None and issubclass(cls, mod.ModelOutput) def _is_matching_diffusers_cls(cls) -> bool: mod = sys.modules.get("diffusers.utils") return mod is not None and issubclass(cls, mod.BaseOutput) def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker": """Shared method between DataClassVariable and CustomizedDictVariable where items are attrs""" if name in self.items or hasattr(self.user_cls, name): return ConstantVariable(True) elif istype(self.mutable_local, MutableLocal) and self.source is None: # Something created locally can't have any extra fields on it return ConstantVariable(False) elif self.mutable_local is None and self.source: # Maybe add a guard try: example = tx.output.root_tx.get_example_value(self.source) install_guard( AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) ) return ConstantVariable(hasattr(example, name)) except KeyError: pass unimplemented( f"hasattr({self.__class__.__name__}, {name}) {self.mutable_local} {self.source}" ) class DataClassVariable(ConstDictVariable): """ This is a bit of a hack to deal with transformers.file_utils.ModelOutput() from huggingface. ModelOutput causes trouble because it a a mix of a dataclass and a OrderedDict and it calls super() methods implemented in C. """ # ModelOutput() excludes None, though generic datclasses don't include_none = False @staticmethod @functools.lru_cache(None) def _patch_once(): try: from transformers.file_utils import ModelOutput for obj in ModelOutput.__dict__.values(): if callable(obj): skip_code(obj.__code__) except ImportError: pass try: from diffusers.utils import BaseOutput for obj in BaseOutput.__dict__.values(): if callable(obj): skip_code(obj.__code__) except ImportError: pass @staticmethod def is_matching_cls(cls): return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) @classmethod def is_matching_object(cls, obj): return cls.is_matching_cls(type(obj)) @classmethod def create(cls, user_cls, args, kwargs, options): DataClassVariable._patch_once() skip_code(user_cls.__init__.__code__) keys = [f.name for f in dataclasses.fields(user_cls)] bound = inspect.signature(user_cls).bind(*args, **kwargs) bound.apply_defaults() assert set(bound.arguments.keys()) == set(keys) items = {} for key in keys: val = bound.arguments[key] key = ConstantVariable.create(key) if isinstance(val, VariableTracker): items[key] = val else: if cls.include_none: assert variables.ConstantVariable.is_literal(val) items[key] = variables.ConstantVariable.create(val) else: assert val is None, f"unexpected {val}" if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable): unimplemented("DataClassVariable iterator constructor") # TODO(jansel): implement unpacking logic in ModelOutput.__post_init__ return cls(items, user_cls, **options) @classmethod def wrap(cls, builder, obj): user_cls = type(obj) keys = [f.name for f in dataclasses.fields(user_cls)] excluded = [] items = {} for key in keys: # __init__ function of a dataclass might not have yet defined the key if hasattr(obj, key): val = getattr(obj, key) var = builder.__class__( tx=builder.tx, source=AttrSource(builder.source, key) )(val) if val is not None or cls.include_none: key = ConstantVariable.create(key) items[key] = var else: excluded.append(var) return cls(items, user_cls) def __init__(self, items, user_cls, **options): super().__init__(items, user_cls, **options) assert self.is_matching_cls(user_cls) def as_proxy(self): raise NotImplementedError() def reconstruct(self, codegen): codegen.extend_output([codegen._create_load_const(self.user_cls)]) # All the keys are just wrapped strings d = self.keys_as_python_constant() codegen.foreach(d.values()) keys = tuple(d.keys()) codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True)) def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": if name == "__getitem__": assert not kwargs and len(args) == 1 val = args[0] if val.python_type() == str: return self.getitem_const(val) else: return self.call_method(tx, "to_tuple", [], {}).call_method( tx, "__getitem__", args, kwargs ) elif name == "to_tuple": assert not (args or kwargs) return variables.TupleVariable(list(self.items.values())) elif name == "__setattr__": name = "__setitem__" return super().call_method(tx, name, args, kwargs) def var_getattr(self, tx, name: str) -> "VariableTracker": name_vt = ConstantVariable.create(name) if name_vt in self: return self.call_method(tx, "__getitem__", [name_vt], {}) elif not self.include_none: defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)} if name in defaults: assert variables.ConstantVariable.is_literal(defaults[name]) return variables.ConstantVariable.create(defaults[name]) super().var_getattr(tx, name) call_hasattr = _call_hasattr_customobj class CustomizedDictVariable(ConstDictVariable): @staticmethod def is_matching_cls(cls): # True if using default OrderedDict.__init__ and did not implement __post_init__ if ( issubclass(cls, collections.OrderedDict) and cls.__init__ is collections.OrderedDict.__init__ and not hasattr(cls, "__post_init__") ): return True # hack for HF usecase: # assume dataclass annotation for ModelOutput subclass # assume self.create is AA to ModelOutput.__post_init__ return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) @classmethod def is_matching_object(cls, obj): return cls.is_matching_cls(type(obj)) # called from user_defined.py # when is_matching_cls(cls) is true @classmethod def create(cls, user_cls, args, kwargs, options): # avoid tracing when returning ModelOutput from forward func for attr_name in ("__init__", "__post_init__", "__setattr__", "__setitem__"): if hasattr(user_cls, attr_name): fn = getattr(user_cls, attr_name) assert callable(fn), f"expect callable attr {attr_name}" if hasattr(fn, "__code__"): skip_code(fn.__code__) if dataclasses.is_dataclass(user_cls): # @dataclass CustomDict(a=1, b=2) bound = inspect.signature(user_cls).bind(*args, **kwargs) bound.apply_defaults() def make_var(x): if isinstance(x, VariableTracker): return x elif ConstantVariable.is_literal(x): return ConstantVariable.create(x) else: unimplemented( "expect VariableTracker or ConstantVariable.is_literal" ) items = { ConstantVariable.create(k): make_var(v) for k, v in bound.arguments.items() } elif not args: # CustomDict(a=1, b=2) in the general (non-dataclass) case. items = {ConstantVariable.create(k): v for k, v in kwargs.items()} elif len(args) == 1 and isinstance(args[0], ConstDictVariable) and not kwargs: # CustomDict({'a': 1, 'b': 2}) items = args[0].items else: unimplemented("custom dict init with args/kwargs unimplemented") return cls(items, user_cls, **options) # called from builder.py @classmethod def wrap(cls, builder, obj): raise NotImplementedError() def __init__(self, items, user_cls, **options): super().__init__(items, user_cls, **options) assert self.is_matching_cls(user_cls) def as_proxy(self): raise NotImplementedError() # 'RETURN_VALUE triggered compile' # called from torch/_dynamo/codegen.py def reconstruct(self, codegen): codegen.extend_output([codegen._create_load_const(self.user_cls)]) # All the keys are just wrapped strings d = self.keys_as_python_constant() codegen.foreach(d.values()) keys = tuple(d.keys()) codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True)) def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": fn = getattr(self.user_cls, name) source = None if self.source is None else AttrSource(self.source, name) if hasattr(fn, "__objclass__") and fn.__objclass__ in ( dict, collections.OrderedDict, ): # for python dict method without overridden return super().call_method(tx, name, args, kwargs) elif name in ("__getitem__", "to_tuple", "__setitem__", "__setattr__"): # for user overridden method return tx.inline_user_function_return( variables.UserFunctionVariable(fn, source=source), [self] + list(args), kwargs, ) unimplemented("custom dict: call_method unimplemented name=%s", name) def var_getattr(self, tx, name: str) -> "VariableTracker": name_vt = ConstantVariable.create(name) if name_vt in self: return self.call_method(tx, "__getitem__", [name_vt], {}) super().var_getattr(tx, name) call_hasattr = _call_hasattr_customobj @functools.lru_cache(None) def _install_PretrainedConfig_patch(): import transformers # We need to monkeypatch transformers here, sadly. # TODO(voz): Upstream to transformers lib def _dynamo_overriden_transformers_eq(self, other): if not hasattr(other, "__dict__"): return False return self.__dict__ == other.__dict__ transformers.configuration_utils.PretrainedConfig.__eq__ = ( _dynamo_overriden_transformers_eq ) class HFPretrainedConfigVariable(VariableTracker): """ Hack for HuggingFace PretrainedConfig """ @staticmethod def is_matching_cls(cls): mod = sys.modules.get("transformers.configuration_utils") is_match = mod is not None and issubclass(cls, mod.PretrainedConfig) # Lazily install monkeypatch the first time we see it in dynamo if is_match: _install_PretrainedConfig_patch() return is_match @classmethod def is_matching_object(cls, obj): return cls.is_matching_cls(type(obj)) def __init__(self, obj, **kwargs): super().__init__(**kwargs) self.obj = obj assert self.is_matching_cls(type(obj)) def var_getattr(self, tx, name: str) -> "VariableTracker": from . import ConstantVariable return ConstantVariable.create(getattr(self.obj, name)) def call_hasattr(self, tx, name: str) -> "VariableTracker": return variables.ConstantVariable.create(hasattr(self.obj, name)) class PythonSysModulesVariable(VariableTracker): """Special case for sys.modules. Without this we will guard on the exact set of modules imported in the lifetime of the python program. """ def python_type(self): return dict def reconstruct(self, codegen): codegen.extend_output( [ codegen.create_load_python_module(sys, True), codegen.create_load_attr("modules"), ] ) def call_method( self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] ): from .builder import VariableBuilder if name == "__getitem__": return self.call_getitem(tx, *args, **kwargs) elif name == "get": return self.call_get(tx, *args, **kwargs) elif name == "__contains__": return self.call_contains(tx, *args, **kwargs) # Fallback to dict implementation real_dict = VariableBuilder(tx, self.source)(sys.modules) return real_dict.call_method(tx, name, args, kwargs) def _contains_helper(self, tx, key: VariableTracker): k = key.as_python_constant() has_key = k in sys.modules install_guard( self.make_guard( functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key) ) ) return k, has_key def call_contains(self, tx, key: VariableTracker): k, has_key = self._contains_helper(tx, key) return ConstantVariable.create(value=has_key) def call_get( self, tx, key: VariableTracker, default: Optional[VariableTracker] = None ): from .builder import VariableBuilder k, has_key = self._contains_helper(tx, key) if has_key: return VariableBuilder( tx, GetItemSource(self.source, k), )(sys.modules[k]) if default is not None: return default return ConstantVariable.create(value=None) def call_getitem(self, tx, key: VariableTracker): from .builder import VariableBuilder k, has_key = self._contains_helper(tx, key) return VariableBuilder( tx, GetItemSource(self.source, k), )(sys.modules[k])