import torch
import torch.nn as nn
from torch.ao.pruning.sparsifier.utils import module_to_fqn, fqn_to_module
from typing import Dict, List
SUPPORTED_MODULES = {
nn.Embedding,
nn.EmbeddingBag
}
def _fetch_all_embeddings(model):
"""Fetches Embedding and EmbeddingBag modules from the model
"""
embedding_modules = []
stack = [model]
while stack:
module = stack.pop()
for _, child in module.named_children():
fqn_name = module_to_fqn(model, child)
if type(child) in SUPPORTED_MODULES:
embedding_modules.append((fqn_name, child))
else:
stack.append(child)
return embedding_modules
def post_training_sparse_quantize(model,
data_sparsifier_class,
sparsify_first=True,
select_embeddings: List[nn.Module] = None,
**sparse_config):
"""Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags.
The quantization step can happen before or after sparsification depending on the `sparsify_first` argument.
Args:
- model (nn.Module)
model whose embeddings needs to be sparsified
- data_sparsifier_class (type of data sparsifier)
Type of sparsification that needs to be applied to model
- sparsify_first (bool)
if true, sparsifies first and then quantizes
otherwise, quantizes first and then sparsifies.
- select_embeddings (List of Embedding modules)
List of embedding modules to in the model to be sparsified & quantized.
If None, all embedding modules with be sparsified
- sparse_config (Dict)
config that will be passed to the constructor of data sparsifier object.
Note:
1. When `sparsify_first=False`, quantization occurs first followed by sparsification.
- before sparsifying, the embedding layers are dequantized.
- scales and zero-points are saved
- embedding layers are sparsified and `squash_mask` is applied
- embedding weights are requantized using the saved scales and zero-points
2. When `sparsify_first=True`, sparsification occurs first followed by quantization.
- embeddings are sparsified first
- quantization is applied on the sparsified embeddings
"""
data_sparsifier = data_sparsifier_class(**sparse_config)
# if select_embeddings is None, perform it on all embeddings
if select_embeddings is None:
embedding_modules = _fetch_all_embeddings(model)
else:
embedding_modules = []
assert isinstance(select_embeddings, List), "the embedding_modules must be a list of embedding modules"
for emb in select_embeddings:
assert type(emb) in SUPPORTED_MODULES, "the embedding_modules list must be an embedding or embedding bags"
fqn_name = module_to_fqn(model, emb)
assert fqn_name is not None, "the embedding modules must be part of input model"
embedding_modules.append((fqn_name, emb))
if sparsify_first:
# sparsify
for name, emb_module in embedding_modules:
valid_name = name.replace('.', '_')
data_sparsifier.add_data(name=valid_name, data=emb_module)
data_sparsifier.step()
data_sparsifier.squash_mask()
# quantize
for _, emb_module in embedding_modules:
emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
torch.ao.quantization.prepare(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)
else:
# quantize
for _, emb_module in embedding_modules:
emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
torch.ao.quantization.prepare(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)
# retrieve scale & zero_points
quantize_params: Dict[str, Dict] = {'scales': {}, 'zero_points': {},
'dequant_weights': {}, 'axis': {},
'dtype': {}}
for name, _ in embedding_modules:
quantized_emb = fqn_to_module(model, name)
assert quantized_emb is not None # satisfy mypy
quantized_weight = quantized_emb.weight() # type: ignore[operator]
quantize_params['scales'][name] = quantized_weight.q_per_channel_scales()
quantize_params['zero_points'][name] = quantized_weight.q_per_channel_zero_points()
quantize_params['dequant_weights'][name] = torch.dequantize(quantized_weight)
quantize_params['axis'][name] = quantized_weight.q_per_channel_axis()
quantize_params['dtype'][name] = quantized_weight.dtype
# attach data to sparsifier
data_sparsifier.add_data(name=name.replace('.', '_'), data=quantize_params['dequant_weights'][name])
data_sparsifier.step()
data_sparsifier.squash_mask()
for name, _ in embedding_modules:
quantized_emb = fqn_to_module(model, name)
assert quantized_emb is not None # satisfy mypy
requantized_vector = torch.quantize_per_channel(quantize_params['dequant_weights'][name],
scales=quantize_params['scales'][name],
zero_points=quantize_params['zero_points'][name],
dtype=quantize_params['dtype'][name],
axis=quantize_params['axis'][name])
quantized_emb.set_weight(requantized_vector) # type: ignore[operator]