import torch
from torch import nn
def _replace_relu(module):
reassign = {}
for name, mod in module.named_children():
_replace_relu(mod)
# Checking for explicit type instead of instance
# as we only want to replace modules of the exact type
# not inherited classes
if type(mod) == nn.ReLU or type(mod) == nn.ReLU6:
reassign[name] = nn.ReLU(inplace=False)
for key, value in reassign.items():
module._modules[key] = value
def quantize_model(model, backend):
_dummy_input_data = torch.rand(1, 3, 299, 299)
if backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported ")
torch.backends.quantized.engine = backend
model.eval()
# Make sure that weight qconfig matches that of the serialized models
if backend == 'fbgemm':
model.qconfig = torch.quantization.QConfig(
activation=torch.quantization.default_observer,
weight=torch.quantization.default_per_channel_weight_observer)
elif backend == 'qnnpack':
model.qconfig = torch.quantization.QConfig(
activation=torch.quantization.default_observer,
weight=torch.quantization.default_weight_observer)
model.fuse_model()
torch.quantization.prepare(model, inplace=True)
model(_dummy_input_data)
torch.quantization.convert(model, inplace=True)
return