Repository URL to install this package:
|
Version:
2.4.1 ▾
|
# mypy: allow-untyped-defs
import math
from collections import namedtuple
import torch
from .common import ExprPrinter
DTYPE_TO_CPP = {
torch.float32: "float",
torch.float64: "double",
torch.float16: "half",
torch.int64: "int64_t",
torch.int32: "int32_t",
torch.int16: "int16_t",
torch.int8: "int8_t",
torch.uint64: "uint64_t",
torch.uint32: "uint32_t",
torch.uint16: "uint16_t",
torch.uint8: "uint8_t",
torch.bool: "bool",
torch.bfloat16: "bfloat16",
torch.complex64: "complex64",
torch.float8_e4m3fn: "float8_e4m3fn",
torch.float8_e5m2: "float8_e5m2",
}
DTYPE_TO_ATEN = {
torch.float32: "at::kFloat",
torch.float64: "at::kDouble",
torch.float16: "at::kHalf",
torch.int64: "at::kLong",
torch.int32: "at::kInt",
torch.int16: "at::kShort",
torch.int8: "at::kChar",
torch.uint64: "at::kUInt64",
torch.uint32: "at::kUInt32",
torch.uint16: "at::kUInt16",
torch.uint8: "at::kByte",
torch.uint32: "at::kUInt32",
torch.uint64: "at::kUInt64",
torch.bool: "at::kBool",
torch.bfloat16: "at::kBFloat16",
torch.complex32: "at::kComplexHalf",
torch.complex64: "at::kComplexFloat",
torch.complex128: "at::kComplexDouble",
torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
torch.float8_e5m2: "at::kFloat8_e5m2",
torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz",
torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz",
}
DEVICE_TO_ATEN = {
"cpu": "at::kCPU",
"cuda": "at::kCUDA",
}
LAYOUT_TO_ATEN = {
torch.strided: "at::kStrided",
torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined]
}
INDEX_TYPE = "long"
GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
class CppPrinter(ExprPrinter):
def _print_Integer(self, expr):
return f"{int(expr)}L"
def _print_Where(self, expr):
c = self.paren(self.doprint(expr.args[0]))
p = self.paren(self.doprint(expr.args[1]))
q = self.paren(self.doprint(expr.args[2]))
return f"{c} ? {p} : {q}"
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
if div != 1:
div = self.paren(self.doprint(div))
if expr.is_integer:
x = f"c10::div_floor_integer({x}, {div})"
else:
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
mod = self.paren(self.doprint(mod))
return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})"
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
if expr.is_integer:
return f"c10::div_floor_integer({x}, {div})"
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
def _print_floor(self, expr):
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_FloorToInt(self, expr):
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_TruncToInt(self, expr):
assert len(expr.args) == 1
r = f"std::trunc({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})"
def _print_TruncToFloat(self, expr):
assert len(expr.args) == 1
return f"std::trunc({self._print(expr.args[0])})"
def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"static_cast<double>({self._print(expr.args[0])})"
# TODO: This is wrong if one of the inputs is negative. This is hard to
# tickle though, as the inputs are typically positive (and if we can prove
# they are positive, we will have used Mod instead, for which this codegen
# is right).
def _print_PythonMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_CMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args
# TODO: This is only accurate up to 2**53
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
# use std::pow, that operates on floats
def _print_PowByNatural(self, expr):
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _print_FloatPow(self, expr):
base, exp = expr.args
return f"std::pow({self._print(base)}, {self._print(exp)})"
def _print_Pow(self, expr):
# Uses float constants to perform FP div
base, exp = expr.args
base = self._print(base)
if exp == 0.5 or exp == -0.5:
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
if exp.is_integer:
exp = int(exp)
if exp > 0:
r = "*".join([self.paren(base)] * exp)
elif exp < 0:
r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp)))
else: # exp == 0
r = "1.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
else:
# TODO: float vs double
return f"std::pow({base}, {float(exp)})"
def _print_Rational(self, expr):
# Uses float constants to perform FP div
if expr.q == 1:
r = f"{expr.p}"
else:
r = f"{expr.p}.0/{expr.q}.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_ceiling(self, expr):
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_CeilToInt(self, expr):
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Min(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::min({args[0]}, {args[1]})"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::min({il})"
def _print_Max(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::max({args[0]}, {args[1]})"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::max({il})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"std::abs({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cos(self, expr):
assert len(expr.args) == 1
return f"std::cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr):
assert len(expr.args) == 1
return f"std::cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr):
assert len(expr.args) == 1
return f"std::acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr):
assert len(expr.args) == 1
return f"std::sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr):
assert len(expr.args) == 1
return f"std::sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr):
assert len(expr.args) == 1
return f"std::asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr):
assert len(expr.args) == 1
return f"std::tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr):
assert len(expr.args) == 1
return f"std::tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr):
assert len(expr.args) == 1
return f"std::atan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sqrt(self, expr):
return f"std::sqrt({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr):
assert len(expr.args) == 1
# TODO: dispatch to llrint depending on index type
return f"std::lrint({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):
assert len(expr.args) == 2
number, ndigits = expr.args
if number.is_integer:
# ndigits < 0 should have been filtered by the sympy function
assert ndigits < 0
raise ValueError(
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
)
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})"
def _print_BooleanTrue(self, expr):
return "true"
def _print_BooleanFalse(self, expr):
return "false"
# A function to print, useful for printing sympy symbols.
cexpr = CppPrinter().doprint
def cexpr_index(index):
return f"static_cast<{INDEX_TYPE}>({cexpr(index)})"
def value_to_cpp(value, cpp_type):
if value == float("-inf"):
return f"-std::numeric_limits<{cpp_type}>::infinity()"
elif value == float("inf"):
return f"std::numeric_limits<{cpp_type}>::infinity()"
elif isinstance(value, bool):
return f"static_cast<{cpp_type}>({str(value).lower()})"
elif math.isnan(value):
return f"std::numeric_limits<{cpp_type}>::quiet_NaN()"
else:
return f"static_cast<{cpp_type}>({repr(value)})"