r"""
Pruning methods
"""
import numbers
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Tuple
import torch
class BasePruningMethod(ABC):
r"""Abstract base class for creation of new pruning techniques.
Provides a skeleton for customization requiring the overriding of methods
such as :meth:`compute_mask` and :meth:`apply`.
"""
_tensor_name: str
def __init__(self):
pass
def __call__(self, module, inputs):
r"""Multiplies the mask (stored in ``module[name + '_mask']``)
into the original tensor (stored in ``module[name + '_orig']``)
and stores the result into ``module[name]`` by using
:meth:`apply_mask`.
Args:
module (nn.Module): module containing the tensor to prune
inputs: not used.
"""
setattr(module, self._tensor_name, self.apply_mask(module))
@abstractmethod
def compute_mask(self, t, default_mask):
r"""Computes and returns a mask for the input tensor ``t``.
Starting from a base ``default_mask`` (which should be a mask of ones
if the tensor has not been pruned yet), generate a random mask to
apply on top of the ``default_mask`` according to the specific pruning
method recipe.
Args:
t (torch.Tensor): tensor representing the importance scores of the
parameter to prune.
default_mask (torch.Tensor): Base mask from previous pruning
iterations, that need to be respected after the new mask is
applied. Same dims as ``t``.
Returns:
mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
"""
pass
def apply_mask(self, module):
r"""Simply handles the multiplication between the parameter being
pruned and the generated mask.
Fetches the mask and the original tensor from the module
and returns the pruned version of the tensor.
Args:
module (nn.Module): module containing the tensor to prune
Returns:
pruned_tensor (torch.Tensor): pruned version of the input tensor
"""
# to carry out the multiplication, the mask needs to have been computed,
# so the pruning method must know what tensor it's operating on
assert self._tensor_name is not None, "Module {} has to be pruned".format(
module
) # this gets set in apply()
mask = getattr(module, self._tensor_name + "_mask")
orig = getattr(module, self._tensor_name + "_orig")
pruned_tensor = mask.to(dtype=orig.dtype) * orig
return pruned_tensor
@classmethod
def apply(cls, module, name, *args, importance_scores=None, **kwargs):
r"""Adds the forward pre-hook that enables pruning on the fly and
the reparametrization of a tensor in terms of the original tensor
and the pruning mask.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
args: arguments passed on to a subclass of
:class:`BasePruningMethod`
importance_scores (torch.Tensor): tensor of importance scores (of
same shape as module parameter) used to compute mask for pruning.
The values in this tensor indicate the importance of the
corresponding elements in the parameter being pruned.
If unspecified or None, the parameter will be used in its place.
kwargs: keyword arguments passed on to a subclass of a
:class:`BasePruningMethod`
"""
def _get_composite_method(cls, module, name, *args, **kwargs):
# Check if a pruning method has already been applied to
# `module[name]`. If so, store that in `old_method`.
old_method = None
found = 0
# there should technically be only 1 hook with hook.name == name
# assert this using `found`
hooks_to_remove = []
for k, hook in module._forward_pre_hooks.items():
# if it exists, take existing thing, remove hook, then
# go through normal thing
if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
old_method = hook
hooks_to_remove.append(k)
found += 1
assert (
found <= 1
), "Avoid adding multiple pruning hooks to the\
same tensor {} of module {}. Use a PruningContainer.".format(
name, module
)
for k in hooks_to_remove:
del module._forward_pre_hooks[k]
# Apply the new pruning method, either from scratch or on top of
# the previous one.
method = cls(*args, **kwargs) # new pruning
# Have the pruning method remember what tensor it's been applied to
method._tensor_name = name
# combine `methods` with `old_method`, if `old_method` exists
if old_method is not None: # meaning that there was a hook
# if the hook is already a pruning container, just add the
# new pruning method to the container
if isinstance(old_method, PruningContainer):
old_method.add_pruning_method(method)
method = old_method # rename old_method --> method
# if the hook is simply a single pruning method, create a
# container, add the old pruning method and the new one
elif isinstance(old_method, BasePruningMethod):
container = PruningContainer(old_method)
# Have the pruning method remember the name of its tensor
# setattr(container, '_tensor_name', name)
container.add_pruning_method(method)
method = container # rename container --> method
return method
method = _get_composite_method(cls, module, name, *args, **kwargs)
# at this point we have no forward_pre_hooks but we could have an
# active reparametrization of the tensor if another pruning method
# had been applied (in which case `method` would be a PruningContainer
# and not a simple pruning method).
# Pruning is to be applied to the module's tensor named `name`,
# starting from the state it is found in prior to this iteration of
# pruning. The pruning mask is calculated based on importances scores.
orig = getattr(module, name)
if importance_scores is not None:
assert (
importance_scores.shape == orig.shape
), "importance_scores should have the same shape as parameter \
{} of {}".format(
name, module
)
else:
importance_scores = orig
# If this is the first time pruning is applied, take care of moving
# the original tensor to a new parameter called name + '_orig' and
# and deleting the original parameter
if not isinstance(method, PruningContainer):
# copy `module[name]` to `module[name + '_orig']`
module.register_parameter(name + "_orig", orig)
# temporarily delete `module[name]`
del module._parameters[name]
default_mask = torch.ones_like(orig) # temp
# If this is not the first time pruning is applied, all of the above
# has been done before in a previous pruning iteration, so we're good
# to go
else:
default_mask = (
getattr(module, name + "_mask")
.detach()
.clone(memory_format=torch.contiguous_format)
)
# Use try/except because if anything goes wrong with the mask
# computation etc., you'd want to roll back.
try:
# get the final mask, computed according to the specific method
mask = method.compute_mask(importance_scores, default_mask=default_mask)
# reparametrize by saving mask to `module[name + '_mask']`...
module.register_buffer(name + "_mask", mask)
# ... and the new pruned tensor to `module[name]`
setattr(module, name, method.apply_mask(module))
# associate the pruning method to the module via a hook to
# compute the function before every forward() (compile by run)
module.register_forward_pre_hook(method)
except Exception as e:
if not isinstance(method, PruningContainer):
orig = getattr(module, name + "_orig")
module.register_parameter(name, orig)
del module._parameters[name + "_orig"]
raise e
return method
def prune(self, t, default_mask=None, importance_scores=None):
r"""Computes and returns a pruned version of input tensor ``t``
according to the pruning rule specified in :meth:`compute_mask`.
Args:
t (torch.Tensor): tensor to prune (of same dimensions as
``default_mask``).
importance_scores (torch.Tensor): tensor of importance scores (of
same shape as ``t``) used to compute mask for pruning ``t``.
The values in this tensor indicate the importance of the
corresponding elements in the ``t`` that is being pruned.
If unspecified or None, the tensor ``t`` will be used in its place.
default_mask (torch.Tensor, optional): mask from previous pruning
iteration, if any. To be considered when determining what
portion of the tensor that pruning should act on. If None,
default to a mask of ones.
Returns:
pruned version of tensor ``t``.
"""
if importance_scores is not None:
assert (
importance_scores.shape == t.shape
), "importance_scores should have the same shape as tensor t"
else:
importance_scores = t
default_mask = default_mask if default_mask is not None else torch.ones_like(t)
return t * self.compute_mask(importance_scores, default_mask=default_mask)
def remove(self, module):
r"""Removes the pruning reparameterization from a module. The pruned
parameter named ``name`` remains permanently pruned, and the parameter
named ``name+'_orig'`` is removed from the parameter list. Similarly,
the buffer named ``name+'_mask'`` is removed from the buffers.
Note:
Pruning itself is NOT undone or reversed!
"""
# before removing pruning from a tensor, it has to have been applied
assert (
self._tensor_name is not None
), "Module {} has to be pruned\
before pruning can be removed".format(
module
) # this gets set in apply()
# to update module[name] to latest trained weights
weight = self.apply_mask(module) # masked weights
# delete and reset
if hasattr(module, self._tensor_name):
delattr(module, self._tensor_name)
orig = module._parameters[self._tensor_name + "_orig"]
orig.data = weight.data
del module._parameters[self._tensor_name + "_orig"]
del module._buffers[self._tensor_name + "_mask"]
setattr(module, self._tensor_name, orig)
class PruningContainer(BasePruningMethod):
"""Container holding a sequence of pruning methods for iterative pruning.
Keeps track of the order in which pruning methods are applied and handles
combining successive pruning calls.
Accepts as argument an instance of a BasePruningMethod or an iterable of
them.
"""
def __init__(self, *args):
self._pruning_methods: Tuple["BasePruningMethod", ...] = tuple()
if not isinstance(args, Iterable): # only 1 item
self._tensor_name = args._tensor_name
self.add_pruning_method(args)
elif len(args) == 1: # only 1 item in a tuple
self._tensor_name = args[0]._tensor_name
self.add_pruning_method(args[0])
else: # manual construction from list or other iterable (or no args)
for method in args:
self.add_pruning_method(method)
def add_pruning_method(self, method):
r"""Adds a child pruning ``method`` to the container.
Args:
method (subclass of BasePruningMethod): child pruning method
to be added to the container.
"""
# check that we're adding a pruning method to the container
if not isinstance(method, BasePruningMethod) and method is not None:
raise TypeError(
"{} is not a BasePruningMethod subclass".format(type(method))
)
elif method is not None and self._tensor_name != method._tensor_name:
raise ValueError(
"Can only add pruning methods acting on "
"the parameter named '{}' to PruningContainer {}.".format(
self._tensor_name, self
)
+ " Found '{}'".format(method._tensor_name)
)
# if all checks passed, add to _pruning_methods tuple
self._pruning_methods += (method,) # type: ignore[operator]
def __len__(self):
return len(self._pruning_methods)
def __iter__(self):
return iter(self._pruning_methods)
def __getitem__(self, idx):
return self._pruning_methods[idx]
def compute_mask(self, t, default_mask):
r"""Applies the latest ``method`` by computing the new partial masks
and returning its combination with the ``default_mask``.
The new partial mask should be computed on the entries or channels
that were not zeroed out by the ``default_mask``.
Which portions of the tensor ``t`` the new mask will be calculated from
depends on the ``PRUNING_TYPE`` (handled by the type handler):
* for 'unstructured', the mask will be computed from the raveled
list of nonmasked entries;
* for 'structured', the mask will be computed from the nonmasked
channels in the tensor;
* for 'global', the mask will be computed across all entries.
Args:
t (torch.Tensor): tensor representing the parameter to prune
(of same dimensions as ``default_mask``).
default_mask (torch.Tensor): mask from previous pruning iteration.
Returns:
mask (torch.Tensor): new mask that combines the effects
of the ``default_mask`` and the new mask from the current
pruning ``method`` (of same dimensions as ``default_mask`` and
Loading ...