Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / proto / caffe2_pb.h

#pragma once
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
#include <caffe2/proto/caffe2.pb.h>

namespace caffe2 {

using DeviceType = at::DeviceType;
constexpr DeviceType CPU = DeviceType::CPU;
constexpr DeviceType CUDA = DeviceType::CUDA;
constexpr DeviceType OPENGL = DeviceType::OPENGL;
constexpr DeviceType OPENCL = DeviceType::OPENCL;
constexpr DeviceType MKLDNN = DeviceType::MKLDNN;
constexpr DeviceType IDEEP = DeviceType::IDEEP;
constexpr DeviceType HIP = DeviceType::HIP;
constexpr DeviceType COMPILE_TIME_MAX_DEVICE_TYPES =
    DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;

inline TORCH_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) {
  switch (p) {
    case caffe2::PROTO_CPU:
      return DeviceType::CPU;
    case caffe2::PROTO_CUDA:
      return DeviceType::CUDA;
    case caffe2::PROTO_OPENGL:
      return DeviceType::OPENGL;
    case caffe2::PROTO_OPENCL:
      return DeviceType::OPENCL;
    case caffe2::PROTO_MKLDNN:
      return DeviceType::MKLDNN;
    case caffe2::PROTO_IDEEP:
      return DeviceType::IDEEP;
    case caffe2::PROTO_HIP:
      return DeviceType::HIP;
    case caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES:
      return DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
    default:
      AT_ERROR(
          "Unknown device:",
          static_cast<int32_t>(p),
          ". If you have recently updated the caffe2.proto file to add a new "
          "device type, did you forget to update the ProtoToType() and TypeToProto"
          "function to reflect such recent changes?");
  }
}

inline TORCH_API DeviceType ProtoToType(int p) {
  return ProtoToType(static_cast<caffe2::DeviceTypeProto>(p));
}

inline TORCH_API DeviceTypeProto TypeToProto(const DeviceType& t) {
  switch (t) {
    case DeviceType::CPU:
      return caffe2::PROTO_CPU;
    case DeviceType::CUDA:
      return caffe2::PROTO_CUDA;
    case DeviceType::OPENGL:
      return caffe2::PROTO_OPENGL;
    case DeviceType::OPENCL:
      return caffe2::PROTO_OPENCL;
    case DeviceType::MKLDNN:
      return caffe2::PROTO_MKLDNN;
    case DeviceType::IDEEP:
      return caffe2::PROTO_IDEEP;
    case DeviceType::HIP:
      return caffe2::PROTO_HIP;
    case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES:
      return caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES;
    default:
      AT_ERROR(
          "Unknown device:",
          static_cast<int32_t>(t),
          ". If you have recently updated the caffe2.proto file to add a new "
          "device type, did you forget to update the ProtoToType() and TypeToProto"
          "function to reflect such recent changes?");
  }
}

inline TORCH_API caffe2::DeviceOption DeviceToOption(
    const at::Device& device) {
  caffe2::DeviceOption option;
  auto type = device.type();
  option.set_device_type(TypeToProto(type));

  switch (type) {
    case DeviceType::CPU:
      if (device.index() != -1) {
        option.set_numa_node_id(device.index());
      }
      break;
    case DeviceType::CUDA:
    case DeviceType::HIP:
      option.set_device_id(device.index());
      break;
    case DeviceType::OPENGL:
    case DeviceType::OPENCL:
    case DeviceType::MKLDNN:
    case DeviceType::IDEEP:
    case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES:
      break;
    default:
      AT_ERROR(
          "Unknown device:",
          static_cast<int32_t>(type),
          ". If you have recently updated the caffe2.proto file to add a new "
          "device type, did you forget to update the ProtoToType() and TypeToProto"
          "function to reflect such recent changes?");
  }
  return option;
}

inline TORCH_API at::Device OptionToDevice(const caffe2::DeviceOption option) {
  auto type = option.device_type();
  int32_t id = -1;
  switch (type) {
    case caffe2::PROTO_CPU:
      if (option.has_numa_node_id()) {
        id = option.numa_node_id();
      }
      break;
    case caffe2::PROTO_CUDA:
    case caffe2::PROTO_HIP:
      id = option.device_id();
      break;
  }
  return at::Device(ProtoToType(type), id);
}

inline void ExtractDeviceOption(
    DeviceOption* device_option,
    const at::Device& device) {
  AT_ASSERT(device_option);
  device_option->CopyFrom(DeviceToOption(device));
}

} // namespace caffe2