Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ _inductor / triton_ops / conv1x1.py

import torch

from ..utils import has_triton

if has_triton():

    import triton

    class _conv1x1:
        @staticmethod
        def _call(
            x,
            w,
            bias,
            stride,
            padding,
            dilation,
            transposed,
            output_padding,
            groups,
        ):
            # Q: should we check x, w, bias dtypes?
            device = x.device
            # input shapes
            shape_x = x.shape
            shape_w = w.shape
            shape_bias = bias.shape if bias is not None else None

            # indicies for the layout
            xn, xc, xh, xw = 0, 1, 2, 3
            yn, yc, yh, yw = 0, 1, 2, 3
            wn, wc, wh, ww = 0, 1, 2, 3

            # out_channel, in_channel, kernel_height, kernel_width
            kernel_size = [shape_w[wh], shape_w[ww]]
            input_size = [shape_x[xh], shape_x[xw]]
            assert (
                not shape_bias or shape_bias[0] == shape_w[wn]
            ), f"bias shape did not match{shape_bias} != {shape_w[wn]}"
            in_channel = shape_w[wc] * groups

            assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups"
            assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups"
            assert (
                shape_x[xc] == in_channel
            ), f"in_channel did not match {shape_x[xc]} != {in_channel}"

            assert (
                len(stride)
                == len(padding)
                == len(dilation)
                == len(output_padding)
                == len(kernel_size)
                == len(input_size)
            )

            # output shape
            shape_y = [0] * 4
            shape_y[yn] = shape_x[xn]
            shape_y[yc] = shape_w[wn]
            shape_y[yh] = (
                input_size[0]
                + 2 * padding[0]
                - dilation[0] * (kernel_size[0] - 1)
                - 1
                + stride[0]
            ) // stride[0] + 2 * output_padding[0]
            shape_y[yw] = (
                input_size[1]
                + 2 * padding[1]
                - dilation[1] * (kernel_size[1] - 1)
                - 1
                + stride[1]
            ) // stride[1] + 2 * output_padding[1]

            BATCH = shape_x[xn]
            IN_C = shape_x[xc]
            # IN_H = shape_x[xh]
            # IN_W = shape_x[xw]
            KERNEL_N = shape_w[wn]
            KERNEL_H = shape_w[wh]
            KERNEL_W = shape_w[ww]
            OUT_H = shape_y[yh]
            OUT_W = shape_y[yw]

            assert KERNEL_H == 1 and KERNEL_W == 1, "only support 1x1 conv"
            channels_last = x.stride()[1] == 1

            if padding == (0, 0):
                # nchw -> nhwc
                x = x.permute(0, 2, 3, 1)
                # select every stride's element (for stride > 1)
                x = x[:, :: stride[0], :: stride[1], :]
                # 2d matrix
                mat_x = x.reshape(-1, IN_C)
                # 2d matrix
                mat_w = w.view(KERNEL_N, IN_C)
                mat_w = mat_w.permute(1, 0)
                # 2d matrix y, (BATCH * OUT_H * OUT_W, KERNEL_N)
                mat_y = triton.ops.matmul(mat_x, mat_w)
                # mat_y = torch.empty((BATCH * OUT_H * OUT_W, KERNEL_N), device=device, dtype=x.dtype,)
                y = mat_y.view(BATCH, OUT_H, OUT_W, KERNEL_N)
                if bias is not None:
                    y += bias
                # convert back to the original layout of y
                # nhwc -> nchw
                y = y.permute(0, 3, 1, 2)
                if not channels_last:
                    y = y.to(memory_format=torch.contiguous_format)
                return y

            else:
                y = torch.empty(
                    (shape_y[yn], shape_y[yh], shape_y[yw], shape_y[yc]),
                    device=device,
                    dtype=x.dtype,
                )
                if channels_last:
                    y = y.to(memory_format=torch.channels_last)
                # y = bias.repeat((shape_y[yn], shape_y[yh], shape_y[yw], 1)).to(device).type(x.dtype)
                # convert x to channel-last layout;
                # don't care w layout since kernel size is 1
                x = x.permute(0, 2, 3, 1)
                # select every stride"s element (for stride > 1)
                x = x[:, :: stride[0], :: stride[1], :]
                # 2d matrix
                mat_x = x.view(-1, IN_C)
                # 2d matrix
                mat_w = w.view(KERNEL_N, IN_C)
                mat_w = mat_w.permute(1, 0)
                # 2d matrix y, (BATCH * (OUT_H-2*padding) * (OUT_W-2*padding), KERNEL_N)
                mat_y = triton.ops.matmul(mat_x, mat_w)
                mat_y = mat_y.view(
                    BATCH, OUT_H - 2 * padding[0], OUT_W - 2 * padding[1], KERNEL_N
                )
                # consider padding > 0
                if bias is not None:
                    y[
                        :,
                        padding[0] : OUT_H - padding[0],
                        padding[1] : OUT_W - padding[1],
                        :,
                    ] = (
                        mat_y + bias
                    )
                    y[:, : padding[0], :, :] = bias
                    y[:, :, : padding[1], :] = bias
                    y[:, OUT_H - padding[0] :, :, :] = bias
                    y[:, :, OUT_W - padding[1] :, :] = bias
                else:
                    y[
                        :,
                        padding[0] : OUT_H - padding[0],
                        padding[1] : OUT_W - padding[1],
                        :,
                    ] = mat_y
                    y[:, : padding[0], :, :] = 0
                    y[:, :, : padding[1], :] = 0
                    y[:, OUT_H - padding[0] :, :, :] = 0
                    y[:, :, OUT_W - padding[1] :, :] = 0
                # convert back to the original layout of y
                # nhwc -> nchw
                y = y.permute(0, 3, 1, 2)
                return y

        @staticmethod
        def forward(
            x,
            w,
            bias,
            stride=(1, 1),
            padding=(0, 0),
            dilation=(1, 1),
            transposed=False,
            output_padding=(0, 0),
            groups=1,
        ):
            if groups != 1:
                print(f"Do not support groups = {groups}")
                return
            if transposed:
                print("Do not support transposed")
            return _conv1x1._call(
                x,
                w,
                bias,
                stride,
                padding,
                dilation,
                transposed,
                output_padding,
                groups,
            )

    conv1x1 = _conv1x1.forward