Repository URL to install this package:
|
Version:
2.4.0 ▾
|
# mypy: allow-untyped-defs
from typing import cast, List, Optional
import torch
import torch.utils
from .. import ir, lowering as L
from ..kernel.mm_common import mm_args
from ..select_algorithm import DataProcessorTemplateWrapper
from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
from ..virtualized import V
from .cpp_micro_gemm import create_micro_gemm
from .cpp_template import CppTemplate
from .cpp_template_kernel import CppTemplateKernel
from .cpp_utils import GemmBlocking
GEMM_TEMPLATE = r"""
{{template.header().getvalue()}}
{{micro_gemm.codegen_define(kernel)}}
extern "C"
{{kernel.def_kernel(inputs={"X": X, "W": W, "inp": inp}, outputs={"Y": Y})}}
{
{{kernel.maybe_codegen_profile()}}
constexpr int64_t num_threads = {{num_threads}};
constexpr int64_t N = {{kernel.size(GemmOut, 1)}};
constexpr int64_t K = {{kernel.size(X, 1)}};
constexpr int64_t M0 = {{micro_gemm.register_blocking.block_m}};
constexpr int64_t N0 = {{micro_gemm.register_blocking.block_n}};
constexpr int64_t K0 = {{micro_gemm.register_blocking.block_k}};
constexpr int64_t N0_blocks = (N + N0 - 1) / N0;
constexpr int64_t K0_blocks = (K + K0 - 1) / K0;
static_assert(N % N0 == 0, "N dimension must be multiple of N0");
// TODO(jgong5): improve cache blocking with CPU info (Mc, Kc)
{%- if is_dynamic_M %}
const int64_t M = {{kernel.size(GemmOut, 0)}};
const int64_t M0_blocks = (M + M0 - 1) / M0;
{%- if num_threads > 1 %}
int64_t Mt_blocks, Nt_blocks, Kt_blocks;
mm_get_thread_blocking(num_threads, M, N, K, M0, N0, K0, Mt_blocks, Nt_blocks, Kt_blocks);
{%- else %}
const auto Mt_blocks = M0_blocks;
const auto Nt_blocks = N0_blocks;
const auto Kt_blocks = K0_blocks;
{%- endif %}
const int64_t Mc_blocks = Mt_blocks;
const int64_t Kc_blocks = Kt_blocks;
{%- else %}
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
constexpr int64_t M0_blocks = (M + M0 - 1) / M0;
constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}};
constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}};
constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}};
constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}};
constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}};
{%- endif %}
// TODO(jgong5): support k-slicing
{{kernel.assert_function}}(Kt_blocks == K0_blocks, "Do not support k slicing yet.");
// make sure all partitions are assigned
{{kernel.assert_function}}(
Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= M0_blocks * N0_blocks * K0_blocks,
"Not all partitions are assigned."
);
{%- if num_threads > 1 %}
#pragma omp parallel num_threads({{num_threads}})
{
int tid = omp_get_thread_num();
int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end;
mm_get_thread_blocks(
tid, M0_blocks, N0_blocks, K0_blocks, Mt_blocks, Nt_blocks, Kt_blocks,
m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end);
{%- else %}
{
int64_t m_block_start = 0;
int64_t m_block_end = M0_blocks;
int64_t n_block_start = 0;
int64_t n_block_end = N0_blocks;
int64_t k_block_start = 0;
int64_t k_block_end = K0_blocks;
{%- endif %}
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
const int64_t m_start = mc * M0;
const int64_t m_end = std::min((mc + Mc_blocks) * M0, M);
const int64_t m_size = m_end - m_start;
for (int64_t nc = n_block_start; nc < n_block_end; ++nc) {
const int64_t n_start = nc * N0;
const int64_t n_size = N0;
{%- if use_local_acc %}
{{ kernel.define_buffer("acc_local_buf", ["m_end - m_start", "N0"]) }}
{%- set acc = kernel.local_buffers["acc_local_buf"] %}
{%- else %}
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
{%- endif %}
{%- if inp is not none and beta != 0 %}
for (int64_t m = 0; m < m_size; ++m) {
#pragma omp simd
for (int64_t n = 0; n < n_size; ++n) {
{{kernel.index(acc, ["m", "n"])}} = {{beta}} * {{kernel.index(inp, ["m + m_start", "n + n_start"])}};
}
}
{%- endif %}
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
int64_t k_start = kc * K0;
int64_t k_end = std::min((kc + Kc_blocks) * K0, K);
{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
{%- set tile_W_3d = kernel.slice_nd(W, [("nc", "nc + 1"), ("k_start", "k_end"), ()]) %}
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
{%- if inp is not none and beta != 0 %}
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True)|indent(20, false) }}
{%- else %}
if (kc == k_block_start) {
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False)|indent(24, false) }}
} else {
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True)|indent(24, false) }}
}
{%- endif %}
}
{%- if reindexer is not none %}
{%- set Y_maybe_transposed = kernel.permute(Y, reindexer([0,1])) %}
{%- else %}
{%- set Y_maybe_transposed = Y %}
{%- endif %}
{%- set tile_Y = kernel.slice_nd(Y_maybe_transposed, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
{{ kernel.store_output(
tile_Y, acc, epilogue_nodes, offsets=("m_start", "n_start"), reindexer=reindexer
)|indent(16, false)
}}
}
}
}
}
"""
class CppPackedGemmTemplate(CppTemplate):
def __init__(
self,
input_nodes,
layout: ir.Layout,
num_threads: int,
register_blocking: GemmBlocking,
beta=1,
alpha=1,
):
super().__init__("packed_gemm", input_nodes, layout)
self.beta = beta
self.alpha = alpha
self.num_threads = num_threads
self.register_blocking = register_blocking
m, n = layout.size
_, k = input_nodes[0].get_size()
self.m, self.n, self.k = m, n, k
self.is_dynamic_M = has_free_symbols((m,))
@cache_on_self
def thread_blocking(self) -> GemmBlocking:
# TODO(jgong5): allow tuning various blocking options
def get_factors(number):
factors = []
# priorize more evenly divided factors
for i in range(int(number**0.5), 0, -1):
if number % i == 0:
factors.append(number // i)
factors.append(i)
return factors
def get_blocking(num_threads, factor, m_blocks, n_blocks, k_blocks):
thread_block_n = (n_blocks + factor - 1) // factor
cofactor = num_threads // factor
thread_block_m = (m_blocks + cofactor - 1) // cofactor
return GemmBlocking(thread_block_m, thread_block_n, k_blocks)
assert (
not self.is_dynamic_M
), "Unable to determine thread blocking for dynamic M."
register_blocking = self.register_blocking
m_blocks = (self.m + register_blocking.block_m - 1) // register_blocking.block_m
n_blocks = (self.n + register_blocking.block_n - 1) // register_blocking.block_n
k_blocks = (self.k + register_blocking.block_k - 1) // register_blocking.block_k
factors = get_factors(self.num_threads)
assert len(factors) > 0
for factor in factors:
if n_blocks % factor == 0 and m_blocks % (self.num_threads // factor) == 0:
return get_blocking(
self.num_threads, factor, m_blocks, n_blocks, k_blocks
)
for factor in factors:
if n_blocks % factor == 0:
return get_blocking(
self.num_threads, factor, m_blocks, n_blocks, k_blocks
)
cofactor = self.num_threads // factor
if m_blocks % cofactor == 0:
return get_blocking(
self.num_threads, factor, m_blocks, n_blocks, k_blocks
)
raise AssertionError("Should not reach here.")
@cache_on_self
def cache_blocking(self) -> GemmBlocking:
# TODO(jgong5): improve cache blocking with CPU info
assert (
not self.is_dynamic_M
), "Unable to determine cache blocking for dynamic M."
thread_blocking = self.thread_blocking()
return GemmBlocking(thread_blocking.block_m, 1, thread_blocking.block_k)
@staticmethod
def add_choices(
choices, layout, input_nodes, beta=1, alpha=1, trans_w=False, input_indices=None
):
if input_indices is None:
input_indices = list(range(len(input_nodes)))
def reorder_and_filter(inputs, layout_or_out):
if len(input_indices) == 2:
x_idx = input_indices[0]
w_idx = input_indices[1]
return [inputs[x_idx], inputs[w_idx]], layout_or_out
else:
assert (
len(input_indices) == 3
), "Cpp Packed GEMM template requires 2 or 3 input nodes."
# assume the input order is [inp, x, w] and we reorder it to [x, w, inp]
inp_idx = input_indices[0]
x_idx = input_indices[1]
w_idx = input_indices[2]
return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out
def transpose_weight(inputs, layout_or_out):
if not trans_w:
return inputs, layout_or_out
new_inputs = list(inputs)
W = inputs[1]
if isinstance(W, ir.IRNode):
if not isinstance(W, ir.TensorBox):
W = ir.TensorBox(W)
new_inputs[1] = L.permute(W, [1, 0])
return new_inputs, layout_or_out
else:
assert isinstance(W, torch.Tensor)
new_inputs[1] = W.transpose(0, 1)
return new_inputs, layout_or_out
# TODO(jgong5): decide proper number of threads per problem size
num_threads = parallel_num_threads()
new_inputs, _ = transpose_weight(*reorder_and_filter(input_nodes, layout))
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1])
micro_gemm = create_micro_gemm(
"micro_gemm", m, n, k, layout.dtype, alpha=alpha, num_threads=num_threads
)
assert micro_gemm is not None
_, block_n, _ = micro_gemm.register_blocking
def pack_weight(inputs, layout_or_out):
W = inputs[1]
new_inputs = list(inputs)
if isinstance(W, ir.IRNode):
if not isinstance(W, ir.TensorBox):
W = ir.TensorBox(W)
k, n = W.get_size()
assert (
n % block_n == 0
), f"The last dimension of W must be a multiple of {block_n}."
blocked_w = L.permute(
L.view(W, (k, n // block_n, block_n)),
[1, 0, 2],
)
blocked_w = ir.ExternKernel.realize_input(blocked_w)
blocked_w = ir.ExternKernel.require_contiguous(blocked_w)
if isinstance(blocked_w, ir.ReinterpretView):
# normalize stride to be "contiguous_strides" per size
# this avoids the problems in L.view during template codegen
assert isinstance(blocked_w.layout, ir.FixedLayout)
blocked_w.layout = ir.FixedLayout(
blocked_w.layout.device,
blocked_w.layout.dtype,
blocked_w.layout.size,
ir.FlexibleLayout.contiguous_strides(blocked_w.layout.size),
blocked_w.layout.offset,
)
else:
k, n = list(W.shape)
blocked_w = (
W.reshape(k, n // block_n, block_n).transpose(0, 1).contiguous()
)
# normalize stride to be "contiguous_strides" per size
# this avoids the problems in L.view during template codegen
new_stride = [1]
for sz in reversed(blocked_w.shape[1:]):
new_stride.insert(0, new_stride[0] * sz)
blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride)
new_inputs[1] = blocked_w
return new_inputs, layout_or_out
def preprocessor(inputs, layout):
return pack_weight(*transpose_weight(*reorder_and_filter(inputs, layout)))
def postprocessor(output):
if isinstance(output, ir.TensorBox):
# prepack the weight as input to the template buffer
# TODO(jgong5): prune the unused constants in V.graph
# Should we implement it with constant folding in the scheduler instead?
template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
assert isinstance(template_buffer, ir.CppTemplateBuffer)
new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
W_node = new_input_nodes[1]
assert W_node.get_name() in V.graph.constants
W = V.graph.constants[W_node.get_name()]
new_input_nodes[1] = W
new_input_nodes, _ = pack_weight(
*transpose_weight(new_input_nodes, layout)
)
W_packed = new_input_nodes[1]
W_packed_constant = V.graph.add_tensor_constant(W_packed)
template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input(
W_packed_constant
)
return output
template = DataProcessorTemplateWrapper(
CppPackedGemmTemplate,
preprocessor,
postprocessor,
input_nodes=input_nodes,
layout=layout,
num_threads=num_threads,
register_blocking=micro_gemm.register_blocking,
beta=beta,
alpha=alpha,
)
template.maybe_append_choice(choices)
return template
def render( # type: ignore[override]
self,
kernel: CppTemplateKernel,
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
epilogue_nodes: Optional[List[ir.IRNode]] = None,
**kwargs,
) -> str:
assert len(self.input_nodes) >= 2
X, W = self.input_nodes[0], self.input_nodes[1]
inp = self.input_nodes[2] if len(self.input_nodes) > 2 else None
Y = self.output_node
if template_buffer_node is not None:
# Use the updated prepacked weight buffer
W = template_buffer_node.inputs[1]
Y = template_buffer_node
template_buffer = Y
Y_is_transposed = False
# TODO(jgong5): support local accumulation
use_local_acc = False
if epilogue_nodes:
Y = cast(ir.Buffer, epilogue_nodes[-1])
assert Y.get_name() in V.kernel.inplace_update_buffers
if Y.get_size() == list(
reversed(template_buffer.get_size())
) and Y.get_stride() == list(reversed(template_buffer.get_stride())):
Y_is_transposed = True
micro_gemm = create_micro_gemm(
f"{kernel.kernel_name}_micro_gemm",
self.m,
self.n,
self.k,
self.layout.dtype,
alpha=self.alpha,
num_threads=self.num_threads,
)
assert micro_gemm is not None
assert self.register_blocking == micro_gemm.register_blocking
options = dict(
X=X,
W=W,
inp=inp,
Y=Y,
GemmOut=template_buffer,
beta=self.beta,
alpha=self.alpha,
num_threads=self.num_threads,
micro_gemm=micro_gemm,
is_dynamic_M=self.is_dynamic_M,
template=self,
kernel=kernel,
epilogue_nodes=epilogue_nodes,
reindexer=(lambda x: list(reversed(x))) if Y_is_transposed else None,
use_local_acc=use_local_acc,
)
return self._template_from_string(GEMM_TEMPLATE).render(**options)