199 lines
6.4 KiB
Python
199 lines
6.4 KiB
Python
from typing import List, Optional
|
|
|
|
import torch
|
|
from torch.backends._nnapi.serializer import _NnapiSerializer
|
|
|
|
ANEURALNETWORKS_PREFER_LOW_POWER = 0
|
|
ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
|
|
ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
|
|
|
|
|
|
class NnapiModule(torch.nn.Module):
|
|
"""Torch Module that wraps an NNAPI Compilation.
|
|
|
|
This module handles preparing the weights, initializing the
|
|
NNAPI TorchBind object, and adjusting the memory formats
|
|
of all inputs and outputs.
|
|
"""
|
|
|
|
# _nnapi.Compilation is defined
|
|
comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined]
|
|
weights: List[torch.Tensor]
|
|
out_templates: List[torch.Tensor]
|
|
|
|
def __init__(
|
|
self,
|
|
shape_compute_module: torch.nn.Module,
|
|
ser_model: torch.Tensor,
|
|
weights: List[torch.Tensor],
|
|
inp_mem_fmts: List[int],
|
|
out_mem_fmts: List[int],
|
|
compilation_preference: int,
|
|
relax_f32_to_f16: bool,
|
|
):
|
|
super().__init__()
|
|
self.shape_compute_module = shape_compute_module
|
|
self.ser_model = ser_model
|
|
self.weights = weights
|
|
self.inp_mem_fmts = inp_mem_fmts
|
|
self.out_mem_fmts = out_mem_fmts
|
|
self.out_templates = []
|
|
self.comp = None
|
|
self.compilation_preference = compilation_preference
|
|
self.relax_f32_to_f16 = relax_f32_to_f16
|
|
|
|
@torch.jit.export
|
|
def init(self, args: List[torch.Tensor]):
|
|
assert self.comp is None
|
|
self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator]
|
|
self.weights = [w.contiguous() for w in self.weights]
|
|
comp = torch.classes._nnapi.Compilation()
|
|
comp.init2(
|
|
self.ser_model,
|
|
self.weights,
|
|
self.compilation_preference,
|
|
self.relax_f32_to_f16,
|
|
)
|
|
|
|
self.comp = comp
|
|
|
|
def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
if self.comp is None:
|
|
self.init(args)
|
|
comp = self.comp
|
|
assert comp is not None
|
|
outs = [torch.empty_like(out) for out in self.out_templates]
|
|
|
|
assert len(args) == len(self.inp_mem_fmts)
|
|
fixed_args = []
|
|
for idx in range(len(args)):
|
|
fmt = self.inp_mem_fmts[idx]
|
|
# These constants match the values in DimOrder in serializer.py
|
|
# TODO: See if it's possible to use those directly.
|
|
if fmt == 0:
|
|
fixed_args.append(args[idx].contiguous())
|
|
elif fmt == 1:
|
|
fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
|
|
else:
|
|
raise Exception("Invalid mem_fmt")
|
|
comp.run(fixed_args, outs)
|
|
assert len(outs) == len(self.out_mem_fmts)
|
|
for idx in range(len(self.out_templates)):
|
|
fmt = self.out_mem_fmts[idx]
|
|
# These constants match the values in DimOrder in serializer.py
|
|
# TODO: See if it's possible to use those directly.
|
|
if fmt in (0, 2):
|
|
pass
|
|
elif fmt == 1:
|
|
outs[idx] = outs[idx].permute(0, 3, 1, 2)
|
|
else:
|
|
raise Exception("Invalid mem_fmt")
|
|
return outs
|
|
|
|
|
|
def convert_model_to_nnapi(
|
|
model,
|
|
inputs,
|
|
serializer=None,
|
|
return_shapes=None,
|
|
use_int16_for_qint16=False,
|
|
compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED,
|
|
relax_f32_to_f16=False,
|
|
):
|
|
(
|
|
shape_compute_module,
|
|
ser_model_tensor,
|
|
used_weights,
|
|
inp_mem_fmts,
|
|
out_mem_fmts,
|
|
retval_count,
|
|
) = process_for_nnapi(
|
|
model, inputs, serializer, return_shapes, use_int16_for_qint16
|
|
)
|
|
|
|
nnapi_model = NnapiModule(
|
|
shape_compute_module,
|
|
ser_model_tensor,
|
|
used_weights,
|
|
inp_mem_fmts,
|
|
out_mem_fmts,
|
|
compilation_preference,
|
|
relax_f32_to_f16,
|
|
)
|
|
|
|
class NnapiInterfaceWrapper(torch.nn.Module):
|
|
"""NNAPI list-ifying and de-list-ifying wrapper.
|
|
|
|
NNAPI always expects a list of inputs and provides a list of outputs.
|
|
This module allows us to accept inputs as separate arguments.
|
|
It returns results as either a single tensor or tuple,
|
|
matching the original module.
|
|
"""
|
|
|
|
def __init__(self, mod):
|
|
super().__init__()
|
|
self.mod = mod
|
|
|
|
wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
|
|
wrapper_model = torch.jit.script(wrapper_model_py)
|
|
# TODO: Maybe make these names match the original.
|
|
arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
|
|
if retval_count < 0:
|
|
ret_expr = "retvals[0]"
|
|
else:
|
|
ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
|
|
wrapper_model.define(
|
|
f"def forward(self, {arg_list}):\n"
|
|
f" retvals = self.mod([{arg_list}])\n"
|
|
f" return {ret_expr}\n"
|
|
)
|
|
return wrapper_model
|
|
|
|
|
|
def process_for_nnapi(
|
|
model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False
|
|
):
|
|
model = torch.jit.freeze(model)
|
|
|
|
if isinstance(inputs, torch.Tensor):
|
|
inputs = [inputs]
|
|
|
|
serializer = serializer or _NnapiSerializer(
|
|
config=None, use_int16_for_qint16=use_int16_for_qint16
|
|
)
|
|
(
|
|
ser_model,
|
|
used_weights,
|
|
inp_mem_fmts,
|
|
out_mem_fmts,
|
|
shape_compute_lines,
|
|
retval_count,
|
|
) = serializer.serialize_model(model, inputs, return_shapes)
|
|
ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
|
|
|
|
# We have to create a new class here every time this function is called
|
|
# because module.define adds a method to the *class*, not the instance.
|
|
class ShapeComputeModule(torch.nn.Module):
|
|
"""Code-gen-ed module for tensor shape computation.
|
|
|
|
module.prepare will mutate ser_model according to the computed operand
|
|
shapes, based on the shapes of args. Returns a list of output templates.
|
|
"""
|
|
|
|
pass
|
|
|
|
shape_compute_module = torch.jit.script(ShapeComputeModule())
|
|
real_shape_compute_lines = [
|
|
"def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
|
|
] + [f" {line}\n" for line in shape_compute_lines]
|
|
shape_compute_module.define("".join(real_shape_compute_lines))
|
|
|
|
return (
|
|
shape_compute_module,
|
|
ser_model_tensor,
|
|
used_weights,
|
|
inp_mem_fmts,
|
|
out_mem_fmts,
|
|
retval_count,
|
|
)
|