import torch
from ..utils import has_triton
if has_triton():
import triton
class _conv1x1:
@staticmethod
def _call(
x,
w,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
):
# Q: should we check x, w, bias dtypes?
device = x.device
# input shapes
shape_x = x.shape
shape_w = w.shape
shape_bias = bias.shape if bias is not None else None
# indicies for the layout
xn, xc, xh, xw = 0, 1, 2, 3
yn, yc, yh, yw = 0, 1, 2, 3
wn, wc, wh, ww = 0, 1, 2, 3
# out_channel, in_channel, kernel_height, kernel_width
kernel_size = [shape_w[wh], shape_w[ww]]
input_size = [shape_x[xh], shape_x[xw]]
assert (
not shape_bias or shape_bias[0] == shape_w[wn]
), f"bias shape did not match{shape_bias} != {shape_w[wn]}"
in_channel = shape_w[wc] * groups
assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups"
assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups"
assert (
shape_x[xc] == in_channel
), f"in_channel did not match {shape_x[xc]} != {in_channel}"
assert (
len(stride)
== len(padding)
== len(dilation)
== len(output_padding)
== len(kernel_size)
== len(input_size)
)
# output shape
shape_y = [0] * 4
shape_y[yn] = shape_x[xn]
shape_y[yc] = shape_w[wn]
shape_y[yh] = (
input_size[0]
+ 2 * padding[0]
- dilation[0] * (kernel_size[0] - 1)
- 1
+ stride[0]
) // stride[0] + 2 * output_padding[0]
shape_y[yw] = (
input_size[1]
+ 2 * padding[1]
- dilation[1] * (kernel_size[1] - 1)
- 1
+ stride[1]
) // stride[1] + 2 * output_padding[1]
BATCH = shape_x[xn]
IN_C = shape_x[xc]
# IN_H = shape_x[xh]
# IN_W = shape_x[xw]
KERNEL_N = shape_w[wn]
KERNEL_H = shape_w[wh]
KERNEL_W = shape_w[ww]
OUT_H = shape_y[yh]
OUT_W = shape_y[yw]
assert KERNEL_H == 1 and KERNEL_W == 1, "only support 1x1 conv"
channels_last = x.stride()[1] == 1
if padding == (0, 0):
# nchw -> nhwc
x = x.permute(0, 2, 3, 1)
# select every stride's element (for stride > 1)
x = x[:, :: stride[0], :: stride[1], :]
# 2d matrix
mat_x = x.reshape(-1, IN_C)
# 2d matrix
mat_w = w.view(KERNEL_N, IN_C)
mat_w = mat_w.permute(1, 0)
# 2d matrix y, (BATCH * OUT_H * OUT_W, KERNEL_N)
mat_y = triton.ops.matmul(mat_x, mat_w)
# mat_y = torch.empty((BATCH * OUT_H * OUT_W, KERNEL_N), device=device, dtype=x.dtype,)
y = mat_y.view(BATCH, OUT_H, OUT_W, KERNEL_N)
if bias is not None:
y += bias
# convert back to the original layout of y
# nhwc -> nchw
y = y.permute(0, 3, 1, 2)
if not channels_last:
y = y.to(memory_format=torch.contiguous_format)
return y
else:
y = torch.empty(
(shape_y[yn], shape_y[yh], shape_y[yw], shape_y[yc]),
device=device,
dtype=x.dtype,
)
if channels_last:
y = y.to(memory_format=torch.channels_last)
# y = bias.repeat((shape_y[yn], shape_y[yh], shape_y[yw], 1)).to(device).type(x.dtype)
# convert x to channel-last layout;
# don't care w layout since kernel size is 1
x = x.permute(0, 2, 3, 1)
# select every stride"s element (for stride > 1)
x = x[:, :: stride[0], :: stride[1], :]
# 2d matrix
mat_x = x.view(-1, IN_C)
# 2d matrix
mat_w = w.view(KERNEL_N, IN_C)
mat_w = mat_w.permute(1, 0)
# 2d matrix y, (BATCH * (OUT_H-2*padding) * (OUT_W-2*padding), KERNEL_N)
mat_y = triton.ops.matmul(mat_x, mat_w)
mat_y = mat_y.view(
BATCH, OUT_H - 2 * padding[0], OUT_W - 2 * padding[1], KERNEL_N
)
# consider padding > 0
if bias is not None:
y[
:,
padding[0] : OUT_H - padding[0],
padding[1] : OUT_W - padding[1],
:,
] = (
mat_y + bias
)
y[:, : padding[0], :, :] = bias
y[:, :, : padding[1], :] = bias
y[:, OUT_H - padding[0] :, :, :] = bias
y[:, :, OUT_W - padding[1] :, :] = bias
else:
y[
:,
padding[0] : OUT_H - padding[0],
padding[1] : OUT_W - padding[1],
:,
] = mat_y
y[:, : padding[0], :, :] = 0
y[:, :, : padding[1], :] = 0
y[:, OUT_H - padding[0] :, :, :] = 0
y[:, :, OUT_W - padding[1] :, :] = 0
# convert back to the original layout of y
# nhwc -> nchw
y = y.permute(0, 3, 1, 2)
return y
@staticmethod
def forward(
x,
w,
bias,
stride=(1, 1),
padding=(0, 0),
dilation=(1, 1),
transposed=False,
output_padding=(0, 0),
groups=1,
):
if groups != 1:
print(f"Do not support groups = {groups}")
return
if transposed:
print("Do not support transposed")
return _conv1x1._call(
x,
w,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
)
conv1x1 = _conv1x1.forward