Repository URL to install this package:
|
Version:
2.4.1 ▾
|
# mypy: allow-untyped-defs
from collections import namedtuple
from typing import Dict, List, Optional, Type
import sympy
import torch
from .. import ir
from ..codecache import pick_vec_isa, VecAVX2, VecAVX512
from ..utils import IndentedBuffer, parallel_num_threads
from ..virtualized import V
from .common import KernelTemplate
from .cpp_template_kernel import CppTemplateKernel
from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp
class CppMicroGemm:
"""
A class that codegens a kernel that computes small-sized matrix multiplication.
A micro GEMM kernel is responsible for register blocking, instruction selection,
and other CPU architecture-specific optimizations.
The subclasses need to override `codegen_define` to define the kernel function
that is called by the code generated by `codegen_call`.
"""
# TODO(jgong5): support constant shapes and lds as template args.
DECLARE_KERNEL = r"""
template <bool accum>
inline void {{kernel_name}}(
const {{input_t}}* __restrict__ A,
const {{input_t}}* __restrict__ B,
{{output_t}}* __restrict__ C,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc
)
"""
def __init__(
self,
name,
input_dtype,
output_dtype,
compute_dtype,
register_blocking,
alpha=1,
):
self.name = name
self.input_dtype = input_dtype
self.output_dtype = output_dtype
self.compute_dtype = compute_dtype
self.register_blocking = register_blocking
self.alpha = alpha
def get_common_options(self):
return {
"kernel_name": self.name,
"input_t": DTYPE_TO_CPP[self.input_dtype],
"output_t": DTYPE_TO_CPP[self.output_dtype],
"compute_t": DTYPE_TO_CPP[self.compute_dtype],
"alpha": self.alpha,
}
def get_kernel_declaration(self):
options = self.get_common_options()
return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options)
def codegen_define(self, kernel: CppTemplateKernel) -> str:
raise NotImplementedError
def codegen_call(
self,
kernel: CppTemplateKernel,
A: ir.Buffer,
B: ir.Buffer,
C: ir.Buffer,
accum: bool,
) -> str:
"""
Generate the code for calling the templated kernel that computes
`C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise.
"""
A_ptr = f"&({kernel.index(A, [0, 0])})"
B_ptr = f"&({kernel.index(B, [0, 0])})"
C_ptr = f"&({kernel.index(C, [0, 0])})"
M = kernel.size(C, 0)
N = kernel.size(C, 1)
K = kernel.size(A, 1)
lda = kernel.stride(A, 0)
ldb = kernel.stride(B, 0)
ldc = kernel.stride(C, 0)
res = IndentedBuffer()
res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(")
with res.indent():
res.writeline(f"{A_ptr},")
res.writeline(f"{B_ptr},")
res.writeline(f"{C_ptr},")
res.writeline(f"{M},")
res.writeline(f"{N},")
res.writeline(f"{K},")
res.writeline(f"{lda},")
res.writeline(f"{ldb},")
res.writeline(f"{ldc}")
res.writeline(");")
return res.getvalue()
CppMicroGemmConfig = namedtuple(
"CppMicroGemmConfig",
[
"input_dtype",
"output_dtype",
"compute_dtype",
"vec_isa_cls",
"register_blocking",
],
)
micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {}
def register_micro_gemm(*configs):
def inner(cls):
assert (
cls not in micro_gemm_configs
), f"Duplicate micro_gemm registration for {cls}"
assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
micro_gemm_configs[cls] = list(configs)
return cls
return inner
class CppMicroGemmRef(CppMicroGemm):
"""
A reference implementation of the CppMicroGemm class with naive C++ code.
It is used for correctness debugging.
"""
TEMPLATE_ENTRY = r"""
{{declare_kernel}} {
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
{{compute_t}} result = accum ? C[m * ldc + n] : 0;
for (int64_t k = 0; k < K; ++k) {
result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}};
}
C[m * ldc + n] = result;
}
}
}
"""
def __init__(self, name, input_dtype, output_dtype, compute_dtype, alpha):
super().__init__(
name, input_dtype, output_dtype, compute_dtype, GemmBlocking(1, 1, 1), alpha
)
def codegen_define(self, kernel: CppTemplateKernel) -> str:
options = {
"declare_kernel": self.get_kernel_declaration(),
**self.get_common_options(),
}
return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)
@register_micro_gemm(
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 48, 1)
),
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 32, 1)
),
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(16, 16, 1)
),
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 24, 1)
),
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 16, 1)
),
CppMicroGemmConfig(
torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(8, 8, 1)
),
)
class CppMicroGemmFP32Vec(CppMicroGemm):
"""
This class generates the code for fp32 micro gemm using vec instructions.
"""
TEMPLATE_ENTRY = r"""
{{declare_kernel}} {
TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
// TODO(jgong5): loop unroll for M and N
for (int64_t m = 0; m < M; m += {{block_m}}) {
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
for (int64_t n = 0; n < N; n += {{block_n}}) {
if (block_m == {{block_m}}) {
{{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
} else {
switch (block_m) {
{%- for b in range(block_m - 1, 0, -1) %}
case {{b}}:
{{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
{%- endfor %}
default:
{{kernel.assert_function}}(false, "Unsupported block_m: ", block_m);
}
}
}
}
}
"""
TEMPLATE_KERNEL = r"""
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
inline void {{kernel_name}}_kernel(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc
) {
using Vectorized = at::vec::Vectorized<float>;
constexpr auto VLEN = Vectorized::size();
constexpr auto ROWS = BLOCK_M;
constexpr auto COLS = BLOCK_N / VLEN;
Vectorized va;
at::vec::VectorizedN<float, COLS> vb;
at::vec::VectorizedN<float, ROWS*COLS> vc;
auto loadc = [&](auto i) {
if constexpr (accum) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
} else {
vc[i] = Vectorized(0.0f);
}
};
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
auto compute = [&, COLS](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
{%- if alpha != 1 %}
va = Vectorized(A[row * lda + k] * {{alpha}});
{%- else %}
va = Vectorized(A[row * lda + k]);
{%- endif %}
}
if constexpr (row == 0) {
vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
}
constexpr int idx = row * COLS + col;
vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
};
{{kernel.unroll_pragma(4)}}
for (int k = 0; k < K; ++k) {
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
}
// store to C
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
vc[i].store(C + row * ldc + col * VLEN);
};
c10::ForcedUnroll<ROWS * COLS>{}(storec);
}
"""
def codegen_define(self, kernel: CppTemplateKernel) -> str:
options = {
"declare_kernel": self.get_kernel_declaration(),
"kernel": kernel,
"block_m": self.register_blocking.block_m,
"block_n": self.register_blocking.block_n,
"block_k": self.register_blocking.block_k,
**self.get_common_options(),
}
result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
options
)
result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
options
)
return result
def create_micro_gemm(
name,
m,
n,
k,
input_dtype,
output_dtype=None,
compute_dtype=None,
alpha=1,
num_threads=-1,
use_ref=True,
) -> Optional[CppMicroGemm]:
def create_from_config(cls, config: CppMicroGemmConfig):
return cls(
name,
config.input_dtype,
config.output_dtype,
config.compute_dtype,
config.register_blocking,
alpha,
)
assert isinstance(n, int) or n.is_number, n
assert isinstance(k, int) or k.is_number, k
m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m
assert isinstance(m, int), m
if output_dtype is None:
output_dtype = input_dtype
if compute_dtype is None:
compute_dtype = input_dtype
if num_threads < 0:
num_threads = parallel_num_threads()
vec_isa = pick_vec_isa()
matched_configs = []
for cls, configs in micro_gemm_configs.items():
for config in configs:
if not isinstance(vec_isa, config.vec_isa_cls):
continue
if (
config.input_dtype == input_dtype
and config.output_dtype == output_dtype
and config.compute_dtype == compute_dtype
):
block_m, block_n, block_k = config.register_blocking
# TODO(jgong5): support n % n_block_size != 0
if n % block_n != 0:
continue
# Criteria on the ranking of configurations
# 1. Dividable by block sizes (block_m, block_k)
# 2. Number of mxn blocks is large enough to occupy all the threads
# 3. Register blocks are larger
dividable_score = 0
if k % block_k == 0:
dividable_score += 1
if m % block_m == 0:
dividable_score += 1
occupancy_score = 0
n_blocks = n // block_n
total_mxn_blocks = n // block_n * ((m + block_m - 1) // block_m)
if n_blocks >= num_threads:
occupancy_score += 1
if total_mxn_blocks >= num_threads:
occupancy_score += 1
matched_configs.append(
(
(dividable_score, occupancy_score, block_m * block_n * block_k),
cls,
config,
)
)
if len(matched_configs) == 0:
if use_ref:
return CppMicroGemmRef(
name, input_dtype, output_dtype, compute_dtype, alpha
)
else:
return None
# TODO(jgong5): allow autotuning on choices of configs
return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:])