Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / c10 / core / Backend.h

#pragma once

#include <c10/core/DeviceType.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/util/Exception.h>

#include <stdexcept>

namespace c10 {

/**
 * This legacy enum class defines the set of backends supported by old school,
 * code generated Type-based ATen.  A "backend" in this sense roughly
 * corresponds to the cartesian product of (device type, layout), but restricted
 * only to combinations which we actually have kernels for.  Backend does NOT
 * include dtype.
 *
 * The reason we are sunsetting this enum class is because it doesn't allow for
 * open registration; e.g., if you want to add SparseXLA, you'd have to
 * edit this enum; you wouldn't be able to do it out of tree.  DispatchKey is
 * the replacement for Backend which supports open registration.
 *
 * NB: The concept of 'Backend' here disagrees with the notion of backend
 * exposed to users in torch.backends.  Backend here is something like "CPU"
 * or "SparseCUDA"; backend in torch.backends is something like "MKL" or
 * "CUDNN".
 */
enum class Backend {
  CPU,
  CUDA,
  HIP,
  VE,
  FPGA,
  IPU,
  XPU,
  SparseCPU,
  SparseCUDA,
  SparseCsrCPU,
  SparseCsrCUDA,
  SparseHIP,
  SparseVE,
  SparseXPU,
  ORT,
  XLA,
  Vulkan,
  Metal,
  Meta,
  QuantizedCPU,
  QuantizedCUDA,
  QuantizedXPU,
  Undefined,
  MkldnnCPU,
  MPS,
  HPU,
  Lazy,
  MTIA,
  PrivateUse1,
  NumOptions
};

static inline Backend dispatchKeyToBackend(DispatchKey t) {
  if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) {
    return Backend::CPU;
  } else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) {
    return Backend::CUDA;
  } else if (t == DispatchKey::HIP) {
    return Backend::HIP;
  } else if (t == DispatchKey::VE) {
    return Backend::VE;
  } else if (t == DispatchKey::FPGA) {
    return Backend::FPGA;
  } else if (t == DispatchKey::ORT) {
    return Backend::ORT;
  } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
    return Backend::XLA;
  } else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
    return Backend::Lazy;
  } else if (t == DispatchKey::MPS || t == DispatchKey::AutogradMPS) {
    return Backend::MPS;
  } else if (t == DispatchKey::Vulkan) {
    return Backend::Vulkan;
  } else if (t == DispatchKey::Metal) {
    return Backend::Metal;
  } else if (t == DispatchKey::Meta) {
    return Backend::Meta;
  } else if (t == DispatchKey::SparseCPU) {
    return Backend::SparseCPU;
  } else if (t == DispatchKey::SparseCUDA) {
    return Backend::SparseCUDA;
  } else if (t == DispatchKey::SparseHIP) {
    return Backend::SparseHIP;
  } else if (t == DispatchKey::SparseVE) {
    return Backend::SparseVE;
  } else if (t == DispatchKey::SparseCsrCPU) {
    return Backend::SparseCsrCPU;
  } else if (t == DispatchKey::SparseCsrCUDA) {
    return Backend::SparseCsrCUDA;
  } else if (t == DispatchKey::MkldnnCPU) {
    return Backend::MkldnnCPU;
  } else if (t == DispatchKey::QuantizedCPU) {
    return Backend::QuantizedCPU;
  } else if (t == DispatchKey::QuantizedCUDA) {
    return Backend::QuantizedCUDA;
  } else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) {
    return Backend::IPU;
  } else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) {
    return Backend::XPU;
  } else if (t == DispatchKey::SparseXPU) {
    return Backend::SparseXPU;
  } else if (t == DispatchKey::QuantizedXPU) {
    return Backend::QuantizedXPU;
  } else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) {
    return Backend::HPU;
  } else if (t == DispatchKey::MTIA) {
    return Backend::MTIA;
  } else if (t == DispatchKey::PrivateUse1) {
    return Backend::PrivateUse1;
  } else if (t == DispatchKey::Undefined) {
    return Backend::Undefined;
  } else {
    TORCH_CHECK(false, "Unrecognized tensor type ID: ", t);
  }
}

static inline DispatchKey backendToDispatchKey(Backend b) {
  switch (b) {
    case Backend::CPU:
      return DispatchKey::CPU;
    case Backend::CUDA:
      return DispatchKey::CUDA;
    case Backend::HIP:
      return DispatchKey::HIP;
    case Backend::VE:
      return DispatchKey::VE;
    case Backend::FPGA:
      return DispatchKey::FPGA;
    case Backend::ORT:
      return DispatchKey::ORT;
    case Backend::XLA:
      return DispatchKey::XLA;
    case Backend::Lazy:
      return DispatchKey::Lazy;
    case Backend::IPU:
      return DispatchKey::IPU;
    case Backend::XPU:
      return DispatchKey::XPU;
    case Backend::SparseXPU:
      return DispatchKey::SparseXPU;
    case Backend::SparseCPU:
      return DispatchKey::SparseCPU;
    case Backend::SparseCUDA:
      return DispatchKey::SparseCUDA;
    case Backend::SparseHIP:
      return DispatchKey::SparseHIP;
    case Backend::SparseVE:
      return DispatchKey::SparseVE;
    case Backend::SparseCsrCPU:
      return DispatchKey::SparseCsrCPU;
    case Backend::SparseCsrCUDA:
      return DispatchKey::SparseCsrCUDA;
    case Backend::MkldnnCPU:
      return DispatchKey::MkldnnCPU;
    case Backend::Vulkan:
      return DispatchKey::Vulkan;
    case Backend::Metal:
      return DispatchKey::Metal;
    case Backend::Meta:
      return DispatchKey::Meta;
    case Backend::QuantizedCPU:
      return DispatchKey::QuantizedCPU;
    case Backend::QuantizedCUDA:
      return DispatchKey::QuantizedCUDA;
    case Backend::Undefined:
      return DispatchKey::Undefined;
    case Backend::MPS:
      return DispatchKey::MPS;
    case Backend::HPU:
      return DispatchKey::HPU;
    case Backend::MTIA:
      return DispatchKey::MTIA;
    case Backend::PrivateUse1:
      return DispatchKey::PrivateUse1;
    default:
      throw std::runtime_error("Unknown backend");
  }
}

static inline DeviceType backendToDeviceType(Backend b) {
  switch (b) {
    case Backend::CPU:
      return DeviceType::CPU;
    case Backend::CUDA:
      return DeviceType::CUDA;
    case Backend::HIP:
      return DeviceType::HIP;
    case Backend::VE:
      return DeviceType::VE;
    case Backend::FPGA:
      return DeviceType::FPGA;
    case Backend::ORT:
      return DeviceType::ORT;
    case Backend::XLA:
      return DeviceType::XLA;
    case Backend::Lazy:
      return DeviceType::Lazy;
    case Backend::SparseCPU:
      return DeviceType::CPU;
    case Backend::SparseCUDA:
      return DeviceType::CUDA;
    case Backend::SparseHIP:
      return DeviceType::HIP;
    case Backend::SparseVE:
      return DeviceType::VE;
    case Backend::SparseCsrCPU:
      return DeviceType::CPU;
    case Backend::SparseCsrCUDA:
      return DeviceType::CUDA;
    case Backend::IPU:
      return DeviceType::IPU;
    case Backend::XPU:
    case Backend::SparseXPU:
    case Backend::QuantizedXPU:
      return DeviceType::XPU;
    case Backend::MkldnnCPU:
    case Backend::QuantizedCPU:
      return DeviceType::CPU;
    case Backend::QuantizedCUDA:
      return DeviceType::CUDA;
    case Backend::Vulkan:
      return DeviceType::Vulkan;
    case Backend::Metal:
      return DeviceType::Metal;
    case Backend::Meta:
      return DeviceType::Meta;
    case Backend::MPS:
      return DeviceType::MPS;
    case Backend::HPU:
      return DeviceType::HPU;
    case Backend::MTIA:
      return DeviceType::MTIA;
    case Backend::PrivateUse1:
      return DeviceType::PrivateUse1;
    case Backend::Undefined:
      TORCH_CHECK(false, "Undefined backend is not a valid device type");
    default:
      TORCH_CHECK(false, "Unknown backend");
  }
}

// TODO: This probably shouldn't actually be static inline
static inline const char* toString(Backend b) {
  switch (b) {
    case Backend::CPU:
      return "CPU";
    case Backend::CUDA:
      return "CUDA";
    case Backend::HIP:
      return "HIP";
    case Backend::VE:
      return "VE";
    case Backend::FPGA:
      return "FPGA";
    case Backend::XPU:
      return "XPU";
    case Backend::IPU:
      return "IPU";
    case Backend::ORT:
      return "ORT";
    case Backend::XLA:
      return "XLA";
    case Backend::Lazy:
      return "Lazy";
    case Backend::MPS:
      return "MPS";
    case Backend::SparseCPU:
      return "SparseCPU";
    case Backend::SparseCUDA:
      return "SparseCUDA";
    case Backend::SparseHIP:
      return "SparseHIP";
    case Backend::SparseVE:
      return "SparseVE";
    case Backend::SparseXPU:
      return "SparseXPU";
    case Backend::SparseCsrCPU:
      return "SparseCsrCPU";
    case Backend::SparseCsrCUDA:
      return "SparseCsrCUDA";
    case Backend::MkldnnCPU:
      return "MkldnnCPU";
    case Backend::Vulkan:
      return "Vulkan";
    case Backend::Metal:
      return "Metal";
    case Backend::Meta:
      return "Meta";
    case Backend::QuantizedCPU:
      return "QuantizedCPU";
    case Backend::QuantizedCUDA:
      return "QuantizedCUDA";
    case Backend::QuantizedXPU:
      return "QuantizedXPU";
    case Backend::HPU:
      return "HPU";
    case Backend::MTIA:
      return "MTIA";
    case Backend::PrivateUse1:
      return "PrivateUseOne";
    default:
      return "UNKNOWN_BACKEND";
  }
}

static inline bool isSparse(Backend b) {
  switch (b) {
    case Backend::SparseXPU:
    case Backend::SparseCPU:
    case Backend::SparseCUDA:
    case Backend::SparseHIP:
    case Backend::SparseVE:
      return true;
    default:
      return false;
  }
}

static inline bool isSparseCsr(Backend b) {
  switch (b) {
    case Backend::SparseCsrCPU:
    case Backend::SparseCsrCUDA:
      return true;
    default:
      return false;
  }
}

} // namespace c10