Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ nn / utils / _expanded_weights / linear_expanded_weights.py

import torch
import torch.nn.functional as F
from .expanded_weights_impl import implements_per_sample_grads
from .expanded_weights_utils import \
    forward_helper, set_grad_sample_if_exists, unpack_expanded_weight_or_tensor, is_batch_first
from typing import List, Optional

@implements_per_sample_grads(F.linear)
class LinearPerSampleGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, _, __, *expanded_args_and_kwargs):
        if len(expanded_args_and_kwargs[0].shape) <= 1:
            raise RuntimeError("Input does not have a batch dimension. Expanded Weights expected input "
                               f"of at least rank 2, got of rank {len(expanded_args_and_kwargs[0].shape)}")
        expanded_kwargs = {'bias': expanded_args_and_kwargs[2] if len(expanded_args_and_kwargs) == 3 else None}
        expanded_args = expanded_args_and_kwargs[:2]
        ctx.batch_first = is_batch_first(expanded_args_and_kwargs)
        output = forward_helper(F.linear, expanded_args, expanded_kwargs)
        ctx.args = expanded_args
        ctx.kwargs = expanded_kwargs
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.args
        bias = ctx.kwargs['bias']
        results: List[Optional[torch.Tensor]] = []
        results.append(None)  # for kwarg_names
        results.append(None)  # for op reference

        if input.requires_grad:
            results.append(grad_output.matmul(unpack_expanded_weight_or_tensor(weight)))
        else:
            results.append(None)
        results.extend([None] * 2)  # weight and bias don't compute batched gradients

        if not ctx.batch_first:
            grad_output = grad_output.transpose(0, 1)
            input = input.transpose(0, 1)

        # weight and bias get their grad_sample fields set directly if they exist
        set_grad_sample_if_exists(weight, lambda _: torch.einsum("n...i,n...j->nij", grad_output, input))
        set_grad_sample_if_exists(bias, lambda _: torch.einsum("n...k->nk", grad_output))
        return tuple(results)