import functools
import logging
import sympy
import torch
from torch._inductor.select_algorithm import realize_inputs
from torch._inductor.virtualized import V
from ..utils import ceildiv as cdiv
log = logging.getLogger(__name__)
@functools.lru_cache(None)
def mm_configs():
import triton
return [
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=3, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=3, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=8
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=5, num_warps=8
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=5, num_warps=8
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=2, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_stages=2, num_warps=4
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16}, num_stages=1, num_warps=2
),
]
def mm_grid(m, n, meta):
"""
The CUDA grid size for matmul triton templates.
"""
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
def acc_type(dtype):
if dtype in (torch.float16, torch.bfloat16):
return "tl.float32"
return f"tl.{dtype}".replace("torch.", "")
def mm_options(config, sym_k, layout):
"""
Common options to matmul triton templates.
"""
even_k_symbolic = (
# it isn't worth guarding on this
sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
== config.kwargs["BLOCK_K"]
)
return dict(
GROUP_M=8,
EVEN_K=even_k_symbolic,
ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32,
ACC_TYPE=acc_type(layout.dtype),
num_stages=config.num_stages,
num_warps=config.num_warps,
**config.kwargs,
)
def mm_args(mat1, mat2, *others, layout=None):
"""
Common arg processing for mm,bmm,addmm,etc
"""
mat1, mat2 = realize_inputs(mat1, mat2)
*b1, m, k1 = mat1.get_size()
*b2, k2, n = mat2.get_size()
b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
k = V.graph.sizevars.guard_equals(k1, k2)
if layout is None:
from torch._inductor.ir import FixedLayout
layout = FixedLayout(
mat1.get_device(),
mat1.get_dtype(),
[*b, m, n],
)
from ..lowering import expand
others = [realize_inputs(expand(x, layout.size)) for x in others]
return [m, n, k, layout, mat1, mat2, *others]
def addmm_epilogue(dtype, alpha, beta):
def epilogue(acc, bias):
if alpha != 1:
acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
if beta != 1:
bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
return V.ops.add(acc, bias)
return epilogue