from typing import Any, Dict, Optional, Type from torch.nn.utils.parametrize import type_before_parametrizations, is_parametrized from itertools import chain from torch import nn __all__ = [ "module_contains_param", "swap_module", "module_to_fqn", "fqn_to_module", "get_arg_info_from_tensor_fqn", "FakeSparsity", ] def module_contains_param(module: nn.Module, parametrization: Type[nn.Module]) -> bool: if is_parametrized(module): # see if any of the module tensors have a parametriztion attached that matches the one passed in return any( any(isinstance(param, parametrization) for param in param_list) for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator] ) return False def swap_module( mod: nn.Module, mapping: Dict[Type[nn.Module], Type[nn.Module]] ) -> nn.Module: r"""Swaps the module using from_dense according to the mapping passed in. Args: mod: input module mapping: a dictionary that maps from nn module to sparse nn module Return: The corresponding sparse module of `mod` according to mapping, created using from_dense """ if type_before_parametrizations(mod) in mapping: sparse_mod = mapping[type_before_parametrizations(mod)] # TODO Fix this typing, as Type[Module] has no attribute "from_dense" new_mod = sparse_mod.from_dense(mod) # type: ignore[attr-defined] # Preserve module's pre forward hooks. They'll be called on quantized input for pre_hook_fn in mod._forward_pre_hooks.values(): new_mod.register_forward_pre_hook(pre_hook_fn) # Preserve module's post forward hooks except _observer_forward_hook # After convert they'll work with quantized output for hook_fn in mod._forward_hooks.values(): new_mod.register_forward_hook(hook_fn) # respect device affinity when swapping modules devices = {p.device for p in chain(mod.parameters(), mod.buffers())} assert len(devices) <= 1, ( f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" ) device = next(iter(devices)) if len(devices) > 0 else None if device: new_mod.to(device) return new_mod else: return mod def module_to_fqn( model: nn.Module, module: nn.Module, prefix: str = "" ) -> Optional[str]: """ Returns the fqn for a module or None if module not a descendent of model. """ if module is model: return "" for name, child in model.named_children(): fqn = module_to_fqn(child, module, ".") if isinstance(fqn, str): return prefix + name + fqn return None def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]: """ Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path` doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors. """ if path != "": for name in path.split("."): model = getattr(model, name, None) return model def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]: """ Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name """ # string manip to split tensor_fqn into module_fqn and tensor_name # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' tensor_name = tensor_fqn.split(".")[-1] module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] module = fqn_to_module(model, module_fqn) return { "module_fqn": module_fqn, "module": module, "tensor_name": tensor_name, "tensor_fqn": tensor_fqn, } # Parametrizations class FakeSparsity(nn.Module): r"""Parametrization for the weights. Should be attached to the 'weight' or any other parameter that requires a mask applied to it. Note:: Once the mask is passed, the variable should not change the id. The contents of the mask can change, but the mask reference itself should not. """ def __init__(self, mask): super().__init__() self.register_buffer("mask", mask) def forward(self, x): assert self.mask.shape == x.shape return self.mask * x def state_dict(self, *args, **kwargs): # We don't want to let the parametrizations to save the mask. # That way we make sure that the linear module doesn't store the masks # alongside their parametrizations. return {}