import sys
import enum
import struct
import array
import logging
import functools
from typing import (
Tuple,
NamedTuple,
List,
Optional,
)
import torch
# TODO: Add type annotations
# TODO: Check tensor types for ops
LOG = logging.getLogger("nnapi_serialize")
class NNAPI_OperandCode:
FLOAT32 = 0
INT32 = 1
UINT32 = 2
TENSOR_FLOAT32 = 3
TENSOR_INT32 = 4
TENSOR_QUANT8_ASYMM = 5
BOOL = 6
TENSOR_QUANT16_SYMM = 7
TENSOR_FLOAT16 = 8
TENSOR_BOOL8 = 9
FLOAT16 = 10
TENSOR_QUANT8_SYMM_PER_CHANNEL = 11
TENSOR_QUANT16_ASYMM = 12
class NNAPI_OperationCode:
ADD = 0
AVERAGE_POOL_2D = 1
CONCATENATION = 2
CONV_2D = 3
DEPTHWISE_CONV_2D = 4
DEPTH_TO_SPACE = 5
DEQUANTIZE = 6
EMBEDDING_LOOKUP = 7
FLOOR = 8
FULLY_CONNECTED = 9
HASHTABLE_LOOKUP = 10
L2_NORMALIZATION = 11
L2_POOL_2D = 12
LOCAL_RESPONSE_NORMALIZATION = 13
LOGISTIC = 14
LSH_PROJECTION = 15
LSTM = 16
MAX_POOL_2D = 17
MUL = 18
RELU = 19
RELU1 = 20
RELU6 = 21
RESHAPE = 22
RESIZE_BILINEAR = 23
RNN = 24
SOFTMAX = 25
SPACE_TO_DEPTH = 26
SVDF = 27
TANH = 28
BATCH_TO_SPACE_ND = 29
DIV = 30
MEAN = 31
PAD = 32
SPACE_TO_BATCH_ND = 33
SQUEEZE = 34
STRIDED_SLICE = 35
SUB = 36
TRANSPOSE = 37
ABS = 38
ARGMAX = 39
ARGMIN = 40
AXIS_ALIGNED_BBOX_TRANSFORM = 41
BIDIRECTIONAL_SEQUENCE_LSTM = 42
BIDIRECTIONAL_SEQUENCE_RNN = 43
BOX_WITH_NMS_LIMIT = 44
CAST = 45
CHANNEL_SHUFFLE = 46
DETECTION_POSTPROCESSING = 47
EQUAL = 48
EXP = 49
EXPAND_DIMS = 50
GATHER = 51
GENERATE_PROPOSALS = 52
GREATER = 53
GREATER_EQUAL = 54
GROUPED_CONV_2D = 55
HEATMAP_MAX_KEYPOINT = 56
INSTANCE_NORMALIZATION = 57
LESS = 58
LESS_EQUAL = 59
LOG = 60
LOGICAL_AND = 61
LOGICAL_NOT = 62
LOGICAL_OR = 63
LOG_SOFTMAX = 64
MAXIMUM = 65
MINIMUM = 66
NEG = 67
NOT_EQUAL = 68
PAD_V2 = 69
POW = 70
PRELU = 71
QUANTIZE = 72
QUANTIZED_16BIT_LSTM = 73
RANDOM_MULTINOMIAL = 74
REDUCE_ALL = 75
REDUCE_ANY = 76
REDUCE_MAX = 77
REDUCE_MIN = 78
REDUCE_PROD = 79
REDUCE_SUM = 80
ROI_ALIGN = 81
ROI_POOLING = 82
RSQRT = 83
SELECT = 84
SIN = 85
SLICE = 86
SPLIT = 87
SQRT = 88
TILE = 89
TOPK_V2 = 90
TRANSPOSE_CONV_2D = 91
UNIDIRECTIONAL_SEQUENCE_LSTM = 92
UNIDIRECTIONAL_SEQUENCE_RNN = 93
RESIZE_NEAREST_NEIGHBOR = 94
class NNAPI_FuseCode:
FUSED_NONE = 0
FUSED_RELU = 1
FUSED_RELU1 = 2
FUSED_RELU6 = 3
class OperandValueSourceType:
IMMEDIATE = 0
NUMBERED_BUFFER = 2
NUMBERED_MEMORY = 3
# Scalar types that appear explicitly in models.
# These must be kept in sync with
# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
# TODO: Expose these directly to Python to avoid maintaining this list.
class TorchScalarTypes(enum.Enum):
QUINT8 = 13
def approx_equal(lhs, rhs, tolerance=1e-6):
return abs(lhs - rhs) <= tolerance * min(lhs, rhs)
def tensor_size(op_type, dims):
ITEM_SIZES = {
NNAPI_OperandCode.TENSOR_FLOAT32: 4,
NNAPI_OperandCode.TENSOR_INT32: 4,
NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1,
NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2,
NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2,
}
size = ITEM_SIZES[op_type]
for d in dims:
size *= d
return size
def change_element(tup, index, value):
ls = list(tup)
ls[index] = value
return tuple(ls)
class ConvPoolArgs2d(NamedTuple):
"""Configuration arguments for a convolution."""
kernel_h: int
kernel_w: int
stride_h: int
stride_w: int
pad_t: int
pad_b: int
pad_l: int
pad_r: int
dilation_h: int
dilation_w: int
group: int
class DimOrder(enum.Enum):
PRESUMED_CONTIGUOUS = 0
CHANNELS_LAST = 1
SCALAR_OR_VECTOR = 2
UNKNOWN_CONSTANT = 999
class Operand(NamedTuple):
"""Represenation of an NNAPI operand."""
# NNAPI operand type. One of NNAPI_OperandCode.
# TODO: Make this an enum.
op_type: int
# This is always the PyTorch shape, which is NCHW for feature maps.
# The actual NNAPI operand might have a transposed shape.
# we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes
shape: Tuple[int, ...]
# Specifies how the shape of the operand that we define in NNAPI
# relates to the shape we track above.
# - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match
# the shape of the PyTorch tensor.
# - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and
# the NNAPI operand will be represented explicitly as NHWC.
dim_order: DimOrder
# Quantization params
scale: float
zero_point: int
def use_nchw(self):
if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS:
return True
if self.dim_order is DimOrder.CHANNELS_LAST:
return False
raise Exception("Unknown dim order")
def broadcast_shapes(shape1, shape2):
assert len(shape1) > 0
assert len(shape2) > 0
s1 = list(shape1)
s2 = list(shape2)
# TODO: Support non-equal-rank broadcast where semantics match.
# This can be tricky for NHWC tensors because dimension orders
# don't match between PT and NNAPI, even though semantics match.
if len(s1) > len(s2):
# s2 = [1] * (len(s1) - len(s2)) + s2
raise Exception("Non-equal-rank broadcast is not supported yet.")
if len(s2) > len(s1):
# s3 = [1] * (len(s2) - len(s1)) + s1
raise Exception("Non-equal-rank broadcast is not supported yet.")
ret = []
for d1, d2 in zip(s1, s2):
if d1 == 1:
ret.append(d2)
elif d2 == 1:
ret.append(d1)
elif d1 == d2:
ret.append(d1)
else:
raise Exception("Cannot broadcast shapes: {} and {}".format(shape1, shape2))
return tuple(ret)
def get_conv_pool_shape(image_shape, args, out_ch, transpose):
batch, in_c, in_h, in_w = image_shape
# TODO: Handle dilation
if args.dilation_h != 1 or args.dilation_w != 1:
raise Exception("Dilation not supported yet.")
if transpose:
out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b
out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l
else:
out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1
out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1
# Handle variable-sized tensors.
if in_h == 0:
out_h = 0
if in_w == 0:
out_w = 0
out_shape = (batch, out_ch, out_h, out_w)
return out_shape
def fix_shape(shape, dim_order):
# Return the actual shape that an operand should have in NNAPI,
# given a PyTorch shape and dimension order. This is where we
# convert from PyTorch's "always NCHW" shape to explicit NHWC.
if dim_order is DimOrder.PRESUMED_CONTIGUOUS:
return shape
if dim_order is DimOrder.CHANNELS_LAST:
return tuple([shape[0]] + list(shape[2:]) + [shape[1]])
if dim_order is DimOrder.SCALAR_OR_VECTOR:
assert len(shape) == 0 or len(shape) == 1
return shape
if dim_order is DimOrder.UNKNOWN_CONSTANT:
# XXX think this through
return shape
raise Exception(f"Bad dim_order: {dim_order!r}.")
def reverse_map_dim(dim_order, d):
# Return the original PyTorch dimension position for a given dimension.
# d should be the dimension that NNAPI will see.
# reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x
# reverse_map_dim(CHANNELS_LAST, 3) == 1
if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR):
return d
assert dim_order is DimOrder.CHANNELS_LAST
return [0, 2, 3, 1][d]
def flex_name(op_id, dim):
# Return the local variable name for the computed flexible size
# for a given op and dimension.
return f"s_{op_id}_{dim}"
class _NnapiSerializer:
def __init__(self, config, use_int16_for_qint16=False):
self.operands = []
self.values = []
self.operations = []
self.value_data = []
self.operation_args = []
self.inputs = []
self.outputs = []
self.flexible_shape_computation_lines = []
self.modules = {}
self.constants = {}
self.tensor_sequences = {}
self.jitval_operand_map = {}
self.cached_immediates = {}
self.used_weights = []
self.weight_offset = 0
self.use_int16_for_qint16 = use_int16_for_qint16
if config is None:
config = {}
def get_next_operand_id(self):
Loading ...