import math
import torch
import torch.nn as nn
class LinearReluFunctionalChild(nn.Module):
def __init__(self, N):
super().__init__()
self.w1 = nn.Parameter(torch.empty(N, N))
self.b1 = nn.Parameter(torch.zeros(N))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
def forward(self, x):
x = torch.nn.functional.linear(x, self.w1, self.b1)
x = torch.nn.functional.relu(x)
return x
class LinearReluFunctional(nn.Module):
def __init__(self, N):
super().__init__()
self.child = LinearReluFunctionalChild(N)
self.w1 = nn.Parameter(torch.empty(N, N))
self.b1 = nn.Parameter(torch.zeros(N))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
def forward(self, x):
x = self.child(x)
x = torch.nn.functional.linear(x, self.w1, self.b1)
x = torch.nn.functional.relu(x)
return x