import torch
import torch.nn as nn
import torch.ao.nn.intrinsic as nni
import torch.nn.functional as F
import torch.ao.nn.quantized.reference as nnqr
from ._common_operator_config_utils import (
_get_conv_configs,
_get_linear_configs,
_get_binary_op_configs,
_get_bn_configs,
_get_cat_config,
_get_default_op_configs,
_get_embedding_op_configs,
_get_fixed_qparams_op_configs,
_get_ln_configs,
_get_rnn_op_configs,
_get_share_qparams_op_configs,
)
from .backend_config import (
BackendPatternConfig,
BackendConfig,
DTypeConfig,
ObservationType,
)
from ..fuser_method_mappings import (
_sequential_wrapper2,
)
import operator
from torch.ao.quantization.utils import MatchAllNode
import itertools
# ===================
# | DTYPE CONFIGS |
# ===================
onednn_weighted_op_int8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
)
onednn_op_quint8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
)
onednn_dynamic_int8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.float,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
is_dynamic=True,
)
onednn_weight_only_qint8_dtype_config = DTypeConfig(
input_dtype=torch.float,
output_dtype=torch.float,
weight_dtype=torch.qint8,
)
onednn_input_output_only_quint8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.float,
bias_dtype=torch.float,
)
# ===================
# | FUSER METHODS |
# ===================
def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
r"""Given the linear, bn and leaky_relu modules, fuses them and returns the fused module
Args:
is_qat: a flag for whether we are using quantization aware training fusion
or post training quantization fusion
linear: Module instance of type Linear
bn: BatchNorm1d instance that needs to be fused with the linear layer
leaky_relu: LeakyReLU instance that needs to be fused with the linear layer
Examples::
>>> # xdoctest: +SKIP(failing)
>>> m1 = nn.Linear(20, 10)
>>> b1 = nn.BatchNorm1d(10)
>>> lr = nn.LeakyReLU(0.01)
>>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
"""
assert(linear.training == bn.training and bn.training == leaky_relu.training),\
"Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
if is_qat:
raise NotImplementedError("Cannot fuse train modules: {}".format((linear, bn, leaky_relu)))
else:
map_to_fused_module_eval = {
nn.Linear: nni.LinearLeakyReLU,
}
fused_module = map_to_fused_module_eval.get(type(linear), None)
if fused_module is not None:
fused_linear = nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
fm = fused_module(fused_linear, leaky_relu)
return fm
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((linear, bn, leaky_relu)))
# ======================
# | CONFIGS FOR CONV |
# ======================
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
conv_dtype_configs = [onednn_weighted_op_int8_dtype_config]
conv_configs = _get_conv_configs(conv_dtype_configs)
# (1) Conv2d + Add
# conv2d Y
# \ /
# add
# include:
# conv2d conv2d
# \ /
# add
def _fuse_conv_add_left(is_qat, add, conv, _):
return nni.ConvAdd2d(conv, add)
def _conv_add_root_node_getter_left(pattern):
_, conv, _ = pattern
return conv
def _conv_add_extra_inputs_getter_left(pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
_, conv, extra_input = pattern
return [extra_input]
# conv2d
# \
# bn Y
# \ /
# add
def _fuse_conv_bn_add_left(is_qat, add, bn_conv, _):
bn, conv = bn_conv
if is_qat:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, add)))
else:
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
return nni.ConvAdd2d(fused_conv, add)
def _conv_bn_add_root_node_getter_left(add_pattern):
_, bn_conv, _ = add_pattern
bn, conv = bn_conv
return conv
def _conv_bn_add_extra_inputs_getter_left(add_pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
_, bn_conv, extra_input = add_pattern
bn, conv = bn_conv
return [extra_input]
conv_add_left_optioins = itertools.product(
[True, False], # with_bn
[torch.add, operator.add], # add_op
)
for with_bn, add_op in conv_add_left_optioins:
if with_bn:
conv_configs.append(
BackendPatternConfig()
._set_pattern_complex_format((add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) # noqa: E131
.set_observation_type(observation_type)
.set_dtype_configs(conv_dtype_configs)
.set_fuser_method(_fuse_conv_bn_add_left)
._set_root_node_getter(_conv_bn_add_root_node_getter_left)
._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_left)
.set_fused_module(nni.ConvAdd2d))
else:
conv_configs.append(
BackendPatternConfig()
._set_pattern_complex_format((add_op, nn.Conv2d, MatchAllNode)) # noqa: E131
.set_observation_type(observation_type)
.set_dtype_configs(conv_dtype_configs)
.set_fuser_method(_fuse_conv_add_left)
._set_root_node_getter(_conv_add_root_node_getter_left)
._set_extra_inputs_getter(_conv_add_extra_inputs_getter_left)
.set_fused_module(nni.ConvAdd2d))
# Y conv2d
# \ /
# add
def _fuse_conv_add_right(is_qat, add, _, conv):
return nni.ConvAdd2d(conv, add)
def _conv_add_root_node_getter_right(pattern):
add, _, conv = pattern
return conv
def _conv_add_extra_inputs_getter_right(pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
_, extra_input, conv = pattern
return [extra_input]
# conv2d
# /
# Y bn
# \ /
# add
def _fuse_conv_bn_add_right(is_qat, add, _, bn_conv):
bn, conv = bn_conv
if is_qat:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, add)))
else:
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
return nni.ConvAdd2d(fused_conv, add)
def _conv_bn_add_root_node_getter_right(pattern):
add, _, bn_conv = pattern
bn, conv = bn_conv
return conv
def _conv_bn_add_extra_inputs_getter_right(pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
_, extra_input, bn_conv = pattern
bn, conv = bn_conv
return [extra_input]
conv_add_optioins = itertools.product(
[True, False], # with_bn
[torch.add, operator.add], # add_op
)
for with_bn, add_op in conv_add_optioins:
if with_bn:
conv_configs.append(
BackendPatternConfig()
._set_pattern_complex_format((add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) # noqa: E131
.set_observation_type(observation_type)
.set_dtype_configs(conv_dtype_configs)
.set_fuser_method(_fuse_conv_bn_add_right)
._set_root_node_getter(_conv_bn_add_root_node_getter_right)
._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_right)
.set_fused_module(nni.ConvAdd2d))
else:
conv_configs.append(
BackendPatternConfig()
._set_pattern_complex_format((add_op, MatchAllNode, nn.Conv2d)) # noqa: E131
.set_observation_type(observation_type)
.set_dtype_configs(conv_dtype_configs)
.set_fuser_method(_fuse_conv_add_right)
._set_root_node_getter(_conv_add_root_node_getter_right)
._set_extra_inputs_getter(_conv_add_extra_inputs_getter_right)
.set_fused_module(nni.ConvAdd2d))
conv_configs.append(
BackendPatternConfig(nni.ConvAdd2d)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(conv_dtype_configs)
.set_root_module(nn.Conv2d)
.set_reference_quantized_module(nnqr.Conv2d))
# (2) Conv2d + Add + Relu
# conv2d Y
# \ /
# add
# \
# relu
def _fuse_conv_add_relu_left(is_qat, relu, add_pattern):
add, conv, _ = add_pattern
return nni.ConvAddReLU2d(conv, add, relu)
def _conv_add_relu_root_node_getter_left(pattern):
relu, add_pattern = pattern
_, conv, _ = add_pattern
return conv
def _conv_add_relu_extra_inputs_getter_left(pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu, add_pattern = pattern
_, conv, extra_input = add_pattern
return [extra_input]
# conv2d
# \
# bn Y
# \ /
# add
# \
# relu
def _fuse_conv_bn_add_relu_left(is_qat, relu, add_pattern):
add, bn_conv, _ = add_pattern
bn, conv = bn_conv
if is_qat:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, add, relu)))
else:
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
return nni.ConvAddReLU2d(fused_conv, add, relu)
def _conv_bn_add_relu_root_node_getter_left(pattern):
relu, add_pattern = pattern
_, bn_conv, _ = add_pattern
bn, conv = bn_conv
return conv
def _conv_bn_add_relu_extra_inputs_getter_left(pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu, add_pattern = pattern
_, bn_conv, extra_input = add_pattern
bn, conv = bn_conv
return [extra_input]
conv_add_relu_left_optioins = itertools.product(
[True, False], # with_bn
[torch.add, operator.add], # add_op
)
for with_bn, add_op in conv_add_relu_left_optioins:
if with_bn:
conv_configs.append(
BackendPatternConfig()
._set_pattern_complex_format((nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) # noqa: E131
.set_observation_type(observation_type)
.set_dtype_configs(conv_dtype_configs)
.set_fuser_method(_fuse_conv_bn_add_relu_left)
._set_root_node_getter(_conv_bn_add_relu_root_node_getter_left)
._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_left)
.set_fused_module(nni.ConvAddReLU2d))
else:
conv_configs.append(
Loading ...