from typing import List, Optional import torch from torch.backends._nnapi.serializer import _NnapiSerializer ANEURALNETWORKS_PREFER_LOW_POWER = 0 ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1 ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2 class NnapiModule(torch.nn.Module): """Torch Module that wraps an NNAPI Compilation. This module handles preparing the weights, initializing the NNAPI TorchBind object, and adjusting the memory formats of all inputs and outputs. """ # _nnapi.Compilation is defined comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined] weights: List[torch.Tensor] out_templates: List[torch.Tensor] def __init__( self, shape_compute_module: torch.nn.Module, ser_model: torch.Tensor, weights: List[torch.Tensor], inp_mem_fmts: List[int], out_mem_fmts: List[int], compilation_preference: int, relax_f32_to_f16: bool, ): super().__init__() self.shape_compute_module = shape_compute_module self.ser_model = ser_model self.weights = weights self.inp_mem_fmts = inp_mem_fmts self.out_mem_fmts = out_mem_fmts self.out_templates = [] self.comp = None self.compilation_preference = compilation_preference self.relax_f32_to_f16 = relax_f32_to_f16 @torch.jit.export def init(self, args: List[torch.Tensor]): assert self.comp is None self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator] self.weights = [w.contiguous() for w in self.weights] comp = torch.classes._nnapi.Compilation() comp.init2( self.ser_model, self.weights, self.compilation_preference, self.relax_f32_to_f16, ) self.comp = comp def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]: if self.comp is None: self.init(args) comp = self.comp assert comp is not None outs = [torch.empty_like(out) for out in self.out_templates] assert len(args) == len(self.inp_mem_fmts) fixed_args = [] for idx in range(len(args)): fmt = self.inp_mem_fmts[idx] # These constants match the values in DimOrder in serializer.py # TODO: See if it's possible to use those directly. if fmt == 0: fixed_args.append(args[idx].contiguous()) elif fmt == 1: fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous()) else: raise Exception("Invalid mem_fmt") comp.run(fixed_args, outs) assert len(outs) == len(self.out_mem_fmts) for idx in range(len(self.out_templates)): fmt = self.out_mem_fmts[idx] # These constants match the values in DimOrder in serializer.py # TODO: See if it's possible to use those directly. if fmt in (0, 2): pass elif fmt == 1: outs[idx] = outs[idx].permute(0, 3, 1, 2) else: raise Exception("Invalid mem_fmt") return outs def convert_model_to_nnapi( model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False, compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED, relax_f32_to_f16=False, ): ( shape_compute_module, ser_model_tensor, used_weights, inp_mem_fmts, out_mem_fmts, retval_count, ) = process_for_nnapi( model, inputs, serializer, return_shapes, use_int16_for_qint16 ) nnapi_model = NnapiModule( shape_compute_module, ser_model_tensor, used_weights, inp_mem_fmts, out_mem_fmts, compilation_preference, relax_f32_to_f16, ) class NnapiInterfaceWrapper(torch.nn.Module): """NNAPI list-ifying and de-list-ifying wrapper. NNAPI always expects a list of inputs and provides a list of outputs. This module allows us to accept inputs as separate arguments. It returns results as either a single tensor or tuple, matching the original module. """ def __init__(self, mod): super().__init__() self.mod = mod wrapper_model_py = NnapiInterfaceWrapper(nnapi_model) wrapper_model = torch.jit.script(wrapper_model_py) # TODO: Maybe make these names match the original. arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs))) if retval_count < 0: ret_expr = "retvals[0]" else: ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count)) wrapper_model.define( f"def forward(self, {arg_list}):\n" f" retvals = self.mod([{arg_list}])\n" f" return {ret_expr}\n" ) return wrapper_model def process_for_nnapi( model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False ): model = torch.jit.freeze(model) if isinstance(inputs, torch.Tensor): inputs = [inputs] serializer = serializer or _NnapiSerializer( config=None, use_int16_for_qint16=use_int16_for_qint16 ) ( ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines, retval_count, ) = serializer.serialize_model(model, inputs, return_shapes) ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32) # We have to create a new class here every time this function is called # because module.define adds a method to the *class*, not the instance. class ShapeComputeModule(torch.nn.Module): """Code-gen-ed module for tensor shape computation. module.prepare will mutate ser_model according to the computed operand shapes, based on the shapes of args. Returns a list of output templates. """ pass shape_compute_module = torch.jit.script(ShapeComputeModule()) real_shape_compute_lines = [ "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n", ] + [f" {line}\n" for line in shape_compute_lines] shape_compute_module.define("".join(real_shape_compute_lines)) return ( shape_compute_module, ser_model_tensor, used_weights, inp_mem_fmts, out_mem_fmts, retval_count, )