Repository URL to install this package:
|
Version:
1.11.0 ▾
|
ccc-model-manager
/
lib
/
python3.9
/
site-packages
/
functorch
/
experimental
/
batch_norm_replacement.py
|
|---|
import torch.nn as nn
def batch_norm_without_running_stats(module: nn.Module):
if isinstance(module, nn.modules.batchnorm._BatchNorm) and module.track_running_stats:
module.running_mean = None
module.running_var = None
module.num_batches_tracked = None
module.track_running_stats = False
def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module:
"""
In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and
setting track_running_stats to be False for any nn.BatchNorm module in :attr:`root`
"""
# base case
batch_norm_without_running_stats(root)
for obj in root.modules():
batch_norm_without_running_stats(obj)
return root