#pragma once
#include <c10/core/DeviceType.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/util/Exception.h>
#include <stdexcept>
namespace c10 {
/**
* This legacy enum class defines the set of backends supported by old school,
* code generated Type-based ATen. A "backend" in this sense roughly
* corresponds to the cartesian product of (device type, layout), but restricted
* only to combinations which we actually have kernels for. Backend does NOT
* include dtype.
*
* The reason we are sunsetting this enum class is because it doesn't allow for
* open registration; e.g., if you want to add SparseXLA, you'd have to
* edit this enum; you wouldn't be able to do it out of tree. DispatchKey is
* the replacement for Backend which supports open registration.
*
* NB: The concept of 'Backend' here disagrees with the notion of backend
* exposed to users in torch.backends. Backend here is something like "CPU"
* or "SparseCUDA"; backend in torch.backends is something like "MKL" or
* "CUDNN".
*/
enum class Backend {
CPU,
CUDA,
HIP,
VE,
FPGA,
IPU,
XPU,
SparseCPU,
SparseCUDA,
SparseCsrCPU,
SparseCsrCUDA,
SparseHIP,
SparseVE,
SparseXPU,
ORT,
XLA,
Vulkan,
Metal,
Meta,
QuantizedCPU,
QuantizedCUDA,
QuantizedXPU,
Undefined,
MkldnnCPU,
MPS,
HPU,
Lazy,
MTIA,
PrivateUse1,
NumOptions
};
static inline Backend dispatchKeyToBackend(DispatchKey t) {
if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) {
return Backend::CPU;
} else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) {
return Backend::CUDA;
} else if (t == DispatchKey::HIP) {
return Backend::HIP;
} else if (t == DispatchKey::VE) {
return Backend::VE;
} else if (t == DispatchKey::FPGA) {
return Backend::FPGA;
} else if (t == DispatchKey::ORT) {
return Backend::ORT;
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
return Backend::XLA;
} else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
return Backend::Lazy;
} else if (t == DispatchKey::MPS || t == DispatchKey::AutogradMPS) {
return Backend::MPS;
} else if (t == DispatchKey::Vulkan) {
return Backend::Vulkan;
} else if (t == DispatchKey::Metal) {
return Backend::Metal;
} else if (t == DispatchKey::Meta) {
return Backend::Meta;
} else if (t == DispatchKey::SparseCPU) {
return Backend::SparseCPU;
} else if (t == DispatchKey::SparseCUDA) {
return Backend::SparseCUDA;
} else if (t == DispatchKey::SparseHIP) {
return Backend::SparseHIP;
} else if (t == DispatchKey::SparseVE) {
return Backend::SparseVE;
} else if (t == DispatchKey::SparseCsrCPU) {
return Backend::SparseCsrCPU;
} else if (t == DispatchKey::SparseCsrCUDA) {
return Backend::SparseCsrCUDA;
} else if (t == DispatchKey::MkldnnCPU) {
return Backend::MkldnnCPU;
} else if (t == DispatchKey::QuantizedCPU) {
return Backend::QuantizedCPU;
} else if (t == DispatchKey::QuantizedCUDA) {
return Backend::QuantizedCUDA;
} else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) {
return Backend::IPU;
} else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) {
return Backend::XPU;
} else if (t == DispatchKey::SparseXPU) {
return Backend::SparseXPU;
} else if (t == DispatchKey::QuantizedXPU) {
return Backend::QuantizedXPU;
} else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) {
return Backend::HPU;
} else if (t == DispatchKey::MTIA) {
return Backend::MTIA;
} else if (t == DispatchKey::PrivateUse1) {
return Backend::PrivateUse1;
} else if (t == DispatchKey::Undefined) {
return Backend::Undefined;
} else {
TORCH_CHECK(false, "Unrecognized tensor type ID: ", t);
}
}
static inline DispatchKey backendToDispatchKey(Backend b) {
switch (b) {
case Backend::CPU:
return DispatchKey::CPU;
case Backend::CUDA:
return DispatchKey::CUDA;
case Backend::HIP:
return DispatchKey::HIP;
case Backend::VE:
return DispatchKey::VE;
case Backend::FPGA:
return DispatchKey::FPGA;
case Backend::ORT:
return DispatchKey::ORT;
case Backend::XLA:
return DispatchKey::XLA;
case Backend::Lazy:
return DispatchKey::Lazy;
case Backend::IPU:
return DispatchKey::IPU;
case Backend::XPU:
return DispatchKey::XPU;
case Backend::SparseXPU:
return DispatchKey::SparseXPU;
case Backend::SparseCPU:
return DispatchKey::SparseCPU;
case Backend::SparseCUDA:
return DispatchKey::SparseCUDA;
case Backend::SparseHIP:
return DispatchKey::SparseHIP;
case Backend::SparseVE:
return DispatchKey::SparseVE;
case Backend::SparseCsrCPU:
return DispatchKey::SparseCsrCPU;
case Backend::SparseCsrCUDA:
return DispatchKey::SparseCsrCUDA;
case Backend::MkldnnCPU:
return DispatchKey::MkldnnCPU;
case Backend::Vulkan:
return DispatchKey::Vulkan;
case Backend::Metal:
return DispatchKey::Metal;
case Backend::Meta:
return DispatchKey::Meta;
case Backend::QuantizedCPU:
return DispatchKey::QuantizedCPU;
case Backend::QuantizedCUDA:
return DispatchKey::QuantizedCUDA;
case Backend::Undefined:
return DispatchKey::Undefined;
case Backend::MPS:
return DispatchKey::MPS;
case Backend::HPU:
return DispatchKey::HPU;
case Backend::MTIA:
return DispatchKey::MTIA;
case Backend::PrivateUse1:
return DispatchKey::PrivateUse1;
default:
throw std::runtime_error("Unknown backend");
}
}
static inline DeviceType backendToDeviceType(Backend b) {
switch (b) {
case Backend::CPU:
return DeviceType::CPU;
case Backend::CUDA:
return DeviceType::CUDA;
case Backend::HIP:
return DeviceType::HIP;
case Backend::VE:
return DeviceType::VE;
case Backend::FPGA:
return DeviceType::FPGA;
case Backend::ORT:
return DeviceType::ORT;
case Backend::XLA:
return DeviceType::XLA;
case Backend::Lazy:
return DeviceType::Lazy;
case Backend::SparseCPU:
return DeviceType::CPU;
case Backend::SparseCUDA:
return DeviceType::CUDA;
case Backend::SparseHIP:
return DeviceType::HIP;
case Backend::SparseVE:
return DeviceType::VE;
case Backend::SparseCsrCPU:
return DeviceType::CPU;
case Backend::SparseCsrCUDA:
return DeviceType::CUDA;
case Backend::IPU:
return DeviceType::IPU;
case Backend::XPU:
case Backend::SparseXPU:
case Backend::QuantizedXPU:
return DeviceType::XPU;
case Backend::MkldnnCPU:
case Backend::QuantizedCPU:
return DeviceType::CPU;
case Backend::QuantizedCUDA:
return DeviceType::CUDA;
case Backend::Vulkan:
return DeviceType::Vulkan;
case Backend::Metal:
return DeviceType::Metal;
case Backend::Meta:
return DeviceType::Meta;
case Backend::MPS:
return DeviceType::MPS;
case Backend::HPU:
return DeviceType::HPU;
case Backend::MTIA:
return DeviceType::MTIA;
case Backend::PrivateUse1:
return DeviceType::PrivateUse1;
case Backend::Undefined:
TORCH_CHECK(false, "Undefined backend is not a valid device type");
default:
TORCH_CHECK(false, "Unknown backend");
}
}
// TODO: This probably shouldn't actually be static inline
static inline const char* toString(Backend b) {
switch (b) {
case Backend::CPU:
return "CPU";
case Backend::CUDA:
return "CUDA";
case Backend::HIP:
return "HIP";
case Backend::VE:
return "VE";
case Backend::FPGA:
return "FPGA";
case Backend::XPU:
return "XPU";
case Backend::IPU:
return "IPU";
case Backend::ORT:
return "ORT";
case Backend::XLA:
return "XLA";
case Backend::Lazy:
return "Lazy";
case Backend::MPS:
return "MPS";
case Backend::SparseCPU:
return "SparseCPU";
case Backend::SparseCUDA:
return "SparseCUDA";
case Backend::SparseHIP:
return "SparseHIP";
case Backend::SparseVE:
return "SparseVE";
case Backend::SparseXPU:
return "SparseXPU";
case Backend::SparseCsrCPU:
return "SparseCsrCPU";
case Backend::SparseCsrCUDA:
return "SparseCsrCUDA";
case Backend::MkldnnCPU:
return "MkldnnCPU";
case Backend::Vulkan:
return "Vulkan";
case Backend::Metal:
return "Metal";
case Backend::Meta:
return "Meta";
case Backend::QuantizedCPU:
return "QuantizedCPU";
case Backend::QuantizedCUDA:
return "QuantizedCUDA";
case Backend::QuantizedXPU:
return "QuantizedXPU";
case Backend::HPU:
return "HPU";
case Backend::MTIA:
return "MTIA";
case Backend::PrivateUse1:
return "PrivateUseOne";
default:
return "UNKNOWN_BACKEND";
}
}
static inline bool isSparse(Backend b) {
switch (b) {
case Backend::SparseXPU:
case Backend::SparseCPU:
case Backend::SparseCUDA:
case Backend::SparseHIP:
case Backend::SparseVE:
return true;
default:
return false;
}
}
static inline bool isSparseCsr(Backend b) {
switch (b) {
case Backend::SparseCsrCPU:
case Backend::SparseCsrCUDA:
return true;
default:
return false;
}
}
} // namespace c10