Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ _inductor / triton_ops / conv.py

import torch

from ..utils import has_triton

if has_triton():
    import triton
    import triton.language as tl

    from .autotune import conv_heuristics
    from .utils import _unpack

    @conv_heuristics()
    @triton.jit
    def _kernel_delta_x_hwc(
        x,
        w,
        y,
        # stride of tensor
        stride_xn,
        stride_xc,
        stride_xh,
        stride_xw,
        stride_wn,
        stride_wc,
        stride_wh,
        stride_ww,
        stride_yn,
        stride_yc,
        stride_yh,
        stride_yw,
        stride_biasn,
        # pointer inc for x
        delta_xh_ptr,
        delta_xw_ptr,
        delta_xc_ptr,
        # Tensor dimensions
        BATCH,
        IN_C,
        IN_H,
        IN_W,
        KERNEL_N,
        KERNEL_H,
        KERNEL_W,
        OUT_H,
        OUT_W,
        # parameters of conv
        stride_h,
        stride_w,
        padding_h,
        padding_w,
        dilation_h,
        dilation_w,
        output_padding_h,
        output_padding_w,
        groups,
        # Metaparameters
        ACC_TYPE: tl.constexpr,
        CONV1X1_NHWC: tl.constexpr,
        # blocks in different dimension
        BLOCK_M: tl.constexpr,
        BLOCK_N: tl.constexpr,
        # reduction tiling parameter for matmul
        BLOCK_K: tl.constexpr,
        # Super-blocking for better L2 peformance
        GROUP_H: tl.constexpr,
    ):
        """
        each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
        """
        # -----------------------------------------------------------
        # Map program ids `pid` to the block of y it should compute.
        pid_nhw = tl.program_id(0)
        pid_k = tl.program_id(1)

        # offset for output y
        off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
        off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
        off_y_n = off_y_nhw // (OUT_H * OUT_W)
        off_y_hw = off_y_nhw % (OUT_H * OUT_W)
        off_y_h = off_y_hw // OUT_W + output_padding_h
        off_y_w = off_y_hw % OUT_W + output_padding_w

        # offset for the initial ptr for x
        off_x_n = off_y_n
        off_x_h = off_y_h * stride_h - padding_h
        off_x_w = off_y_w * stride_w - padding_w
        off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
        off_x_crs = tl.arange(0, BLOCK_K)

        CRS = IN_C * KERNEL_H * KERNEL_W
        # load inc ptr of x, upade x_ptrs
        if not CONV1X1_NHWC:
            delta_xh_ptrs = delta_xh_ptr + off_x_crs
            delta_xw_ptrs = delta_xw_ptr + off_x_crs
            delta_xc_ptrs = delta_xc_ptr + off_x_crs
            delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
            delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
            delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
            off_x_crs_unpacked = (
                delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
            )
            x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
        else:
            x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]
            delta_xh = 0
            delta_xw = 0

        mask_x = (
            (off_x_n < BATCH)[:, None]
            & (off_x_crs < CRS)[None, :]
            & (off_x_h[:, None] + delta_xh[None, :] >= 0)
            & (off_x_h[:, None] + delta_xh[None, :] < IN_H)
            & (off_x_w[:, None] + delta_xw[None, :] >= 0)
            & (off_x_w[:, None] + delta_xw[None, :] < IN_W)
        )

        # offset for the inital ptr for w
        off_w_crs = tl.arange(0, BLOCK_K)
        off_w_k = off_y_k
        w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
        mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]

        # ------ load x ------
        matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
        # ------ load w ------
        matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)

        # -----------------------------------------------------------
        # allocate accumulator
        acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
        for crs in range(0, CRS, BLOCK_K):

            # ------ matrix multiplication ------
            acc += tl.dot(matrix_x, matrix_w)
            # ------ update ptrs ------
            w_ptrs += BLOCK_K
            # load inc ptr of x, upade x_ptrs
            off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
            if not CONV1X1_NHWC:
                delta_xh_ptrs += BLOCK_K
                delta_xw_ptrs += BLOCK_K
                delta_xc_ptrs += BLOCK_K
                delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
                delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
                delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
                off_x_crs_unpacked = (
                    delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
                )
                x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
            else:
                x_ptrs += BLOCK_K

            mask_x = (
                (off_x_n < BATCH)[:, None]
                & (off_x_crs < CRS)[None, :]
                & (off_x_h[:, None] + delta_xh[None, :] >= 0)
                & (off_x_h[:, None] + delta_xh[None, :] < IN_H)
                & (off_x_w[:, None] + delta_xw[None, :] >= 0)
                & (off_x_w[:, None] + delta_xw[None, :] < IN_W)
            )
            mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
            # ------ prefetch ------
            # ------ load x ------
            matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
            # ------ load w ------
            matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)

        acc = acc.to(y.dtype.element_ty)

        # rematerialize -- this saves some registers
        # offset for output y
        off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
        off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
        off_y_n = off_y_nhw // (OUT_H * OUT_W)
        off_y_hw = off_y_nhw % (OUT_H * OUT_W)
        # consider output padding
        off_y_h = off_y_hw // OUT_W + output_padding_h
        off_y_w = off_y_hw % OUT_W + output_padding_w

        # y ptrs in the block of [BLOCK_M, BLOCK_N]
        y_ptrs = (
            y
            + off_y_n[:, None] * stride_yn
            + off_y_h[:, None] * stride_yh
            + off_y_w[:, None] * stride_yw
            + off_y_k[None, :] * stride_yc
        )

        # out-of-bounds check
        mask_y = (
            (off_y_n < BATCH)[:, None]
            & (off_y_h < OUT_H + output_padding_h)[:, None]
            & (off_y_w < OUT_W + output_padding_w)[:, None]
            & (off_y_k < KERNEL_N)[None, :]
        )

        tl.store(y_ptrs, acc, mask=mask_y)

        return

    @conv_heuristics()
    @triton.jit
    def _kernel_delta_x(
        x,
        w,
        y,
        # stride of tensor
        stride_xn,
        stride_xc,
        stride_xh,
        stride_xw,
        stride_wn,
        stride_wc,
        stride_wh,
        stride_ww,
        stride_yn,
        stride_yc,
        stride_yh,
        stride_yw,
        stride_biasn,
        # pointer inc for x
        delta_x_ptr,
        # Tensor dimensions
        BATCH,
        IN_C,
        IN_H,
        IN_W,
        KERNEL_N,
        KERNEL_H,
        KERNEL_W,
        OUT_H,
        OUT_W,
        # parameters of conv
        stride_h,
        stride_w,
        padding_h,
        padding_w,
        dilation_h,
        dilation_w,
        output_padding_h,
        output_padding_w,
        groups,
        # Metaparameters
        ACC_TYPE: tl.constexpr,
        CONV1X1_NHWC: tl.constexpr,
        # blocks in different dimension
        BLOCK_M: tl.constexpr,
        BLOCK_N: tl.constexpr,
        # reduction tiling parameter for matmul
        BLOCK_K: tl.constexpr,
        # Super-blocking for better L2 peformance
        GROUP_H: tl.constexpr,
    ):
        """
        each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
        """
        # -----------------------------------------------------------
        # Map program ids `pid` to the block of y it should compute.
        pid_nhw = tl.program_id(0)
        pid_k = tl.program_id(1)

        # offset for output y
        off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
        off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
        off_y_n = off_y_nhw // (OUT_H * OUT_W)
        off_y_hw = off_y_nhw % (OUT_H * OUT_W)
        off_y_h = off_y_hw // OUT_W + output_padding_h
        off_y_w = off_y_hw % OUT_W + output_padding_w

        # offset for the initial ptr for x
        off_x_n = off_y_n
        off_x_h = off_y_h * stride_h - padding_h
        off_x_w = off_y_w * stride_w - padding_w
        off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
        off_x_crs = tl.arange(0, BLOCK_K)

        CRS = IN_C * KERNEL_H * KERNEL_W
        # load inc ptr of x, upade x_ptrs
        if not CONV1X1_NHWC:
            delta_x_ptrs = delta_x_ptr + off_x_crs
            off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS)
            x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
        else:
            x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]

        mask_x = (
            (off_x_n < BATCH)
            & (off_x_h >= 0)
            & (off_x_h < IN_H)
            & (off_x_w >= 0)
            & (off_x_w < IN_W)
        )[:, None] & (off_x_crs < CRS)[None, :]

        # offset for the inital ptr for w
        off_w_crs = tl.arange(0, BLOCK_K)
        off_w_k = off_y_k
        w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
        mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]

        # ------ load x ------
        matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
        # ------ load w ------
        matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)

        # -----------------------------------------------------------
        # allocate accumulator
        acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
        for crs in range(0, CRS, BLOCK_K):

            # ------ matrix multiplication ------
            acc += tl.dot(matrix_x, matrix_w)
            # ------ update ptrs ------
            w_ptrs += BLOCK_K
            # load inc ptr of x, upade x_ptrs
            if not CONV1X1_NHWC:
                delta_x_ptrs += BLOCK_K
                off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
                off_x_crs_unpacked = tl.load(
                    delta_x_ptrs, mask=off_x_crs < CRS, other=0
                )
                x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
            else:
                off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
                x_ptrs += BLOCK_K

            mask_x = (
                (off_x_n < BATCH)
                & (off_x_h >= 0)
                & (off_x_h < IN_H)
                & (off_x_w >= 0)
                & (off_x_w < IN_W)
            )[:, None] & (off_x_crs < CRS)[None, :]
            mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
            # ------ prefetch ------
            # ------ load x ------
            matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
            # ------ load w ------
            matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)

        acc = acc.to(y.dtype.element_ty)

        # rematerialize -- this saves some registers
        # offset for output y
        off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
        off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
        off_y_n = off_y_nhw // (OUT_H * OUT_W)
        off_y_hw = off_y_nhw % (OUT_H * OUT_W)
        # consider output padding
        off_y_h = off_y_hw // OUT_W + output_padding_h
        off_y_w = off_y_hw % OUT_W + output_padding_w

        # y ptrs in the block of [BLOCK_M, BLOCK_N]
        y_ptrs = (
            y
            + off_y_n[:, None] * stride_yn
            + off_y_h[:, None] * stride_yh
            + off_y_w[:, None] * stride_yw
            + off_y_k[None, :] * stride_yc
        )

        # out-of-bounds check
        mask_y = (
            (off_y_n < BATCH)[:, None]
            & (off_y_h < OUT_H + output_padding_h)[:, None]
            & (off_y_w < OUT_W + output_padding_w)[:, None]
            & (off_y_k < KERNEL_N)[None, :]
        )

        tl.store(y_ptrs, acc, mask=mask_y)

        return

    class _conv:
        kernel = _kernel_delta_x_hwc

        # for the contigous order of w ptr, what"s the corresponding
        # ptr changes for x in a sliding window
        @staticmethod
        def _delta_x_ptr_hwc(
            IN_C,
            KERNEL_H,
            KERNEL_W,
            dilation_h,
            dilation_w,
            stride_wc,
            stride_wh,
            stride_ww,
            stride_xc,
            stride_xh,
            stride_xw,
            device,
        ):
            # get the order of axes in w, innermost dimension outward
            stride_w_3d = [stride_wc, stride_wh, stride_ww]
            order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__)
            window_size = IN_C * KERNEL_H * KERNEL_W

            r_window = torch.arange(0, window_size, 1, device=device)
            window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W])
            window_unpack_c = window_unpack[order[0]]
            window_unpack_h = window_unpack[order[1]]
            window_unpack_w = window_unpack[order[2]]
            r_dilation_h = dilation_h * window_unpack_h
            r_dilation_w = dilation_w * window_unpack_w
            r_inc = window_unpack_c
            # delta_x = (
            #     r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc
            # )
            # return delta_x
            return (
                r_dilation_h,
                r_dilation_w,
                r_inc,
            )

        @staticmethod
        def _delta_x_ptr(
            IN_C,
            KERNEL_H,
            KERNEL_W,
            dilation_h,
            dilation_w,
            stride_wc,
            stride_wh,
            stride_ww,
            stride_xc,
            stride_xh,
            stride_xw,
            device,
        ):
            # get the order of axes in w, innermost dimension outward
            stride_w_3d = [stride_wc, stride_wh, stride_ww]
            order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__)
            window_size = IN_C * KERNEL_H * KERNEL_W

            r_window = torch.arange(0, window_size, 1, device=device)
            window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W])
            window_unpack_c = window_unpack[order[0]]
            window_unpack_h = window_unpack[order[1]]
            window_unpack_w = window_unpack[order[2]]
            r_dilation_h = dilation_h * window_unpack_h
            r_dilation_w = dilation_w * window_unpack_w
            r_inc = window_unpack_c
            delta_x = (
                r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc
            )
            return delta_x

        @staticmethod
        def _call(
            x,
            w,
            bias,
            stride,
            padding,
            dilation,
            transposed,
            output_padding,
            groups,
        ):
            # Q: should we check x, w, bias dtypes?
            device = x.device
            # input shapes
            shape_x = x.shape
            shape_w = w.shape
            shape_bias = bias.shape if bias is not None else None

            # indicies for the layout
            xn, xc, xh, xw = 0, 1, 2, 3
            yn, yc, yh, yw = 0, 1, 2, 3
            wn, wc, wh, ww = 0, 1, 2, 3

            # out_channel, in_channel, kernel_height, kernel_width
            kernel_size = [shape_w[wh], shape_w[ww]]
            input_size = [shape_x[xh], shape_x[xw]]
            assert (
                not shape_bias or shape_bias[0] == shape_w[wn]
            ), f"bias shape did not match{shape_bias} != {shape_w[wn]}"
            in_channel = shape_w[wc] * groups

            assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups"
            assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups"
            assert (
                shape_x[xc] == in_channel
            ), f"in_channel did not match {shape_x[xc]} != {in_channel}"

            assert (
                len(stride)
                == len(padding)
                == len(dilation)
                == len(output_padding)
                == len(kernel_size)
                == len(input_size)
            )

            # output shape
            shape_y = [0] * 4
            shape_y[yn] = shape_x[xn]
            shape_y[yc] = shape_w[wn]
            shape_y[yh] = (
                input_size[0]
                + 2 * padding[0]
                - dilation[0] * (kernel_size[0] - 1)
                - 1
                + stride[0]
            ) // stride[0] + 2 * output_padding[0]
            shape_y[yw] = (
                input_size[1]
                + 2 * padding[1]
                - dilation[1] * (kernel_size[1] - 1)
                - 1
                + stride[1]
            ) // stride[1] + 2 * output_padding[1]

            BATCH = shape_x[xn]
            IN_C = shape_x[xc]
            IN_H = shape_x[xh]
            IN_W = shape_x[xw]
            KERNEL_N = shape_w[wn]
            KERNEL_H = shape_w[wh]
            KERNEL_W = shape_w[ww]
            OUT_H = shape_y[yh]
            OUT_W = shape_y[yw]

            # allocate output
            y = torch.empty(shape_y, device=device, dtype=x.dtype)

            # get strides for tensors
            stride_x = x.stride()
            stride_w = w.stride()
            stride_bias = bias.stride() if shape_bias else None
            stride_biasn = stride_bias[0] if stride_bias else None

            # output layout should be the same as x
            if stride_x[xc] < stride_x[xh] and stride_x[xc] < stride_x[xw]:
                y = y.to(memory_format=torch.channels_last)
            stride_y = y.stride()

            # allocate tmp
            # WINDOW_SIZE = KERNEL_H * KERNEL_W * IN_C
            # tmp_x = torch.empty((BATCH * OUT_H * OUT_W, WINDOW_SIZE), device=device, dtype=x.dtype)
            # tmp_w = torch.empty((WINDOW_SIZE, KERNEL_N), device=device, dtype=w.dtype)
            # accumulator types
            ACC_TYPE = (
                tl.float32
                if x.dtype in [torch.float16, torch.bfloat16, torch.float32]
                else tl.int32
            )
            # if stride_x[xc] == 1 and stride_x > 1 and stride_y > 1:
            CONV1X1_NHWC = False
            if stride_x[xc] == 1 and KERNEL_H == 1 and KERNEL_W == 1:
                CONV1X1_NHWC = True
            #  do we need delta x ptr for h, w, c dimension each or not
            DELTA_X_PTR_HWC = (
                False
                if (
                    (padding[0] == 0 and padding[1] == 0)
                    or (KERNEL_H == 1 and KERNEL_W == 1)
                )
                else True
            )
            if not CONV1X1_NHWC:
                if DELTA_X_PTR_HWC:
                    delta_xh, delta_xw, delta_xc = _conv._delta_x_ptr_hwc(
                        IN_C,
                        KERNEL_H,
                        KERNEL_W,
                        dilation[0],
                        dilation[1],
                        stride_w[wc],
                        stride_w[wh],
                        stride_w[ww],
                        stride_x[xc],
                        stride_x[xh],
                        stride_x[xw],
                        device,
                    )
                else:
                    delta_x = _conv._delta_x_ptr(
                        IN_C,
                        KERNEL_H,
                        KERNEL_W,
                        dilation[0],
                        dilation[1],
                        stride_w[wc],
                        stride_w[wh],
                        stride_w[ww],
                        stride_x[xc],
                        stride_x[xh],
                        stride_x[xw],
                        device,
                    )
            else:
                delta_x = None
                delta_xh, delta_xw, delta_xc = None, None, None

            # launch kernel, 2-dim, batch*h*w, kernel
            def grid(META):
                return (
                    triton.cdiv(BATCH * OUT_H * OUT_W, META["BLOCK_M"]),
                    triton.cdiv(KERNEL_N, META["BLOCK_N"]),
                )

            # conv1x1 or padding==0
            if CONV1X1_NHWC or not DELTA_X_PTR_HWC:
                _kernel_delta_x[grid](
                    x,
                    w,
                    y,
                    # stride nchw for x,w,y tensor
                    stride_x[xn],
                    stride_x[xc],
                    stride_x[xh],
                    stride_x[xw],
                    stride_w[wn],
                    stride_w[wc],
                    stride_w[wh],
                    stride_w[ww],
                    stride_y[yn],
                    stride_y[yc],
                    stride_y[yh],
                    stride_y[yw],
                    stride_biasn,
                    # pointer inc for x
                    delta_x,
                    # Tensor dimensions
                    BATCH,
                    IN_C,
                    IN_H,
                    IN_W,
                    KERNEL_N,
                    KERNEL_H,
                    KERNEL_W,
                    OUT_H,
                    OUT_W,
                    # conv parameters
                    stride[0],
                    stride[1],
                    padding[0],
                    padding[1],
                    dilation[0],
                    dilation[1],
                    output_padding[0],
                    output_padding[1],
                    groups,
                    # Metaparameters
                    ACC_TYPE=ACC_TYPE,
                    CONV1X1_NHWC=CONV1X1_NHWC,
                    # BLOCK_M=128,
                    # BLOCK_N=32,
                    # BLOCK_K=32,
                    GROUP_H=1,
                )
            # need to know ptr update for each dimension to check if
            # the sliding window is out of bounds
            else:
                # kernel = _kernel_delta_x_hwc
                _kernel_delta_x_hwc[grid](
                    x,
                    w,
                    y,
                    # stride nchw for x,w,y tensor
                    stride_x[xn],
                    stride_x[xc],
                    stride_x[xh],
                    stride_x[xw],
                    stride_w[wn],
                    stride_w[wc],
                    stride_w[wh],
                    stride_w[ww],
                    stride_y[yn],
                    stride_y[yc],
                    stride_y[yh],
                    stride_y[yw],
                    stride_biasn,
                    # pointer inc for x
                    delta_xh,
                    delta_xw,
                    delta_xc,
                    # Tensor dimensions
                    BATCH,
                    IN_C,
                    IN_H,
                    IN_W,
                    KERNEL_N,
                    KERNEL_H,
                    KERNEL_W,
                    OUT_H,
                    OUT_W,
                    # conv parameters
                    stride[0],
                    stride[1],
                    padding[0],
                    padding[1],
                    dilation[0],
                    dilation[1],
                    output_padding[0],
                    output_padding[1],
                    groups,
                    # Metaparameters
                    ACC_TYPE=ACC_TYPE,
                    CONV1X1_NHWC=CONV1X1_NHWC,
                    # BLOCK_M=128,
                    # BLOCK_N=32,
                    # BLOCK_K=32,
                    GROUP_H=1,
                )

            if bias is not None:
                if len(bias.shape) == 1:
                    bias = bias.reshape([1, bias.shape[0], 1, 1])
                y += bias
            return y

        @staticmethod
        def forward(
            x,
            w,
            bias,
            stride=(1, 1),
            padding=(0, 0),
            dilation=(1, 1),
            transposed=False,
            output_padding=(0, 0),
            groups=1,
        ):
            if groups != 1:
                print(f"Do not support groups = {groups}")
                return
            if transposed:
                print("Do not support transposed")
            return _conv._call(
                x,
                w,
                bias,
                stride,
                padding,
                dilation,
                transposed,
                output_padding,
                groups,
            )

    conv = _conv.forward