import array import enum import functools import logging import operator import struct import sys from typing import List, NamedTuple, Optional, Tuple import torch # TODO: Add type annotations # TODO: Check tensor types for ops LOG = logging.getLogger("nnapi_serialize") class NNAPI_OperandCode: FLOAT32 = 0 INT32 = 1 UINT32 = 2 TENSOR_FLOAT32 = 3 TENSOR_INT32 = 4 TENSOR_QUANT8_ASYMM = 5 BOOL = 6 TENSOR_QUANT16_SYMM = 7 TENSOR_FLOAT16 = 8 TENSOR_BOOL8 = 9 FLOAT16 = 10 TENSOR_QUANT8_SYMM_PER_CHANNEL = 11 TENSOR_QUANT16_ASYMM = 12 class NNAPI_OperationCode: ADD = 0 AVERAGE_POOL_2D = 1 CONCATENATION = 2 CONV_2D = 3 DEPTHWISE_CONV_2D = 4 DEPTH_TO_SPACE = 5 DEQUANTIZE = 6 EMBEDDING_LOOKUP = 7 FLOOR = 8 FULLY_CONNECTED = 9 HASHTABLE_LOOKUP = 10 L2_NORMALIZATION = 11 L2_POOL_2D = 12 LOCAL_RESPONSE_NORMALIZATION = 13 LOGISTIC = 14 LSH_PROJECTION = 15 LSTM = 16 MAX_POOL_2D = 17 MUL = 18 RELU = 19 RELU1 = 20 RELU6 = 21 RESHAPE = 22 RESIZE_BILINEAR = 23 RNN = 24 SOFTMAX = 25 SPACE_TO_DEPTH = 26 SVDF = 27 TANH = 28 BATCH_TO_SPACE_ND = 29 DIV = 30 MEAN = 31 PAD = 32 SPACE_TO_BATCH_ND = 33 SQUEEZE = 34 STRIDED_SLICE = 35 SUB = 36 TRANSPOSE = 37 ABS = 38 ARGMAX = 39 ARGMIN = 40 AXIS_ALIGNED_BBOX_TRANSFORM = 41 BIDIRECTIONAL_SEQUENCE_LSTM = 42 BIDIRECTIONAL_SEQUENCE_RNN = 43 BOX_WITH_NMS_LIMIT = 44 CAST = 45 CHANNEL_SHUFFLE = 46 DETECTION_POSTPROCESSING = 47 EQUAL = 48 EXP = 49 EXPAND_DIMS = 50 GATHER = 51 GENERATE_PROPOSALS = 52 GREATER = 53 GREATER_EQUAL = 54 GROUPED_CONV_2D = 55 HEATMAP_MAX_KEYPOINT = 56 INSTANCE_NORMALIZATION = 57 LESS = 58 LESS_EQUAL = 59 LOG = 60 LOGICAL_AND = 61 LOGICAL_NOT = 62 LOGICAL_OR = 63 LOG_SOFTMAX = 64 MAXIMUM = 65 MINIMUM = 66 NEG = 67 NOT_EQUAL = 68 PAD_V2 = 69 POW = 70 PRELU = 71 QUANTIZE = 72 QUANTIZED_16BIT_LSTM = 73 RANDOM_MULTINOMIAL = 74 REDUCE_ALL = 75 REDUCE_ANY = 76 REDUCE_MAX = 77 REDUCE_MIN = 78 REDUCE_PROD = 79 REDUCE_SUM = 80 ROI_ALIGN = 81 ROI_POOLING = 82 RSQRT = 83 SELECT = 84 SIN = 85 SLICE = 86 SPLIT = 87 SQRT = 88 TILE = 89 TOPK_V2 = 90 TRANSPOSE_CONV_2D = 91 UNIDIRECTIONAL_SEQUENCE_LSTM = 92 UNIDIRECTIONAL_SEQUENCE_RNN = 93 RESIZE_NEAREST_NEIGHBOR = 94 class NNAPI_FuseCode: FUSED_NONE = 0 FUSED_RELU = 1 FUSED_RELU1 = 2 FUSED_RELU6 = 3 class OperandValueSourceType: IMMEDIATE = 0 NUMBERED_BUFFER = 2 NUMBERED_MEMORY = 3 # Scalar types that appear explicitly in models. # These must be kept in sync with # AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. # TODO: Expose these directly to Python to avoid maintaining this list. class TorchScalarTypes(enum.Enum): QUINT8 = 13 def approx_equal(lhs, rhs, tolerance=1e-6): return abs(lhs - rhs) <= tolerance * min(lhs, rhs) def tensor_size(op_type, dims): ITEM_SIZES = { NNAPI_OperandCode.TENSOR_FLOAT32: 4, NNAPI_OperandCode.TENSOR_INT32: 4, NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1, NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2, NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2, } size = ITEM_SIZES[op_type] for d in dims: size *= d return size def change_element(tup, index, value): ls = list(tup) ls[index] = value return tuple(ls) class ConvPoolArgs2d(NamedTuple): """Configuration arguments for a convolution.""" kernel_h: int kernel_w: int stride_h: int stride_w: int pad_t: int pad_b: int pad_l: int pad_r: int dilation_h: int dilation_w: int group: int class DimOrder(enum.Enum): PRESUMED_CONTIGUOUS = 0 CHANNELS_LAST = 1 SCALAR_OR_VECTOR = 2 UNKNOWN_CONSTANT = 999 class Operand(NamedTuple): """Represenation of an NNAPI operand.""" # NNAPI operand type. One of NNAPI_OperandCode. # TODO: Make this an enum. op_type: int # This is always the PyTorch shape, which is NCHW for feature maps. # The actual NNAPI operand might have a transposed shape. # we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes shape: Tuple[int, ...] # Specifies how the shape of the operand that we define in NNAPI # relates to the shape we track above. # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match # the shape of the PyTorch tensor. # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and # the NNAPI operand will be represented explicitly as NHWC. dim_order: DimOrder # Quantization params scale: float zero_point: int def use_nchw(self): if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS: return True if self.dim_order is DimOrder.CHANNELS_LAST: return False raise Exception("Unknown dim order") def broadcast_shapes(shape1, shape2): assert len(shape1) > 0 assert len(shape2) > 0 s1 = list(shape1) s2 = list(shape2) # TODO: Support non-equal-rank broadcast where semantics match. # This can be tricky for NHWC tensors because dimension orders # don't match between PT and NNAPI, even though semantics match. if len(s1) > len(s2): # s2 = [1] * (len(s1) - len(s2)) + s2 raise Exception("Non-equal-rank broadcast is not supported yet.") if len(s2) > len(s1): # s3 = [1] * (len(s2) - len(s1)) + s1 raise Exception("Non-equal-rank broadcast is not supported yet.") ret = [] for d1, d2 in zip(s1, s2): if d1 == 1: ret.append(d2) elif d2 == 1: ret.append(d1) elif d1 == d2: ret.append(d1) else: raise Exception(f"Cannot broadcast shapes: {shape1} and {shape2}") return tuple(ret) def get_conv_pool_shape(image_shape, args, out_ch, transpose): batch, in_c, in_h, in_w = image_shape # TODO: Handle dilation if args.dilation_h != 1 or args.dilation_w != 1: raise Exception("Dilation not supported yet.") if transpose: out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l else: out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1 out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1 # Handle variable-sized tensors. if in_h == 0: out_h = 0 if in_w == 0: out_w = 0 out_shape = (batch, out_ch, out_h, out_w) return out_shape def fix_shape(shape, dim_order): # Return the actual shape that an operand should have in NNAPI, # given a PyTorch shape and dimension order. This is where we # convert from PyTorch's "always NCHW" shape to explicit NHWC. if dim_order is DimOrder.PRESUMED_CONTIGUOUS: return shape if dim_order is DimOrder.CHANNELS_LAST: return tuple([shape[0]] + list(shape[2:]) + [shape[1]]) if dim_order is DimOrder.SCALAR_OR_VECTOR: assert len(shape) == 0 or len(shape) == 1 return shape if dim_order is DimOrder.UNKNOWN_CONSTANT: # XXX think this through return shape raise Exception(f"Bad dim_order: {dim_order!r}.") def reverse_map_dim(dim_order, d): # Return the original PyTorch dimension position for a given dimension. # d should be the dimension that NNAPI will see. # reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x # reverse_map_dim(CHANNELS_LAST, 3) == 1 if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR): return d assert dim_order is DimOrder.CHANNELS_LAST return [0, 2, 3, 1][d] def flex_name(op_id, dim): # Return the local variable name for the computed flexible size # for a given op and dimension. return f"s_{op_id}_{dim}" class _NnapiSerializer: def __init__(self, config, use_int16_for_qint16=False): self.operands = [] self.values = [] self.operations = [] self.value_data = [] self.operation_args = [] self.inputs = [] self.outputs = [] self.flexible_shape_computation_lines = [] self.modules = {} self.constants = {} self.tensor_sequences = {} self.jitval_operand_map = {} self.cached_immediates = {} self.used_weights = [] self.weight_offset = 0 self.use_int16_for_qint16 = use_int16_for_qint16 if config is None: config = {} def get_next_operand_id(self): return len(self.operands) # Add a tensor operand corresponding to a JIT Value. # Returns the NNAPI operand ID. Can be looked up later with # get_tensor_operand_by_jitval. def add_tensor_operand(self, jitval, oper): assert isinstance(oper, Operand) if jitval in self.jitval_operand_map: raise Exception(f"Duplicate tensor: {jitval!r}") operand_id = self.get_next_operand_id() self.operands.append(oper) self.jitval_operand_map[jitval] = operand_id return operand_id # Add a tensor operand that does not correspond to a JIT Value. # Useful for cases where multiple NNAPI operands are required # to implement one JIT IR node. Returns the NNAPI operand ID. def add_anonymous_tensor_operand(self, oper): assert isinstance(oper, Operand) operand_id = self.get_next_operand_id() self.operands.append(oper) return operand_id def torch_tensor_to_operand(self, tensor, dim_order): dtype = str(tensor.dtype).replace("torch.", "") scale = 0.0 zero_point = 0 if dtype == "float32": op_type = NNAPI_OperandCode.TENSOR_FLOAT32 elif dtype == "int32": op_type = NNAPI_OperandCode.TENSOR_INT32 elif dtype == "quint8": op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM scale = tensor.q_scale() zero_point = tensor.q_zero_point() elif dtype == "qint32": op_type = NNAPI_OperandCode.TENSOR_INT32 scale = tensor.q_scale() zero_point = tensor.q_zero_point() assert zero_point == 0 elif dtype == "int16": if self.use_int16_for_qint16: nnapi_dtype = getattr(tensor, "nnapi_dtype", None) op_codes = ( NNAPI_OperandCode.TENSOR_QUANT16_SYMM, NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, ) if nnapi_dtype in op_codes: op_type = nnapi_dtype scale = tensor.nnapi_scale zero_point = tensor.nnapi_zero_point else: raise Exception( f"`nnapi_type` needs to be one of {op_codes} for `int16`" ) else: raise Exception( "`int16` isn't supported. If you're trying to represent NNAPI" " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" ) else: raise Exception(f"Can't handle input with dtype '{tensor.dtype}'") return Operand( shape=tuple(tensor.shape), op_type=op_type, dim_order=dim_order, scale=scale, zero_point=zero_point, ) def add_tensor_operand_for_input(self, arg_idx, jitval, tensor): dim_order = ( DimOrder.CHANNELS_LAST if getattr(tensor, "nnapi_nhwc", False) else DimOrder.PRESUMED_CONTIGUOUS ) toper = self.torch_tensor_to_operand(tensor, dim_order) operand_id = self.add_tensor_operand(jitval, toper) self.inputs.append(operand_id) for dim, size in enumerate(tensor.shape): if size == 0: self.compute_operand_shape( operand_id, dim, f"args[{arg_idx}].shape[{dim}]" ) return operand_id def add_tensor_operand_for_weight( self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT ): toper = self.torch_tensor_to_operand(tensor, dim_order) operand_id = len(self.operands) self.operands.append(toper) tsize = tensor_size(toper.op_type, toper.shape) psize = ((tsize - 1) | 0x3) + 1 self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER)) buf_num = len(self.used_weights) offset = 0 self.value_data.append(struct.pack("iii", buf_num, offset, tsize)) # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor if dim_order == DimOrder.CHANNELS_LAST: tensor = tensor.permute(0, 2, 3, 1) self.used_weights.append(tensor) return operand_id def add_immediate_operand(self, code, value, dims): assert isinstance(dims, tuple) cache_key = (code, value) if cache_key not in self.cached_immediates: operand_id = len(self.operands) self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0)) self.values.append((operand_id, OperandValueSourceType.IMMEDIATE)) self.value_data.append(value) self.cached_immediates[cache_key] = operand_id return self.cached_immediates[cache_key] def add_immediate_int_scalar(self, value): return self.add_immediate_operand( NNAPI_OperandCode.INT32, struct.pack("i", value), () ) def add_immediate_float_scalar(self, value): return self.add_immediate_operand( NNAPI_OperandCode.FLOAT32, struct.pack("f", value), () ) def add_immediate_bool_scalar(self, value): return self.add_immediate_operand( NNAPI_OperandCode.BOOL, b"\x01" if value else b"\x00", () ) def add_immediate_int_vector(self, value): return self.add_immediate_operand( NNAPI_OperandCode.TENSOR_INT32, array.array("i", value).tobytes(), (len(value),), ) def has_operand_for_jitval(self, jitval): return jitval in self.jitval_operand_map def get_tensor_operand_by_jitval(self, jitval): operand_id = self.jitval_operand_map[jitval] return (operand_id, self.operands[operand_id]) def get_tensor_operand_by_jitval_fixed_size(self, jitval): op_id, oper = self.get_tensor_operand_by_jitval(jitval) for s in oper.shape: if s == 0: # TODO: Improve this error message, possibly after converting # many callsites to support flexible size. raise Exception("Flexible size is not supported for this operand.") if s < 0: # runtime flex LOG.warning("Operand %s has runtime flex shape", oper) return op_id, oper def get_tensor_operand_or_constant( self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS ): operand_id = self.jitval_operand_map.get(jitval) if operand_id is None: _, value = self.get_constant_value(jitval, "TensorType") operand_id = self.add_tensor_operand_for_weight(value, dim_order) return (operand_id, self.operands[operand_id]) def get_tensor_operand_for_weight(self, jitval): _, value = self.get_constant_value(jitval, "TensorType") operand_id = self.add_tensor_operand_for_weight(value) return (operand_id, self.operands[operand_id]) def add_operation(self, opcode, inputs, outputs): self.operations.append((opcode, len(inputs), len(outputs))) self.operation_args.extend(inputs + outputs) def add_tensor_sequence(self, jitval, values): assert jitval not in self.tensor_sequences self.tensor_sequences[jitval] = values def add_constant_value(self, jitval, ctype, value): assert jitval not in self.constants self.constants[jitval] = (ctype, value) def get_constant_value(self, jitval, typekind=None): record = self.constants.get(jitval) if record is None: raise Exception(f"Could not find constant value for '{jitval!r}'.") ctype, _ = record if typekind is not None and ctype.kind() != typekind: raise Exception( f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'" ) return record def operand_to_template_torchscript(self, op_id, oper, shape=None): """Return a TorchScript expression to build a template for a given operand.""" if shape is None: shape = oper.shape else: assert len(shape) == len(oper.shape) shape_parts = ["("] for d, s in enumerate(shape): if s > 0: # Fixed shape dimension: just add the value. shape_parts.append(str(s)) elif s == 0: # Load time flexible shape dimension: it should have been computed in a variable. shape_parts.append(flex_name(op_id, d)) elif s == -1: # Runtime flexible shape shape_parts.append("0") else: raise Exception("Unknown dim value, dimensions should be >= -1") shape_parts.append(",") shape_parts.append(")") shape_code = "".join(shape_parts) if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: return f"torch.zeros({shape_code}, dtype=torch.float32)" elif oper.op_type == NNAPI_OperandCode.TENSOR_INT32: return f"torch.zeros({shape_code}, dtype=torch.int32)" elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: return ( f"torch.quantize_per_tensor(" f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)" f".expand({shape_code}).contiguous()" ) elif oper.op_type in ( NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, NNAPI_OperandCode.TENSOR_QUANT16_SYMM, ): if self.use_int16_for_qint16: return f"torch.zeros({shape_code}, dtype=torch.int16)" else: raise Exception( "`int16` isn't supported. If you're trying to represent NNAPI" " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" ) raise Exception(f"Unsupported output operand type: {oper.op_type}") def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim): self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim)) def compute_operand_shape(self, op_id, dim, expr): self.flexible_shape_computation_lines.append( f"{flex_name(op_id, dim)} = {expr}" ) def transpose_to_nhwc(self, in_id, oper): if oper.shape[2:] != (1, 1): raise Exception("Automatic transpose only supported for H,W == 1,1") out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST) inputs = [None] * 2 inputs[0] = in_id inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1]) outputs = [None] * 1 outputs[0] = self.add_anonymous_tensor_operand(out_oper) self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs) return outputs[0], out_oper # Transpose inputs as necessary to allow broadcasting. def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper): if in0_oper.dim_order == in1_oper.dim_order: return in0_id, in0_oper, in1_id, in1_oper # Assume NHWC is preferred if there is a mismatch. orders = (in0_oper.dim_order, in1_oper.dim_order) if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST): return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper) if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS): return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper) raise Exception( f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}" ) def get_size_arg(self, jitval): ctype, value = self.get_constant_value(jitval) if ctype.kind() == "ListType": assert ctype.getElementType().kind() == "IntType" return value raise Exception(f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'") def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config): pc = [i.item() for i in packed_config] assert pc[0] == 2 strides = [pc[1], pc[2]] paddings = [pc[3], pc[4]] dilations = [pc[5], pc[6]] output_padding = [pc[7], pc[8]] group_num = pc[9] assert len(pc) == 11 assert output_padding == [0, 0] return self.get_conv_pool_args_2d_common( kernel_size, strides, paddings, dilations, group_num ) def get_conv_pool_args_2d_from_jit( self, kernel_size, stride, padding, dilation=None, group=None ): strides = self.get_size_arg(stride) paddings = self.get_size_arg(padding) if dilation is None: dilations = [1, 1] else: dilations = self.get_size_arg(dilation) if group is not None: _, group_num = self.get_constant_value(group, "IntType") else: group_num = None return self.get_conv_pool_args_2d_common( kernel_size, strides, paddings, dilations, group_num ) def get_conv_pool_args_2d_common( self, kernel_size, strides, paddings, dilations, group_num ): kernels = list(kernel_size) assert len(kernels) == 2 assert len(strides) == 2 assert len(paddings) == 2 assert len(dilations) == 2 # NNAPI uses 4 values for padding. ph, pw = paddings real_paddings = [ph, ph, pw, pw] return ConvPoolArgs2d( *(kernels + strides + real_paddings + dilations + [group_num]) ) def serialize_model(self, model, inputs, return_shapes=None): self.add_immediate_bool_scalar(False) self.add_immediate_bool_scalar(True) inp_dim_orders = [] out_dim_orders = [] self_jitval = next(model.graph.inputs()) self.add_constant_value(self_jitval, self_jitval.type(), model) for arg_idx, (input_value, input_tensor) in enumerate( zip(list(model.graph.inputs())[1:], inputs) ): op_id = self.add_tensor_operand_for_input( arg_idx, input_value, input_tensor ) inp_dim_orders.append(self.operands[op_id].dim_order.value) for idx, node in enumerate(model.graph.nodes()): LOG.debug("Processing node #%d: %r", idx, node) self.add_node(node) retn = model.graph.return_node() assert retn.inputsSize() == 1 assert retn.outputsSize() == 0 retn_input = retn.inputsAt(0) template_return_lines = ["return ["] if retn_input.type().kind() == "TensorType": return_values = [retn_input] retval_count = -1 elif retn_input.type().kind() == "TupleType": return_values = self.tensor_sequences[retn_input] retval_count = len(return_values) else: raise Exception(f"Unsupported return type: {retn_input.type()}") if return_shapes is not None: assert len(return_shapes) == len(return_values) for i, v in enumerate(return_values): op_id = self.jitval_operand_map[v] self.outputs.append(op_id) out_dim_orders.append(self.operands[op_id].dim_order.value) shape = return_shapes[i] if return_shapes else None template_return_lines.append( self.operand_to_template_torchscript(op_id, self.operands[op_id], shape) + "," ) template_return_lines.append("]") model = [] version = 1 header = struct.pack( "iiiiii", version, len(self.operands), len(self.values), len(self.operations), len(self.inputs), len(self.outputs), ) model.append(header) serialized_values, serialized_value_data = self.serialize_values() model.extend( struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands ) model.extend(serialized_values) model.extend(struct.pack("iii", *x) for x in self.operations) # Compact the model so we can get its length so far. model = [b"".join(model)] model_offset = len(model[0]) # Model offset is the index into the model (in 32-bit words, not bytes) # of the next dimension we're about to serialize. If it's 0, # generate code to mutate it before passing to NNAPI. assert model_offset % 4 == 0 model_offset = int(model_offset / 4) for op_id, (_, dims, dim_order, _, _) in enumerate(self.operands): shape = fix_shape(dims, dim_order) for d, s in enumerate(shape): if s == 0: pt_d = reverse_map_dim(dim_order, d) self.flexible_shape_computation_lines.append( f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}" ) model_offset += 1 # convert runtime flex shape from -1 to 0 shape = tuple(d if d != -1 else 0 for d in shape) model.append(self.serialize_ints(shape)) model.extend(serialized_value_data) model.append(self.serialize_ints(self.operation_args)) model.append(self.serialize_ints(self.inputs)) model.append(self.serialize_ints(self.outputs)) self.flexible_shape_computation_lines.extend(template_return_lines) return ( array.array("i", b"".join(model)), self.used_weights, inp_dim_orders, out_dim_orders, self.flexible_shape_computation_lines, retval_count, ) def serialize_values(self): serialized_values = [] serialized_value_data = [] assert len(self.values) == len(self.value_data) for (op_index, source_type), data in zip(self.values, self.value_data): source_length = len(data) # Pad with 0 bytes out to a multiple of 4 for alignment. physical_length = ((source_length - 1) | 0x3) + 1 padded_data = data + (b"\0" * (physical_length - source_length)) serialized_values.append( struct.pack("iii", op_index, source_type, source_length) ) serialized_value_data.append(padded_data) return serialized_values, serialized_value_data @staticmethod def serialize_ints(ints): return array.array("i", ints).tobytes() ADDER_MAP = { "prim::GetAttr": lambda self, node: self.add_getattr(node), "prim::Constant": lambda self, node: self.add_constant_node(node), "prim::ListConstruct": lambda self, node: self.add_list_construct(node), "prim::TupleConstruct": lambda self, node: self.add_tuple_construct(node), "aten::unsqueeze": lambda self, node: self.add_unsqueeze(node), "aten::to": lambda self, node: self.add_to(node), "aten::detach": lambda self, node: self._identity(node), "aten::reshape": lambda self, node: self.add_reshape(node), "aten::flatten": lambda self, node: self.add_flatten(node), "aten::slice": lambda self, node: self.add_slice(node), "aten::size": lambda self, node: self.add_size(node), "aten::cat": lambda self, node: self.add_cat(node), "aten::mean": lambda self, node: self.add_mean(node), "aten::quantize_per_tensor": lambda self, node: self.add_quantize(node), "aten::dequantize": lambda self, node: self.add_dequantize(node), "aten::add": lambda self, node: self.add_add_sub_op( node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE ), "aten::sub": lambda self, node: self.add_add_sub_op( node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE ), "aten::mul": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE ), "aten::div": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE ), "aten::relu": lambda self, node: self.add_pointwise_simple_unary_op( node, NNAPI_OperationCode.RELU ), "aten::sigmoid": lambda self, node: self.add_pointwise_simple_unary_op( node, NNAPI_OperationCode.LOGISTIC ), "aten::softmax": lambda self, node: self.add_softmax(node), "aten::hardtanh": lambda self, node: self.add_hardtanh(node), "aten::avg_pool2d": lambda self, node: self.add_avg_pool2d(node), "aten::max_pool2d": lambda self, node: self.add_pool2d_node( node, NNAPI_OperationCode.MAX_POOL_2D ), "aten::adaptive_avg_pool2d": lambda self, node: self.add_adaptive_avg_pool2d( node ), "aten::upsample_nearest2d": lambda self, node: self.add_upsample_nearest2d( node ), "aten::prelu": lambda self, node: self.add_prelu_op(node), "aten::addmm": lambda self, node: self.add_addmm(node), "aten::linear": lambda self, node: self.add_linear(node), "aten::_convolution": lambda self, node: self.add_conv_underscore(node), "aten::conv2d": lambda self, node: self.add_conv2d(node), "aten::log_softmax": lambda self, node: self.add_log_softmax(node), "quantized::linear": lambda self, node: self.add_qlinear(node), "quantized::conv2d": lambda self, node: self.add_qconv2d( node, NNAPI_FuseCode.FUSED_NONE ), "quantized::conv2d_relu": lambda self, node: self.add_qconv2d( node, NNAPI_FuseCode.FUSED_RELU ), "quantized::conv_transpose2d": lambda self, node: self.add_qconv2d( node, NNAPI_FuseCode.FUSED_NONE, transpose=True ), "quantized::add": lambda self, node: self.add_qadd( node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE ), "quantized::add_relu": lambda self, node: self.add_qadd( node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU ), "quantized::mul": lambda self, node: self.add_qadd( node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE ), } def add_node(self, node): adder = self.ADDER_MAP.get(node.kind()) if not adder: raise Exception(f"Unsupported node kind ({node.kind()!r}) in node {node!r}") adder(self, node) def _identity(self, node): in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) jitval = node.outputsAt(0) self.jitval_operand_map[jitval] = in_id def add_getattr(self, node): assert node.inputsSize() == 1 assert node.outputsSize() == 1 obj_ctype, obj = self.get_constant_value(node.inputsAt(0)) assert str(obj_ctype).startswith("__torch__.") name = node.s("name") value = getattr(obj, name) output = node.outputsAt(0) ctype = output.type() self.add_constant_value(output, ctype, value) def add_constant_node(self, node): assert node.inputsSize() == 0 assert node.outputsSize() == 1 output = node.outputsAt(0) ctype = output.type() value = output.toIValue() self.add_constant_value(output, ctype, value) def add_list_construct(self, node): assert node.outputsSize() == 1 output = node.outputsAt(0) ctype = output.type() const_vals: Optional[List] = [] tensors: Optional[List] = [] for inp in node.inputs(): if const_vals is not None and inp in self.constants: _, val = self.get_constant_value(inp) const_vals.append(val) else: const_vals = None if tensors is not None and inp.type().kind() == "TensorType": tensors.append(inp) else: tensors = None if const_vals is not None: # NOTE: Now that TorchScript supports list constants, # this code path might not be used anymore. self.add_constant_value(output, ctype, const_vals) if tensors is not None: self.add_tensor_sequence(output, tensors) if const_vals is None and tensors is None: raise Exception( f"Unable to handle ListConstruct node. Neither all constants nor all tensors. {node!r}" ) def add_tuple_construct(self, node): assert node.outputsSize() == 1 output = node.outputsAt(0) values = list(node.inputs()) self.add_tensor_sequence(output, values) def add_unsqueeze(self, node): assert node.inputsSize() == 2 assert node.outputsSize() == 1 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) _, dim = self.get_constant_value(node.inputsAt(1), "IntType") assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1 out_shape_list = list(in_oper.shape) out_shape_list.insert(real_dim, 1) out_shape = tuple(out_shape_list) out_oper = in_oper._replace(shape=out_shape) inputs = [None] * 2 inputs[0] = in_id inputs[1] = self.add_immediate_int_scalar(dim) outputs = [None] * 1 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs) def add_to(self, node): # Handle to("cpu") / to("gpu") case self._identity(node) def add_reshape(self, node): assert node.inputsSize() == 2 assert node.outputsSize() == 1 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) shape_ctype, shape = self.get_constant_value(node.inputsAt(1)) assert shape_ctype.kind() == "ListType" assert shape_ctype.getElementType().kind() == "IntType" is_trivial_reshape = len(shape) == 2 and shape[1] == -1 if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape: raise Exception( "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1]." ) # Bit of a hack here. Use a real tensor to infer the output shape. out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape out_oper = in_oper._replace( shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS ) inputs = [None] * 2 inputs[0] = in_id inputs[1] = self.add_immediate_int_vector(shape) outputs = [None] * 1 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs) def add_flatten(self, node): assert node.inputsSize() == 3 assert node.outputsSize() == 1 in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType") end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType") # channels last with channels == 1 or (height & width both 1) is_trivial_flatten = len(in_oper.shape) == 4 and ( in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1) ) if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten: raise Exception( "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1" ) if start_dim < 0: start_dim += len(in_oper.shape) if end_dim < 0: end_dim += len(in_oper.shape) out_shape = ( in_oper.shape[:start_dim] + (functools.reduce(operator.mul, in_oper.shape[start_dim : end_dim + 1]),) + in_oper.shape[end_dim + 1 :] ) if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]): raise Exception("Flattening flexible dims is not supported yet") non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :] if non_flattened_dims.count(0) > 1: raise Exception("Only 1 dim can be flexible") out_oper = in_oper._replace( shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS ) out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) for idx, dim in enumerate(out_shape): if dim == 0: self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0)) inputs_1 = tuple(dim if dim != 0 else -1 for dim in out_shape) inputs = [None] * 2 inputs[0] = in_id inputs[1] = self.add_immediate_int_vector(inputs_1) outputs = [None] * 1 outputs[0] = out_id self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs) def add_slice(self, node): assert node.inputsSize() == 5 assert node.outputsSize() == 1 in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) _, dim_value = self.get_constant_value(node.inputsAt(1)) _, start_value = self.get_constant_value(node.inputsAt(2)) _, stop_value = self.get_constant_value(node.inputsAt(3)) _, step_value = self.get_constant_value(node.inputsAt(4)) if start_value is None: start_value = 0 if stop_value is None: stop_value = sys.maxsize if start_value < 0: start_value += in_oper.shape[dim_value] elif start_value == sys.maxsize: start_value = 0 if start_value == 0 and stop_value == sys.maxsize: self._identity(node) return if in_oper.shape[dim_value] == 0: raise Exception("Unable to slice with flexible shape") if stop_value < 0: stop_value += in_oper.shape[dim_value] elif stop_value == sys.maxsize: stop_value = in_oper.shape[dim_value] if start_value >= stop_value: raise Exception("Slice start value should be less than stop value") out_len = (stop_value - start_value) // step_value out_shape = tuple( out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape) ) out_id = self.add_tensor_operand( node.outputsAt(0), in_oper._replace(shape=out_shape) ) # flex inputs end_mask = 0 for idx, dim in enumerate(out_shape): if dim == 0: self.forward_operand_shape(out_id, idx, in_id, idx) end_mask |= 1 << idx inputs = [None] * 7 inputs[0] = in_id inputs[1] = self.add_immediate_int_vector( [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))] ) inputs[2] = self.add_immediate_int_vector( [ stop_value if i == dim_value else dim for i, dim in enumerate(in_oper.shape) ] ) inputs[3] = self.add_immediate_int_vector( [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))] ) inputs[4] = self.add_immediate_int_scalar(0) # begin mask inputs[5] = self.add_immediate_int_scalar(end_mask) inputs[6] = self.add_immediate_int_scalar(0) # shrink axis mas outputs = [None] * 1 outputs[0] = out_id self.add_operation(NNAPI_OperationCode.STRIDED_SLICE, inputs, outputs) def add_size(self, node): assert node.inputsSize() == 2 assert node.outputsSize() == 1 _, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) _, value = self.constants[node.inputsAt(1)] res = in_oper.shape[value] output = node.outputsAt(0) self.add_constant_value(output, output.type(), res) def add_cat(self, node): assert node.inputsSize() == 2 assert node.outputsSize() == 1 tensors = self.tensor_sequences[node.inputsAt(0)] _, dim = self.get_constant_value(node.inputsAt(1), "IntType") assert len(tensors) > 0 in_ids = [] out_oper = None out_dim_size = 0 for inp in tensors: in_id, in_oper = self.get_tensor_operand_by_jitval(inp) if out_oper is None: out_shape = change_element(in_oper.shape, dim, -1) out_oper = in_oper._replace(shape=out_shape) assert in_oper.op_type == out_oper.op_type assert in_oper.dim_order == out_oper.dim_order assert change_element(in_oper.shape, dim, -1) == change_element( out_oper.shape, dim, -1 ) # TODO: Possibly check scale and zero point. in_ids.append(in_id) # TODO: Possibly support variable-sized inputs. out_dim_size += in_oper.shape[dim] assert out_oper is not None out_oper = out_oper._replace( shape=change_element(out_oper.shape, dim, out_dim_size) ) if in_oper.dim_order == DimOrder.CHANNELS_LAST: # type: ignore[possibly-undefined] assert len(out_oper.shape) == 4 nnapi_dim = [0, 3, 1, 2][dim] else: nnapi_dim = dim out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) for idx, d in enumerate(out_oper.shape): if d == 0: if idx == dim: shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids) self.compute_operand_shape(out_id, idx, shape) else: self.forward_operand_shape(out_id, idx, in_ids[0], idx) inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)] outputs = [None] * 1 outputs[0] = out_id self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs) def add_mean(self, node): assert node.inputsSize() == 4 assert node.outputsSize() == 1 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) dim_ctype, dim = self.get_constant_value(node.inputsAt(1)) assert dim_ctype.kind() == "ListType" assert dim_ctype.getElementType().kind() == "IntType" _, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType") # Expect None for dtype self.get_constant_value(node.inputsAt(3), "NoneType") if in_oper.dim_order == DimOrder.CHANNELS_LAST: assert len(in_oper.shape) == 4 nnapi_dim = [[0, 3, 1, 2][d] for d in dim] else: nnapi_dim = dim collapsed_dims = set() for d in dim: if d < 0: d += len(in_oper.shape) collapsed_dims.add(d) if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim: assert collapsed_dims.issuperset({2, 3}) out_dim_order = DimOrder.PRESUMED_CONTIGUOUS else: out_dim_order = in_oper.dim_order out_shape = [] for i, s in enumerate(in_oper.shape): if i not in collapsed_dims: out_shape.append(s) elif keep_dim: out_shape.append(1) out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order) inputs = [None] * 3 inputs[0] = in_id inputs[1] = self.add_immediate_int_vector(nnapi_dim) inputs[2] = self.add_immediate_int_scalar(keep_dim) outputs = [None] * 1 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs) def add_quantize(self, node): assert node.inputsSize() == 4 assert node.outputsSize() == 1 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) if in_oper.dim_order != DimOrder.CHANNELS_LAST: raise Exception( "Most hardware backends prefer NHWC quantized tensors. " "Try setting `t.nnapi_nhwc = True` on your tensor inputs. " ) _, scale = self.get_constant_value(node.inputsAt(1), "FloatType") _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType") _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType") if scalar_type != TorchScalarTypes.QUINT8.value: raise Exception( "PyTorch NNAPI export only supports quantized tensors " "with the quint8 dtype." ) op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM out_oper = in_oper._replace( op_type=op_type, scale=scale, zero_point=zero_point, ) inputs = [None] * 1 inputs[0] = in_id outputs = [None] * 1 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs) def add_dequantize(self, node): assert node.inputsSize() == 1 assert node.outputsSize() == 1 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) out_oper = in_oper._replace( op_type=NNAPI_OperandCode.TENSOR_FLOAT32, scale=0.0, zero_point=0, ) inputs = [None] * 1 inputs[0] = in_id outputs = [None] * 1 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs) def add_pointwise_simple_unary_op(self, node, opcode): assert node.inputsSize() == 1 assert node.outputsSize() == 1 in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) out_oper = in_oper if opcode == NNAPI_OperationCode.LOGISTIC: # NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale # must be 1.f / 256 and the zeroPoint must be 0. # https://fburl.com/h52stoog if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256) out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) for idx, dim in enumerate(in_oper.shape): if dim == 0: self.forward_operand_shape(out_id, idx, in_id, idx) inputs = [None] * 1 inputs[0] = in_id outputs = [None] * 1 outputs[0] = out_id self.add_operation(opcode, inputs, outputs) def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): # noqa: D401 """Helper for pointwise binary broadcast ops with superfluous extra args.""" assert node.outputsSize() == 1 assert node.inputsAt(0).type().kind() == "TensorType" assert node.inputsAt(1).type().kind() == "TensorType" if self.has_operand_for_jitval(node.inputsAt(0)): in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) in1_id, in1_oper = self.get_tensor_operand_or_constant( node.inputsAt(1), in0_oper.dim_order ) elif self.has_operand_for_jitval(node.inputsAt(1)): in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1)) in0_id, in0_oper = self.get_tensor_operand_or_constant( node.inputsAt(0), in1_oper.dim_order ) else: raise Exception(f"Can't do a NNAPI binary op: {opcode} on two constants") assert in0_oper.op_type == in1_oper.op_type in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast( in0_id, in0_oper, in1_id, in1_oper ) # NOTE: PyTorch and NNAPI have the same broadcast semantics. out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape) out_oper = in0_oper._replace(shape=out_shape) if qparams is not None: scale, zp = qparams out_oper = out_oper._replace(scale=scale, zero_point=zp) out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) for idx, (d0, d1) in enumerate(zip(in0_oper.shape, in1_oper.shape)): if d0 == 1 and d1 == 0: self.forward_operand_shape(out_id, idx, in1_id, idx) elif d0 == 0 and d1 == 1: self.forward_operand_shape(out_id, idx, in0_id, idx) elif d0 == 0 and d1 == 0: self.flexible_shape_computation_lines.append( f"assert {flex_name(in0_id, idx)} == {flex_name(in1_id, idx)}" ) self.forward_operand_shape(out_id, idx, in0_id, idx) inputs = [None] * 3 inputs[0] = in0_id inputs[1] = in1_id inputs[2] = self.add_immediate_int_scalar(fuse_code) outputs = [None] * 1 outputs[0] = out_id self.add_operation(opcode, inputs, outputs) def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code): assert node.inputsSize() == 2 self._do_add_binary(node, opcode, fuse_code) def add_add_sub_op(self, node, opcode, fuse_code): assert node.inputsSize() == 3 _, alpha = self.get_constant_value(node.inputsAt(2), "IntType") if alpha != 1: raise Exception("NNAPI does not support add/sub with alpha.") self._do_add_binary(node, opcode, fuse_code) def add_qadd(self, node, opcode, fuse_code): assert node.inputsSize() == 4 _, scale = self.get_constant_value(node.inputsAt(2), "FloatType") _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType") self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point)) def add_softmax(self, node): assert node.inputsSize() == 3 in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) _, softmax_dim = self.get_constant_value(node.inputsAt(1), "IntType") out_id = self.add_tensor_operand(node.outputsAt(0), in_oper) for dim, size in enumerate(in_oper.shape): if size == 0: self.forward_operand_shape(out_id, dim, in_id, dim) inputs = [None] * 3 inputs[0] = in_id inputs[1] = self.add_immediate_float_scalar( 1.0 ) # positive scaling factor of exponent, beta inputs[2] = self.add_immediate_int_scalar(softmax_dim) outputs = [None] * 1 outputs[0] = out_id self.add_operation(NNAPI_OperationCode.SOFTMAX, inputs, outputs) def add_hardtanh(self, node): assert node.inputsSize() == 3 assert node.outputsSize() == 1 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) _, min_val = self.get_constant_value(node.inputsAt(1), "FloatType") _, max_val = self.get_constant_value(node.inputsAt(2), "FloatType") op_map = { (-1, 1): NNAPI_OperationCode.RELU1, (0, 6): NNAPI_OperationCode.RELU6, # noqa: E201 } opcode = op_map.get((min_val, max_val)) if opcode is None: raise Exception("NNAPI only supports hardtanh with args (-1, 1) or (0, 6).") inputs = [None] * 1 inputs[0] = in_id outputs = [None] * 1 outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper) self.add_operation(opcode, inputs, outputs) def add_prelu_op(self, node): assert node.inputsSize() == 2 assert node.outputsSize() == 1 assert node.inputsAt(0).type().kind() == "TensorType" assert node.inputsAt(1).type().kind() == "TensorType" in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1)) assert len(w_oper.shape) == 1 assert w_oper.shape[0] > 0 if w_oper.shape[0] > 1: if in_oper.use_nchw(): # TODO: Support this by adding trailing 1 dims. raise Exception( "Per-channel PReLU only supports channels_last right now." ) out_id = self.add_tensor_operand(node.outputsAt(0), in_oper) for dim, size in enumerate(in_oper.shape): if size > 0: pass elif dim <= 1: raise Exception("PReLU requires fixed size for dim 0 and dim 1.") else: self.forward_operand_shape(out_id, dim, in_id, dim) inputs = [None] * 2 inputs[0] = in_id inputs[1] = w_id outputs = [None] * 1 outputs[0] = out_id self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs) def add_pool2d_node(self, node, opcode): assert node.inputsSize() == 6 assert node.outputsSize() == 1 image, kernel, stride, padding, dilation, ceil_mode = node.inputs() stride = stride or kernel # TODO: Validate ceil_mode semantics. args = self.get_conv_pool_args_2d_from_jit( self.get_size_arg(kernel), stride, padding, dilation ) if args.dilation_h != 1 or args.dilation_w != 1: raise Exception("NNAPI does not support dilated pooling.") image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image) assert len(image_oper.shape) == 4 out_shape = get_conv_pool_shape( image_oper.shape, args, image_oper.shape[1], False ) use_nchw = image_oper.use_nchw() inputs = [None] * 11 inputs[0] = image_id inputs[1] = self.add_immediate_int_scalar(args.pad_l) inputs[2] = self.add_immediate_int_scalar(args.pad_r) inputs[3] = self.add_immediate_int_scalar(args.pad_t) inputs[4] = self.add_immediate_int_scalar(args.pad_b) inputs[5] = self.add_immediate_int_scalar(args.stride_w) inputs[6] = self.add_immediate_int_scalar(args.stride_h) inputs[7] = self.add_immediate_int_scalar(args.kernel_w) inputs[8] = self.add_immediate_int_scalar(args.kernel_h) inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) inputs[10] = self.add_immediate_bool_scalar(use_nchw) outputs = [None] * 1 outputs[0] = self.add_tensor_operand( node.outputsAt(0), image_oper._replace(shape=out_shape) ) self.add_operation(opcode, inputs, outputs) def add_avg_pool2d(self, node): assert node.inputsSize() == 7 assert node.outputsSize() == 1 ( image, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override, ) = node.inputs() _, count_include_pad_value = self.get_constant_value(count_include_pad) _, divisor_override_value = self.get_constant_value(divisor_override) if not count_include_pad_value or divisor_override_value: raise Exception( "NNAPI doesn't support count_include_pad=False or divisor_override" ) args = self.get_conv_pool_args_2d_from_jit( self.get_size_arg(kernel), stride, padding ) image_id, image_oper = self.get_tensor_operand_by_jitval(image) assert len(image_oper.shape) == 4 out_shape = get_conv_pool_shape( image_oper.shape, args, image_oper.shape[1], False ) use_nchw = image_oper.use_nchw() inputs = [None] * 11 inputs[0] = image_id inputs[1] = self.add_immediate_int_scalar(args.pad_l) inputs[2] = self.add_immediate_int_scalar(args.pad_r) inputs[3] = self.add_immediate_int_scalar(args.pad_t) inputs[4] = self.add_immediate_int_scalar(args.pad_b) inputs[5] = self.add_immediate_int_scalar(args.stride_w) inputs[6] = self.add_immediate_int_scalar(args.stride_h) inputs[7] = self.add_immediate_int_scalar(args.kernel_w) inputs[8] = self.add_immediate_int_scalar(args.kernel_h) inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) inputs[10] = self.add_immediate_bool_scalar(use_nchw) outputs = [None] * 1 out_id = self.add_tensor_operand( node.outputsAt(0), image_oper._replace(shape=out_shape) ) self._handle_conv_pool_flexible_input(out_id, image, args, False) outputs[0] = out_id self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) def add_adaptive_avg_pool2d(self, node): assert node.inputsSize() == 2 assert node.outputsSize() == 1 image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size( node.inputsAt(0) ) assert len(image_oper.shape) == 4 size_ctype, size_arg = self.get_constant_value(node.inputsAt(1)) assert size_ctype.kind() == "ListType" assert size_ctype.getElementType().kind() == "IntType" if size_arg != [1, 1]: raise Exception( "NNAPI only supports adaptive_avg_pool2d with output size (1, 1)." ) out_shape = image_oper.shape[0:2] + tuple(size_arg) use_nchw = image_oper.use_nchw() inputs = [None] * 11 inputs[0] = image_id inputs[1] = self.add_immediate_int_scalar(0) inputs[2] = self.add_immediate_int_scalar(0) inputs[3] = self.add_immediate_int_scalar(0) inputs[4] = self.add_immediate_int_scalar(0) inputs[5] = self.add_immediate_int_scalar(1) inputs[6] = self.add_immediate_int_scalar(1) inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3]) inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2]) inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) inputs[10] = self.add_immediate_bool_scalar(use_nchw) outputs = [None] * 1 outputs[0] = self.add_tensor_operand( node.outputsAt(0), image_oper._replace(shape=out_shape) ) self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) def add_upsample_nearest2d(self, node): assert node.inputsSize() == 3 or node.inputsSize() == 4 assert node.outputsSize() == 1 if node.inputsSize() == 3: image, size_jit, scale_jit = node.inputs() else: image, size_jit, scale_h_jit, scale_w_jit = node.inputs() size_ctype, size_arg = self.get_constant_value(size_jit) if node.inputsSize() == 3: scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined] else: scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined] scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined] # The only way for the 4-argument overload of upsample_nearest2d to # have been added to the graph without error is if the scale_h and # scale_w arguments are None assert scale_h_ctype.kind() == "NoneType" assert scale_w_ctype.kind() == "NoneType" scale_ctype = scale_h_ctype scale_arg = scale_h_arg image_id, image_oper = self.get_tensor_operand_by_jitval(image) assert len(image_oper.shape) == 4 if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType": raise Exception("Size and scale cannot both be non-None.") elif size_ctype.kind() != "NoneType": assert size_ctype.kind() == "ListType" assert size_ctype.getElementType().kind() == "IntType" assert scale_ctype.kind() == "NoneType" assert scale_arg is None assert isinstance(size_arg, list) assert size_arg assert all(isinstance(val, int) for val in size_arg) if len(size_arg) == 1: size_arg = size_arg * 2 assert len(size_arg) == 2 out_h = size_arg[0] out_w = size_arg[1] arg_h = self.add_immediate_int_scalar(out_h) arg_w = self.add_immediate_int_scalar(out_w) elif scale_ctype.kind() != "NoneType": assert scale_ctype.kind() == "ListType" assert scale_ctype.getElementType().kind() == "FloatType" assert size_ctype.kind() == "NoneType" assert size_arg is None assert isinstance(scale_arg, list) assert scale_arg assert all(isinstance(val, float) for val in scale_arg) if len(scale_arg) == 1: scale_arg = scale_arg * 2 assert len(scale_arg) == 2 out_h = int(scale_arg[0] * image_oper.shape[2]) out_w = int(scale_arg[1] * image_oper.shape[3]) arg_h = self.add_immediate_float_scalar(scale_arg[0]) arg_w = self.add_immediate_float_scalar(scale_arg[1]) else: raise Exception("Size and scale cannot both be None.") out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w) use_nchw = image_oper.use_nchw() out_id = self.add_tensor_operand( node.outputsAt(0), image_oper._replace(shape=out_shape) ) if image_oper.shape[0] == 0 or image_oper.shape[1] == 0: raise Exception("Flexible batch or channels not supported") # Handle variable input size for dim in (2, 3): # h, w indices if image_oper.shape[dim] == 0: if size_ctype.kind() != "NoneType": self.compute_operand_shape(out_id, dim, size_arg[dim - 2]) elif scale_ctype.kind() != "NoneType": self.compute_operand_shape( out_id, dim, f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})", ) else: raise Exception("Size and scale cannot both be None.") inputs = [None] * 4 inputs[0] = image_id inputs[1] = arg_w inputs[2] = arg_h inputs[3] = self.add_immediate_bool_scalar(use_nchw) outputs = [None] * 1 outputs[0] = out_id self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs) def add_addmm(self, node): assert node.inputsSize() == 5 assert node.outputsSize() == 1 jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs() for jitval in (jit_beta, jit_alpha): scale_ctype, scale_value = self.get_constant_value(jitval) assert scale_ctype.kind() in ("IntType", "FloatType") if scale_value != 1: raise Exception( "NNAPI Fully-Connected does not support alpha and beta." ) self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias) def add_linear(self, node): assert node.inputsSize() == 3 assert node.outputsSize() == 1 jit_input, jit_weight, jit_bias = node.inputs() self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias) def add_addmm_or_linear( self, node, transpose_weight, jit_input, jit_weight, jit_bias ): input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input) bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias) assert len(input_oper.shape) == 2 assert len(bias_oper.shape) == 1 # TODO: Transform at load time to share weights with CPU model. _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") assert len(weight_tensor.shape) == 2 if transpose_weight: nnapi_weight_tensor = weight_tensor.t().contiguous() else: nnapi_weight_tensor = weight_tensor.contiguous() weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) weight_oper = self.operands[weight_id] out_shape = (input_oper.shape[0], weight_oper.shape[0]) out_id = self.add_tensor_operand( node.outputsAt(0), input_oper._replace(shape=out_shape) ) if input_oper.shape[0] == 0: self.forward_operand_shape(out_id, 0, input_id, 0) inputs = [None] * 4 inputs[0] = input_id inputs[1] = weight_id inputs[2] = bias_id inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) outputs = [None] * 1 outputs[0] = out_id self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) def add_qlinear(self, node): assert node.inputsSize() == 4 assert node.outputsSize() == 1 ( jit_input, jit_packed_weight, jit_scale, jit_zero_point, ) = node.inputs() input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) # TODO: Support automatic reshape assert len(input_oper.shape) == 2 _, out_scale = self.get_constant_value(jit_scale, "FloatType") _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) assert weight_ctype.name() == "LinearPackedParamsBase" raw_weight, raw_bias = packed_weight.__getstate__()[0] assert raw_bias is not None assert len(raw_weight.shape) == 2 assert len(raw_bias.shape) == 1 assert raw_bias.shape[0] == raw_weight.shape[0] assert raw_weight.shape[1] == input_oper.shape[1] assert raw_weight.qscheme() == torch.per_tensor_affine if raw_weight.dtype == torch.quint8: unsigned_weight = raw_weight else: assert raw_weight.dtype == torch.qint8 unsigned_weight = torch._make_per_tensor_quantized_tensor( (raw_weight.int_repr().int() + 128).to(torch.uint8), scale=raw_weight.q_scale(), zero_point=raw_weight.q_zero_point() + 128, ) weight_scale = unsigned_weight.q_scale() bias_scale = input_oper.scale * weight_scale int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) bias_id = self.add_tensor_operand_for_weight(int_bias) multiplier = input_oper.scale * weight_scale / out_scale assert multiplier > 0 if multiplier >= 1: raise Exception( "Quantized convolution multiplier is greater than 1. " "This is supported by NNAPI, but not by most hardware backends. " "Try training a model without quantization-aware training. " ) # TODO: Transform at load time to share weights with CPU model. nnapi_weight_tensor = unsigned_weight.contiguous() weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) weight_oper = self.operands[weight_id] out_shape = (input_oper.shape[0], weight_oper.shape[0]) out_oper = input_oper._replace( shape=out_shape, scale=out_scale, zero_point=out_zero_point, ) inputs = [None] * 4 inputs[0] = input_id inputs[1] = weight_id inputs[2] = bias_id inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) outputs = [None] * 1 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) def get_optional_bias(self, jit_bias, weight_tensor, transpose=False): ctype, value = self.get_constant_value(jit_bias) if ctype.kind() == "NoneType": bias_idx = 1 if transpose else 0 nnapi_bias_tensor = torch.zeros( weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype ) bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor) bias_oper = self.operands[bias_id] return bias_id, bias_oper else: return self.get_tensor_operand_for_weight(jit_bias) def add_conv2d(self, node): assert node.inputsSize() == 7 assert node.outputsSize() == 1 ( jit_image, jit_weight, jit_bias, jit_stride, jit_pad, jit_dilation, jit_groups, ) = node.inputs() _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor) args = self.get_conv_pool_args_2d_from_jit( weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups ) return self.add_conv2d_common( node.outputsAt(0), 0.0, 0, jit_image, weight_tensor, bias_id, args, False, # transpose NNAPI_FuseCode.FUSED_NONE, ) def add_conv_underscore(self, node): assert node.inputsSize() == 13 assert node.outputsSize() == 1 ( jit_image, jit_weight, jit_bias, jit_stride, jit_pad, jit_dilation, jit_transpose, _, jit_groups, _, _, _, _, ) = node.inputs() _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") _, transpose = self.get_constant_value(jit_transpose) bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose) args = self.get_conv_pool_args_2d_from_jit( weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups ) return self.add_conv2d_common( node.outputsAt(0), 0.0, 0, jit_image, weight_tensor, bias_id, args, transpose, NNAPI_FuseCode.FUSED_NONE, ) def add_log_softmax(self, node): assert node.inputsSize() == 3 assert node.outputsSize() == 1 (jit_input, jit_dim, jit_half_to_float) = node.inputs() input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) _, dim = self.get_constant_value(jit_dim, "IntType") out_shape = input_oper.shape inputs = [None] * 3 inputs[0] = input_id # specifying 1 as the scaling factor for the exponent, beta inputs[1] = self.add_immediate_float_scalar(1) inputs[2] = self.add_immediate_int_scalar(dim) outputs = [None] * 1 outputs[0] = self.add_tensor_operand( node.outputsAt(0), input_oper._replace(shape=out_shape) ) self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs) def add_qconv2d(self, node, fuse_code, transpose=False): assert node.inputsSize() == 4 assert node.outputsSize() == 1 ( jit_image, jit_packed_weight, jit_scale, jit_zero_point, ) = node.inputs() _, out_scale = self.get_constant_value(jit_scale, "FloatType") _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) assert weight_ctype.name() == "Conv2dPackedParamsBase" ( pack_version, tensors, opt_tensors, ) = packed_weight.__getstate__()[0] assert pack_version == "2" packed_config, raw_weight = tensors (raw_bias,) = opt_tensors assert raw_bias is not None args = self.get_conv_pool_args_2d_from_pack( raw_weight.shape[2:4], packed_config ) assert raw_weight.qscheme() == torch.per_tensor_affine if raw_weight.dtype == torch.quint8: unsigned_weight = raw_weight else: assert raw_weight.dtype == torch.qint8 unsigned_weight = torch._make_per_tensor_quantized_tensor( (raw_weight.int_repr().int() + 128).to(torch.uint8), scale=raw_weight.q_scale(), zero_point=raw_weight.q_zero_point() + 128, ) weight_scale = unsigned_weight.q_scale() _, image_oper = self.get_tensor_operand_by_jitval(jit_image) bias_scale = image_oper.scale * weight_scale int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) bias_id = self.add_tensor_operand_for_weight(int_bias) multiplier = image_oper.scale * weight_scale / out_scale assert multiplier > 0 if multiplier >= 1: raise Exception( "Quantized convolution multiplier is greater than 1. " "This is supported by NNAPI, but not by most hardware backends. " "Try training a model without quantization-aware training. " ) return self.add_conv2d_common( node.outputsAt(0), out_scale, out_zero_point, jit_image, unsigned_weight, bias_id, args, transpose, fuse_code, ) def add_conv2d_common( self, jit_out, out_scale, out_zero_point, jit_image, weight_tensor, bias_id, args, transpose, fuse_code, ): image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) in_c = image_oper.shape[1] if args.group == 1: # Full convolution depthwise = False if transpose: weight_permutation = (1, 2, 3, 0) else: weight_permutation = (0, 2, 3, 1) elif args.group == in_c: # Depthwise convolution depthwise = True weight_permutation = (1, 2, 3, 0) else: raise Exception("Group convolution not supported yet.") # TODO: Transform at load time to share weights with CPU model. nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous() weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) weight_oper = self.operands[weight_id] bias_oper = self.operands[bias_id] if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32 assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale) assert bias_oper.zero_point == 0 else: raise Exception(f"Unsupported input type for conv2d: {image_oper.op_type}") assert len(image_oper.shape) == 4 assert len(weight_oper.shape) == 4 assert len(bias_oper.shape) == 1 if depthwise: # Depthwise convolution one, kern_h, kern_w, out_c = weight_oper.shape assert one == 1 assert out_c % in_c == 0 channel_multiplier = out_c // in_c assert channel_multiplier == 1 # Don't support multiplier assert out_c == in_c else: # Full convolution out_c, kern_h, kern_w, kern_d = weight_oper.shape assert kern_d == in_c assert out_c == bias_oper.shape[0] use_nchw = image_oper.use_nchw() if depthwise: num_args = 12 opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D else: num_args = 11 if transpose: opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D else: opcode = NNAPI_OperationCode.CONV_2D inputs = [None] * num_args inputs[0] = image_id inputs[1] = weight_id inputs[2] = bias_id inputs[3] = self.add_immediate_int_scalar(args.pad_l) inputs[4] = self.add_immediate_int_scalar(args.pad_r) inputs[5] = self.add_immediate_int_scalar(args.pad_t) inputs[6] = self.add_immediate_int_scalar(args.pad_b) inputs[7] = self.add_immediate_int_scalar(args.stride_w) inputs[8] = self.add_immediate_int_scalar(args.stride_h) if depthwise: inputs[9] = self.add_immediate_int_scalar(1) inputs[10] = self.add_immediate_int_scalar(fuse_code) inputs[11] = self.add_immediate_bool_scalar(use_nchw) else: inputs[9] = self.add_immediate_int_scalar(fuse_code) inputs[10] = self.add_immediate_bool_scalar(use_nchw) outputs = [None] * 1 out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose) out_oper = image_oper._replace( shape=out_shape, scale=out_scale, zero_point=out_zero_point, ) out_id = self.add_tensor_operand(jit_out, out_oper) self._handle_conv_pool_flexible_input(out_id, jit_image, args, transpose) outputs[0] = out_id self.add_operation(opcode, inputs, outputs) def _handle_conv_pool_flexible_input(self, out_id, jit_image, args, transpose): image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) batch, in_ch, in_h, in_w = image_oper.shape if batch == 0: self.forward_operand_shape(out_id, 0, image_id, 0) if in_ch == 0: raise Exception("Input channels can't be flexible") # H & W if transpose: if in_h == 0: self.compute_operand_shape( out_id, 2, f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}", ) if in_w == 0: self.compute_operand_shape( out_id, 3, f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}", ) else: if in_h == 0: self.compute_operand_shape( out_id, 2, f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1", ) if in_w == 0: self.compute_operand_shape( out_id, 3, f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1", ) def serialize_model( module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False ): """Convert to NNAPI and serialize torchscript module. Parameters: module: Torchscript module to convert inputs: Tensors used to specify input details for NNAPI config (optional): Optional config to attach to module return_shapes (optional): Specify shape of outputs if your module uses runtime flexible shapes to set output buffer size for NNAPI use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values """ return _NnapiSerializer(config, use_int16_for_qint16).serialize_model( module, inputs, return_shapes )