import functools
import itertools
import logging
from collections.abc import Iterable
from typing import List, Optional, Tuple
import sympy
import torch
import torch.fx
import torch.utils._pytree as pytree
from torch._prims_common import (
canonicalize_dims,
dtype_to_type,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
is_boolean_dtype,
is_float_dtype,
is_integer_dtype,
Number,
)
from torch.fx.experimental.symbolic_shapes import magic_methods, method_to_operator
from .._dynamo.utils import import_submodule
from . import config, ir, overrides, test_operators # NOQA: F401
from .cuda_properties import current_device
from .decomposition import decompositions, get_decompositions
from .ir import (
ExpandView,
IndexingConstant,
PermuteView,
Pointwise,
Reduction,
SqueezeView,
TensorBox,
validate_ir,
View,
)
from .utils import ceildiv, developer_warning, sympy_product
from .virtualized import ops, V
log = logging.getLogger(__name__)
lowerings = {}
layout_constraints = {}
fallbacks = set()
aten = torch.ops.aten
prims = torch.ops.prims
needs_realized_inputs = set()
def add_needs_realized_inputs(fn):
if isinstance(fn, (list, tuple, set)):
return [add_needs_realized_inputs(x) for x in fn]
needs_realized_inputs.add(fn)
if isinstance(fn, torch._ops.OpOverloadPacket):
for overload in fn.overloads():
needs_realized_inputs.add(getattr(fn, overload))
def add_layout_constraint(fn, constraint):
if isinstance(fn, torch._ops.OpOverloadPacket):
for overload in fn.overloads():
layout_constraints[getattr(fn, overload)] = constraint
else:
layout_constraints[fn] = constraint
add_needs_realized_inputs(
[
aten.as_strided,
aten.avg_pool2d,
aten.avg_pool2d_backward,
aten.bmm,
aten.convolution,
aten.convolution_backward,
aten.max_pool2d_with_indices,
aten.max_pool2d_with_indices_backward,
aten.mm,
aten.upsample_bilinear2d,
aten.upsample_nearest2d,
aten.upsample_bicubic2d,
]
)
# TODO(jansel): ezyang says we won't need this in the future, try removing it
# based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28
DTYPE_ID_LOOKUP = {
0: torch.uint8,
1: torch.int8,
2: torch.int16,
3: torch.int32,
4: torch.int64,
5: torch.float16,
6: torch.float32,
7: torch.float64,
8: torch.complex32,
9: torch.complex64,
10: torch.complex32,
11: torch.bool,
15: torch.bfloat16,
# TODO(jansel): add quantized types?
# _(c10::qint8, QInt8) /* 12 */
# _(c10::quint8, QUInt8) /* 13 */
# _(c10::qint32, QInt32) /* 14 */
# _(c10::quint4x2, QUInt4x2) /* 16 */
# _(c10::quint2x4, QUInt2x4) /* 17 */
}
def decode_dtype(dtype: int):
if not isinstance(dtype, int):
return dtype
assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP"
dtype = DTYPE_ID_LOOKUP[dtype]
return dtype
def is_integer_type(x):
if isinstance(x, TensorBox):
return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
else:
return isinstance(x, int)
def is_boolean_type(x):
if isinstance(x, TensorBox):
return is_boolean_dtype(x.get_dtype())
else:
return isinstance(x, bool)
def decode_device(device):
if device is None:
return torch.tensor(0.0).device # default device
if isinstance(device, str):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=current_device())
return device
def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
def construct_input(inp):
if isinstance(inp, Number):
return inp
else:
assert hasattr(inp, "get_dtype")
dim = len(inp.get_size())
# construct a tmp tensor to feed into torch.result_type
return torch.zeros([1] * dim, dtype=inp.get_dtype())
inps = [construct_input(arg) for arg in args]
_, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind)
return dtype
def _register_lowering(
aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool
):
"""
Add a lowering to lowerings dict
Arguments:
aten_fn: torch.ops.aten.* fn we are lowering
decomp_fn: alternate implementation on our IR
broadcast: True to apply broadcasting to tensor inputs
type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
convert_input_to_bool: some logical ops require inputs are converted to bool
"""
@functools.wraps(decomp_fn)
def wrapped(*args, **kwargs):
args = list(args)
unpacked = False
# TODO maybe we need to use pytrees here
if len(args) == 1 and isinstance(args[0], (list, tuple)):
unpacked = True
args = args[0]
# Only look at args that are Tensors
indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
# explicitly assert for "out=" ops for better error messages
assert not any(
x == "out" for x in kwargs.keys()
), "out= ops aren't yet supported"
# kwargs tensors not supported yet unless it's a fallback op
assert not any(isinstance(x, TensorBox) for x in kwargs.values()) or all(
fn in fallbacks for fn in aten_fn
)
if (type_promotion_kind or convert_input_to_bool) and indices:
if convert_input_to_bool:
dtype = torch.bool
else:
# FIXME that's a crude approximation for promoting args
promoting_args = [
a for a in args if isinstance(a, Number) or hasattr(a, "get_dtype")
]
dtype = get_promoted_dtype(
*promoting_args, type_promotion_kind=type_promotion_kind
)
# sometimes args are an immutable list so we can't mutate them
new_args = []
for i in range(len(args)):
if i in indices:
new_args.append(to_dtype(args[i], dtype))
elif isinstance(args[i], ir.Constant):
new_args.append(
ir.Constant(args[i].value, dtype, args[indices[0]].get_device())
)
else:
new_args.append(args[i])
args = new_args
if unpacked:
args = [args]
if broadcast and indices:
for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
args[i] = x
for i in range(len(args)):
if isinstance(args[i], ir.Constant):
args[i] = ExpandView.create(
args[i], list(args[indices[0]].get_size())
)
out = decomp_fn(*args, **kwargs)
validate_ir(out)
return out
if not isinstance(aten_fn, (list, tuple)):
aten_fn = [aten_fn]
else:
aten_fn = list(aten_fn)
for fn in list(aten_fn):
if isinstance(fn, torch._ops.OpOverloadPacket):
for overload in fn.overloads():
other_fn = getattr(fn, overload)
if other_fn not in lowerings:
aten_fn.append(other_fn)
lowerings.update({fn: wrapped for fn in aten_fn})
return wrapped
def register_lowering(
aten_fn,
broadcast=False,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
convert_input_to_bool=False,
):
"""
Shim to support decorator syntax.
"""
return functools.partial(
_register_lowering,
aten_fn,
broadcast=broadcast,
type_promotion_kind=type_promotion_kind,
convert_input_to_bool=convert_input_to_bool,
)
def broadcast_symbolic_shapes(a, b):
"""
Broadcasting logic based on symbolic shapes.
We give the shapes 0 and 1 concrete values, while all other shapes
are symbolic sympy formulas.
"""
output = []
for a, b in itertools.zip_longest(
reversed(a), reversed(b), fillvalue=sympy.Integer(1)
):
if b == 1:
output.append(a)
elif a == 1:
output.append(b)
else:
V.graph.sizevars.guard_equals(a, b)
if len(sympy.expand(b).free_symbols) < len(sympy.expand(a).free_symbols):
output.append(b) # prefer shorter formula
else:
output.append(a)
return tuple(reversed(output))
def promote_constants(inputs, override_return_dtype=None):
if not any(isinstance(x, (sympy.Expr, int, float)) for x in inputs):
return inputs
if all(isinstance(x, (int, float)) for x in inputs):
dtype = override_return_dtype or get_promoted_dtype(
*inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
return [ir.Constant(x, dtype, decode_device(None)) for x in inputs]
ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView)))
out = []
for x in inputs:
if isinstance(x, (int, float)):
out.append(
ExpandView.create(
ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size())
)
)
elif isinstance(x, sympy.Expr):
out.append(IndexingConstant(x, ex.get_dtype(), ex.get_device()))
else:
out.append(x)
return out
def make_pointwise(
fn,
override_return_dtype=None,
override_device=None,
override_fn_when_input_bool=None,
override_fn_when_cuda_float64=None,
allow_alpha=False,
):
def inner(*inputs: List[TensorBox], alpha=None):
inputs = promote_constants(inputs, override_return_dtype)
if allow_alpha:
if alpha is not None and alpha != 1:
inputs = list(inputs)
inputs[-1] = mul(inputs[-1], alpha)
else:
assert alpha is None
loaders = [x.make_loader() for x in inputs]
ranges = inputs[0].get_size()
dtype = override_return_dtype or inputs[0].get_dtype()
is_cuda = decode_device(inputs[0].get_device()).type == "cuda"
for other in inputs[1:]:
assert isinstance(other, ir.BaseConstant) or len(ranges) == len(
other.get_size()
), f"ndim mismatch {fn} {ranges} {other.get_size()}"
def inner_fn(index):
assert len(index) == len(ranges), f"wrong ndim {index} {ranges}"
if dtype == torch.bool and override_fn_when_input_bool is not None:
return override_fn_when_input_bool(*[load(index) for load in loaders])
elif override_fn_when_cuda_float64 and is_cuda and dtype == torch.float64:
return override_fn_when_cuda_float64(*[load(index) for load in loaders])
else:
Loading ...