"""
helper class that supports empty tensors on some nn functions.
Ideally, add support directly in PyTorch to empty tensors in
those functions.
This can be removed once https://github.com/pytorch/pytorch/issues/12013
is implemented
"""
import warnings
import torch
from torch import Tensor, Size
from torch.jit.annotations import List, Optional, Tuple
class Conv2d(torch.nn.Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
"torchvision.ops.misc.Conv2d is deprecated and will be "
"removed in future versions, use torch.nn.Conv2d instead.", FutureWarning)
class ConvTranspose2d(torch.nn.ConvTranspose2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
"torchvision.ops.misc.ConvTranspose2d is deprecated and will be "
"removed in future versions, use torch.nn.ConvTranspose2d instead.", FutureWarning)
class BatchNorm2d(torch.nn.BatchNorm2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
"torchvision.ops.misc.BatchNorm2d is deprecated and will be "
"removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning)
interpolate = torch.nn.functional.interpolate
# This is not in nn
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters
are fixed
"""
def __init__(
self,
num_features: int,
eps: float = 0.,
n: Optional[int] = None,
):
# n=None for backward-compatibility
if n is not None:
warnings.warn("`n` argument is deprecated and has been renamed `num_features`",
DeprecationWarning)
num_features = n
super(FrozenBatchNorm2d, self).__init__()
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def forward(self, x: Tensor) -> Tensor:
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * (rv + self.eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]})"