import builtins
import copy
import functools
import hashlib
import json
import logging
import operator
import os.path
import re
import threading
from typing import List
import torch
from torch._dynamo.utils import dynamo_timed
from .. import config
from ..codecache import cache_dir
from ..ir import ReductionHint, TileHint
from ..utils import conditional_product, has_triton
from .conv_perf_model import (
early_config_prune as conv_early_config_prune,
estimate_conv_time,
)
log = logging.getLogger(__name__)
if has_triton():
import triton
from triton import cdiv, Config, next_power_of_2
from triton.runtime.jit import get_cuda_stream, KernelInterface
else:
cdiv = None
Config = object
get_cuda_stream = None
KernelInterface = object
next_power_of_2 = None
triton = None
class CachingAutotuner(KernelInterface):
"""
Simplified version of Triton autotuner that has no invalidation
key and caches the best config to disk to improve cold start times.
Unlike the main triton Autotuner, this version can precompile all
configs, and does not rely on the Triton JIT.
"""
def __init__(self, fn, meta, configs, save_cache_hook, mutated_arg_names):
super().__init__()
self.fn = fn
self.meta = meta
self.save_cache_hook = save_cache_hook
self.mutated_arg_names = mutated_arg_names
self.configs = configs
self.launchers = []
self.lock = threading.Lock()
if os.getenv("TRITON_CACHE_DIR") is None:
os.environ["TRITON_CACHE_DIR"] = os.path.join(
cache_dir(),
"triton",
str(self.meta.get("device", 0)),
)
def precompile(self, warm_cache_only_with_cc=None):
with self.lock:
if self.launchers:
return
self.launchers = [
self._precompile_config(c, warm_cache_only_with_cc)
for c in self.configs
]
self.configs = None
def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int):
"""Ahead of time compile a given autotuner config."""
compile_meta = copy.deepcopy(self.meta)
for k, v in cfg.kwargs.items():
compile_meta["constants"][self.fn.arg_names.index(k)] = v
compile_meta["num_warps"] = cfg.num_warps
compile_meta["num_stages"] = cfg.num_stages
if warm_cache_only_with_cc:
triton.compile(
self.fn,
warm_cache_only=True,
cc=warm_cache_only_with_cc,
**compile_meta,
)
return
# load binary to the correct device
with torch.cuda.device(compile_meta["device"]):
# need to initialize context
torch.cuda.synchronize(torch.cuda.current_device())
binary = triton.compile(
self.fn,
**compile_meta,
)
call_args = [
arg
for i, arg in enumerate(self.fn.arg_names)
if i not in self.fn.constexprs
]
def_args = list(self.fn.arg_names)
while def_args and def_args[-1] in cfg.kwargs:
def_args.pop()
scope = {
"grid_meta": cfg.kwargs,
"bin": binary,
"torch": torch,
"set_device": torch.cuda.set_device,
"current_device": torch.cuda.current_device,
}
exec(
f"""
def launcher({', '.join(def_args)}, grid, stream):
if callable(grid):
grid_0, grid_1, grid_2 = grid(grid_meta)
else:
grid_0, grid_1, grid_2 = grid
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
stream, bin.cu_function, None, None, None,
{', '.join(call_args)})
""".lstrip(),
scope,
)
launcher = scope["launcher"]
launcher.config = cfg
return launcher
def bench(self, launcher, *args, grid):
"""Measure the performance of a given launcher"""
stream = get_cuda_stream(torch.cuda.current_device())
def kernel_call():
if launcher.config.pre_hook is not None:
launcher.config.pre_hook(
{**zip(self.arg_names, args), **launcher.config.kwargs}
)
launcher(
*args,
grid=grid,
stream=stream,
)
from triton.testing import do_bench
return do_bench(kernel_call, rep=40, fast_flush=True)
@dynamo_timed
def autotune_to_one_config(self, *args, **kwargs):
"""Do the actual autotuning"""
from ..compile_fx import clone_preserve_strides
# clone inplace buffers to avoid autotune contaminating them if
# the kernel does in-place stores. avoid cloning other buffers because
# it leads to increase memory use
cloned_args = []
for i, arg in enumerate(args):
if self.fn.arg_names[i] in self.mutated_arg_names:
assert isinstance(arg, torch.Tensor)
cloned_args.append(clone_preserve_strides(arg))
else:
cloned_args.append(arg)
timings = {
launcher: self.bench(launcher, *cloned_args, **kwargs)
for launcher in self.launchers
}
self.launchers = [builtins.min(timings, key=timings.get)]
if self.save_cache_hook:
self.save_cache_hook(self.launchers[0].config)
def run(self, *args, grid, stream):
if len(self.launchers) != 1:
if len(self.launchers) == 0:
self.precompile()
if len(self.launchers) > 1:
self.autotune_to_one_config(*args, grid=grid)
(launcher,) = self.launchers
if launcher.config.pre_hook is not None:
launcher.config.pre_hook(
{**zip(self.arg_names, args), **launcher.config.kwargs}
)
try:
result = launcher(
*args,
grid=grid,
stream=stream,
)
except TypeError as e:
if re.match(r"function takes exactly \d+ arguments \(\d+ given\)", str(e)):
raise RuntimeError(
"""Consider updating Triton with
`pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`"""
) from e
else:
raise e
return result
def hash_configs(configs: List[Config]):
"""
Hash used to check for changes in configurations
"""
hasher = hashlib.sha256()
for cfg in configs:
hasher.update(
f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode(
"utf-8"
)
)
return hasher.hexdigest()
def load_cached_autotuning(
cache_filename: str, configs_hash: str, configs: List[Config]
):
"""
Read a cached autotuning result from disk
"""
if not os.path.exists(cache_filename):
return None
with open(cache_filename, "r") as fd:
best_config = json.loads(fd.read())
if best_config.get("configs_hash") != configs_hash:
return None
matching_configs = [
cfg
for cfg in configs
if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
]
if len(matching_configs) != 1:
return None
return matching_configs[0]
def cached_autotune(
configs: List[Config],
meta,
filename=None,
):
"""
A copy of triton.autotune that calls our subclass. Our subclass
has additional debugging, error handling, and on-disk caching.
"""
configs = unique_configs(configs)
assert len(configs) == 1 or filename
# on disk caching logic
if filename is not None and len(configs) > 1:
cache_filename = os.path.splitext(filename)[0] + ".best_config"
configs_hash = hash_configs(configs)
best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
if best_config:
configs = [best_config]
def save_cache_hook(cfg):
with open(cache_filename, "w") as fd:
fd.write(json.dumps({**cfg.kwargs, "configs_hash": configs_hash}))
else:
save_cache_hook = None
mutated_arg_names = meta.pop("mutated_arg_names", ())
def decorator(fn):
return CachingAutotuner(
fn,
meta=meta,
configs=configs,
save_cache_hook=save_cache_hook,
mutated_arg_names=mutated_arg_names,
)
return decorator
def unique_configs(configs: List[Config]):
"""Remove duplicate configurations"""
seen = set()
pruned_configs = []
for cfg in configs:
key = tuple(cfg.kwargs.items())
if key not in seen:
seen.add(key)
pruned_configs.append(cfg)
return pruned_configs
def triton_config(size_hints, x, y=None, z=None, num_stages=1) -> Config:
"""
Construct a pointwise triton config with some adjustment heuristics
based on size_hints. Size_hints is a tuple of numels in each tile
dimension and will be rounded up to the nearest power of 2.
"""
# Ideally we want to read this from some device config
maxGridSize = [2147483647, 65535, 65535]
target = conditional_product(x, y, z)
if conditional_product(*size_hints) < target:
target //= 8
# shrink sizes to size hints
x = min(x, size_hints[0])
if y:
y = min(y, size_hints[1])
if z:
z = min(z, size_hints[2])
# if we are below original block size, scale up where we can;
# or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
while x < size_hints[0] and (
x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target
):
x *= 2
while (
y
and y < size_hints[1]
and (
y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target
)
):
y *= 2
while (
z
and z < size_hints[2]
and (
z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target
)
):
z *= 2
cfg = {"XBLOCK": x}
if y:
cfg["YBLOCK"] = y
if z:
cfg["ZBLOCK"] = z
num_warps = next_power_of_2(min(max(conditional_product(x, y, z) // 256, 1), 8))
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
def triton_config_reduction(size_hints, x, r, num_stages=2) -> Config:
"""
Construct a reduction triton config with some adjustment heuristics
based on size_hints. Size_hints is a tuple of numels in each tile
dimension and will be rounded up to the nearest power of 2.
"""
target = conditional_product(x, r)
if conditional_product(*size_hints) < target:
target //= 8
# shrink sizes to size hints
x = min(x, size_hints[0])
r = min(r, size_hints[1])
# if we are below original block size, scale up where we can
while x < size_hints[0] and conditional_product(x, r) < target:
x *= 2
while r < size_hints[1] and conditional_product(x, r) < target:
r *= 2
cfg = {"XBLOCK": x, "RBLOCK": r}
num_warps = next_power_of_2(min(max(conditional_product(x, r) // 128, 2), 8))
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=2):
"""
Construct a tile reduction triton config with some adjustment
heuristics based on size_hints. Size_hints is a tuple of numels in
each tile dimension and will be rounded up to the nearest power of 2.
"""
target = conditional_product(x, y, r)
if conditional_product(*size_hints) < target:
target //= 8
# shrink sizes to size hints
x = min(x, size_hints[0])
y = min(y, size_hints[1])
r = min(r, size_hints[2])
# if we are below original block size, scale up where we can
while x < size_hints[0] and conditional_product(x, y, r) < target:
x *= 2
while r < size_hints[2] and conditional_product(x, y, r) < target:
r *= 2
while y < size_hints[1] and conditional_product(x, y, r) < target:
y *= 2
cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r}
num_warps = next_power_of_2(min(max(conditional_product(x, y, r) // 256, 1), 8))
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
def pointwise(size_hints, meta, tile_hint=None, filename=None):
"""
Construct @triton.heuristics() based on size_hints.
"""
numel = functools.reduce(operator.mul, size_hints)
bs = max(256, min(numel // 128, 1024))
if len(size_hints) == 1:
return cached_autotune([triton_config(size_hints, bs)], meta=meta)
if len(size_hints) == 2:
if (
not config.triton.autotune_pointwise or tile_hint == TileHint.SQUARE
) and not config.max_autotune:
return cached_autotune([triton_config(size_hints, 32, 32)], meta=meta)
return cached_autotune(
[
triton_config(size_hints, 32, 32),
triton_config(size_hints, 64, 64), # ~8% better for fp16
triton_config(size_hints, 256, 16),
triton_config(size_hints, 16, 256),
triton_config(size_hints, bs, 1),
triton_config(size_hints, 1, bs),
],
meta=meta,
filename=filename,
)
if len(size_hints) == 3:
if not config.triton.autotune_pointwise:
return cached_autotune([triton_config(size_hints, 16, 16, 16)], meta=meta)
return cached_autotune(
[
triton_config(size_hints, 16, 16, 16),
triton_config(size_hints, 64, 8, 8),
triton_config(size_hints, 8, 64, 8),
triton_config(size_hints, 8, 8, 64),
triton_config(size_hints, bs, 1, 1),
triton_config(size_hints, 1, bs, 1),
triton_config(size_hints, 1, 1, bs),
],
meta=meta,
filename=filename,
)
raise NotImplementedError(f"size_hints: {size_hints}")
def reduction(size_hints, reduction_hint=False, meta=None, filename=None):
"""args to @triton.heuristics()"""
assert meta is not None
rnumel = size_hints[-1]
if len(size_hints) == 2:
contiguous_config = triton_config_reduction(
size_hints, 1, (rnumel if 256 <= rnumel < 2048 else 2048), num_stages=1
)
outer_config = triton_config_reduction(size_hints, 128, 8)
tiny_config = triton_config_reduction(
size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, 2048)
)
if config.max_autotune:
pass # skip all these cases
elif reduction_hint == ReductionHint.INNER:
return cached_autotune([contiguous_config], meta=meta)
elif reduction_hint == ReductionHint.OUTER:
return cached_autotune([outer_config], meta=meta)
elif reduction_hint == ReductionHint.OUTER_TINY:
return cached_autotune([tiny_config], meta=meta)
if not config.triton.autotune_pointwise:
return cached_autotune(
[triton_config_reduction(size_hints, 32, 128)], meta=meta
)
return cached_autotune(
[
contiguous_config,
outer_config,
tiny_config,
triton_config_reduction(size_hints, 64, 64),
triton_config_reduction(size_hints, 8, 512),
],
meta=meta,
filename=filename,
)
raise NotImplementedError(f"size_hints: {size_hints}")
def persistent_reduction(size_hints, reduction_hint=False, meta=None, filename=None):
xnumel, rnumel = size_hints
configs = [
triton_config_reduction(size_hints, xblock, rnumel)
for xblock in (1, 8, 32, 128)
if rnumel * xblock <= 4096 and xblock <= xnumel
]
# TODO(jansel): we should be able to improve these heuristics
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
configs = configs[:1]
elif reduction_hint == ReductionHint.OUTER:
configs = configs[-1:]
elif reduction_hint == ReductionHint.OUTER_TINY:
configs = [
triton_config_reduction(
size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
)
]
return cached_autotune(
configs,
meta=meta,
filename=filename,
)
def template(num_stages, num_warps, meta, filename=None):
"""
Compile a triton template
"""
return cached_autotune(
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)], meta=meta
)
def conv_heuristics():
configs = [
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=8
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=2
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=2
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=2
),
# triton.Config(
# {"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 64}, num_stages=4, num_warps=2
# ),
]
key = [
"BATCH",
"IN_C",
"IN_H",
"IN_W",
"KERNEL_N",
"KERNEL_H",
"KERNEL_W",
"OUT_H",
"OUT_W",
# parameters of conv
"stride_h",
"stride_w",
"padding_h",
"padding_w",
"dilation_h",
"dilation_w",
"output_padding_h",
"output_padding_w",
"groups",
]
prune_configs_by = {
"early_config_prune": conv_early_config_prune,
"perf_model": estimate_conv_time,
"top_k": 10,
}
return triton.autotune(configs, key, prune_configs_by=prune_configs_by)
def grid(xnumel, ynumel=None, znumel=None):
"""Helper function to compute triton grids"""
def get_grid_dim(numel, block_name, block):
if numel is None:
return 1
label = block_name[0]
if numel == 1:
assert block == 1, (
f"TritonKernel.indexing assumes {label.lower()}numel == 1 => {block_name} == 1"
f"({label.lower()}numel=={numel}, {block_name}={block})."
)
return cdiv(numel, block)
def grid_fn(meta):
return (
get_grid_dim(xnumel, "XBLOCK", meta.get("XBLOCK", None)),
get_grid_dim(ynumel, "YBLOCK", meta.get("YBLOCK", None)),
get_grid_dim(znumel, "ZBLOCK", meta.get("ZBLOCK", None)),
)
return grid_fn