Repository URL to install this package:
Version:
2.1.2+cpu ▾
|
"""Export torch work functions for binary ufuncs, rename/tweak to match numpy.
This listing is further exported to public symbols in the `torch._numpy/_ufuncs.py` module.
"""
import torch
from torch import ( # noqa: F401
add, # noqa: F401
arctan2, # noqa: F401
bitwise_and, # noqa: F401
bitwise_left_shift as left_shift, # noqa: F401
bitwise_or, # noqa: F401
bitwise_right_shift as right_shift, # noqa: F401
bitwise_xor, # noqa: F401
copysign, # noqa: F401
divide, # noqa: F401
eq as equal, # noqa: F401
float_power, # noqa: F401
floor_divide, # noqa: F401
fmax, # noqa: F401
fmin, # noqa: F401
fmod, # noqa: F401
gcd, # noqa: F401
greater, # noqa: F401
greater_equal, # noqa: F401
heaviside, # noqa: F401
hypot, # noqa: F401
lcm, # noqa: F401
ldexp, # noqa: F401
less, # noqa: F401
less_equal, # noqa: F401
logaddexp, # noqa: F401
logaddexp2, # noqa: F401
logical_and, # noqa: F401
logical_or, # noqa: F401
logical_xor, # noqa: F401
maximum, # noqa: F401
minimum, # noqa: F401
multiply, # noqa: F401
nextafter, # noqa: F401
not_equal, # noqa: F401
pow as power, # noqa: F401
remainder, # noqa: F401
remainder as mod, # noqa: F401
subtract, # noqa: F401
true_divide, # noqa: F401
)
from . import _dtypes_impl, _util
# work around torch limitations w.r.t. numpy
def matmul(x, y):
# work around:
# - RuntimeError: expected scalar type Int but found Double
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool'
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
dtype = _dtypes_impl.result_type_impl(x, y)
is_bool = dtype == torch.bool
is_half = (x.dtype == torch.float16 or y.dtype == torch.float16) and (
x.is_cpu or y.is_cpu
)
work_dtype = dtype
if is_bool:
work_dtype = torch.uint8
if is_half:
work_dtype = torch.float32
x = _util.cast_if_needed(x, work_dtype)
y = _util.cast_if_needed(y, work_dtype)
result = torch.matmul(x, y)
if work_dtype != dtype:
result = result.to(dtype)
return result
# a stub implementation of divmod, should be improved after
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
def divmod(x, y):
return x // y, x % y