Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / core / context.h

#ifndef CAFFE2_CORE_CONTEXT_H_
#define CAFFE2_CORE_CONTEXT_H_

#include <cstdlib>
#include <ctime>
#include <random>
#include <unordered_map>

#include <c10/util/typeid.h>
#include "caffe2/core/allocator.h"
#include "caffe2/core/context_base.h"
#include "caffe2/core/event.h"
#include "caffe2/core/logging.h"
#include "caffe2/proto/caffe2_pb.h"

#include <c10/util/ArrayRef.h>

#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
#include <ATen/core/DistributionsHelper.h>
#include <ATen/CPUGeneratorImpl.h>
#else
#include "caffe2/core/distributions_stubs.h"
#endif

C10_DECLARE_bool(caffe2_report_cpu_memory_usage);

namespace caffe2 {

/**
 * A function to generate a random number seed that is unique in a best-effort
 * basis, using an ever-incrementing seed and the current time.
 */
TORCH_API uint32_t RandomNumberSeed();

/**
 * The CPU Context, representing the bare minimum of what a Context class in
 * Caffe2 should implement.
 *
 * // TODO modify docs
 * See operator.h, especially Operator<Context>, for how Context are used in
 * actual operator implementations that are associated with specific devices.
 * In general, the Context class is passed in as a template argument, and
 * the operator can use the functions defined in the context to execute whatever
 * computation it has.
 *
 */
class TORCH_API CPUContext final : public BaseContext {
 public:
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
  typedef at::CPUGeneratorImpl rand_gen_type;
#else
  typedef std::mt19937 rand_gen_type;
#endif

  CPUContext() {}
  explicit CPUContext(const DeviceOption& option)
      : random_seed_(option.has_random_seed() ? option.random_seed() : 1701),
        random_seed_set_(option.has_random_seed() ? true : false) {
    CAFFE_ENFORCE_EQ(option.device_type(), PROTO_CPU);
  }
  explicit CPUContext(const at::Device& device)
      : CPUContext(DeviceToOption(device)) {}

  ~CPUContext() noexcept override {}

  inline void SwitchToDevice(int /*stream_id*/) override {}

  using BaseContext::SwitchToDevice;

  inline void WaitEvent(const Event& ev) override {
    ev.Wait(CPU, this);
  }

  inline void Record(Event* ev, const char* err_msg = nullptr) const override {
    CAFFE_ENFORCE(ev, "Event must not be null.");
    ev->Record(CPU, this, err_msg);
  }

  inline void FinishDeviceComputation() override {}

  inline rand_gen_type* RandGenerator() {
    if (!random_generator_.get()) {
      random_generator_.reset(new rand_gen_type(RandSeed()));
    }
    return random_generator_.get();
  }

  inline uint32_t RandSeed() {
    if (!random_seed_set_) {
      random_seed_ = RandomNumberSeed();
      random_seed_set_ = true;
    }
    return static_cast<uint32_t>(random_seed_);
  }

  inline static at::DataPtr New(size_t nbytes) {
    return GetCPUAllocator()->allocate(nbytes);
  }

  void CopyBytesSameDevice(size_t nbytes, const void* src, void* dst) override;

  void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) override {
    CopyBytesSameDevice(nbytes, src, dst);
  }

  void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) override {
    CopyBytesSameDevice(nbytes, src, dst);
  }

  bool SupportsNonFundamentalTypes() const override {
    // CPU non fumdamental type copy OK
    return true;
  }

  template <class SrcContext, class DstContext>
  inline void CopyBytes(size_t nbytes, const void* src, void* dst);

  template <typename T, class SrcContext, class DstContext>
  inline void Copy(size_t n, const T* src, T* dst) {
    if (c10::guts::is_fundamental<T>::value) {
      CopyBytes<SrcContext, DstContext>(
          n * sizeof(T),
          static_cast<const void*>(src),
          static_cast<void*>(dst));
    } else {
      for (size_t i = 0; i < n; ++i) {
        dst[i] = src[i];
      }
    }
  }

  template <class SrcContext, class DstContext>
  inline void
  CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) {
    if (meta.copy()) {
      meta.copy()(src, dst, n);
    } else {
      CopyBytes<SrcContext, DstContext>(n * meta.itemsize(), src, dst);
    }
  }

  // By default CPU operators don't have async device parts
  static bool HasAsyncPartDefault() {
    return false;
  }

  static bool SupportsAsyncScheduling() {
    return false;
  }

  // CPU streams are not implemented and are silently ignored by CPU ops,
  // return true to signal executor to schedule a CPU op
  static bool IsStreamFree(
      const DeviceOption& /* option */,
      int /* stream_id */) {
    return true;
  }

  at::Device device() const override {
    // TODO: numa?
    return at::Device(CPU);
  }

  DeviceType device_type() const override {
    return CPU;
  }

  static constexpr DeviceType GetDeviceType() {
    return CPU;
  }

 protected:
  // TODO(jiayq): instead of hard-coding a generator, make it more flexible.
  int random_seed_{1701};
  bool random_seed_set_{false};
  std::unique_ptr<rand_gen_type> random_generator_;
};

template <>
inline void CPUContext::CopyBytes<CPUContext, CPUContext>(
    size_t nbytes,
    const void* src,
    void* dst) {
  if (nbytes == 0) {
    return;
  }
  CAFFE_ENFORCE(src);
  CAFFE_ENFORCE(dst);
  memcpy(dst, src, nbytes);
}

} // namespace caffe2

#endif // CAFFE2_CORE_CONTEXT_H_