Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / core / context_base.h

#pragma once

#include <array>
#include <cstdlib>
#include <ctime>
#include <memory>
#include <unordered_map>

#include <c10/macros/Macros.h>
#include <c10/core/Allocator.h>
#include <c10/util/typeid.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
#include <c10/core/CopyBytes.h>

#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/proto/caffe2_pb.h"

namespace caffe2 {
class Event;

} // namespace caffe2
namespace at {

class BaseContext;

/**
 * Virtual interface for the Context class in Caffe2.
 *
 * A Context defines all the necessities to run an operator on a specific
 * device. Specific Context classes needs to implement all the pure virtual
 * functions in the BaseContext class.
 * TODO: add docs after this is finalized.
 */
class TORCH_API BaseContext {
 public:
  virtual ~BaseContext() noexcept {}

  virtual Device device() const = 0;

  /* Sorry for the naming, will get rid of this in future diff */
  virtual DeviceType device_type() const = 0;

  virtual void SwitchToDevice(int /*stream_id*/) = 0;

  inline void SwitchToDevice() {
    SwitchToDevice(0);
  }

  virtual void WaitEvent(const caffe2::Event& ev) = 0;

  virtual void Record(caffe2::Event* ev, const char* err_msg = nullptr)
      const = 0;

  virtual void FinishDeviceComputation() = 0;

  // This used to be arbitrary cross-device copy, but it turns out everyone
  // did direct CPU-X copy, so we just make three functions for it (to avoid
  // double dispatch).  This will get obsoleted by C10. where copies
  // will be proper operators (and get to rely on multiple dispatch there.)
  virtual void CopyBytesSameDevice(
      size_t nbytes,
      const void* src,
      void* dst) = 0;

  virtual void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) = 0;

  virtual void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) = 0;

  template <typename T>
  inline void CopySameDevice(size_t n, const T* src, T* dst) {
    static_assert(
        c10::guts::is_fundamental<T>::value,
        "CopySameDevice requires fundamental types");
    CopyBytesSameDevice(
        n * sizeof(T), static_cast<const void*>(src), static_cast<void*>(dst));
  }

  template <typename T>
  inline void CopyFromCPU(size_t n, const T* src, T* dst) {
    static_assert(
        c10::guts::is_fundamental<T>::value,
        "CopyFromCPU requires fundamental types");
    CopyBytesFromCPU(
        n * sizeof(T), static_cast<const void*>(src), static_cast<void*>(dst));
  }

  template <typename T>
  inline void CopyToCPU(size_t n, const T* src, T* dst) {
    static_assert(
        c10::guts::is_fundamental<T>::value, "CopyToCPU requires fundamental types");
    CopyBytesToCPU(
        n * sizeof(T), static_cast<const void*>(src), static_cast<void*>(dst));
  }

  virtual bool SupportsNonFundamentalTypes() const {
    return false;
  }

  inline void EnforceMetaCopyOK() {
    AT_ASSERTM(
        SupportsNonFundamentalTypes(), "Context requires fundamental types");
  }

  void CopyItemsSameDevice(
      const caffe2::TypeMeta meta,
      size_t n,
      const void* src,
      void* dst) {
    if (meta.copy()) {
      EnforceMetaCopyOK();
      meta.copy()(src, dst, n);
    } else {
      CopyBytesSameDevice(n * meta.itemsize(), src, dst);
    }
  }

  void CopyItemsFromCPU(
      const caffe2::TypeMeta meta,
      size_t n,
      const void* src,
      void* dst) {
    if (meta.copy()) {
      EnforceMetaCopyOK();
      meta.copy()(src, dst, n);
    } else {
      CopyBytesFromCPU(n * meta.itemsize(), src, dst);
    }
  }

  void CopyItemsToCPU(
      const caffe2::TypeMeta meta,
      size_t n,
      const void* src,
      void* dst) {
    if (meta.copy()) {
      EnforceMetaCopyOK();
      meta.copy()(src, dst, n);
    } else {
      CopyBytesToCPU(n * meta.itemsize(), src, dst);
    }
  }
};

// Context constructor registry
C10_DECLARE_TYPED_REGISTRY(
    ContextRegistry,
    at::DeviceType,
    at::BaseContext,
    std::unique_ptr,
    at::Device);

#define REGISTER_CONTEXT(type, ...) \
  C10_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__)

inline std::unique_ptr<at::BaseContext> CreateContext(
    const at::Device& device) {
  return at::ContextRegistry()->Create(device.type(), device);
}

} // namespace at

namespace caffe2 {

using at::BaseContext;
using at::CreateContext;
} // namespace caffe2