Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / sgd / iter_op.h

#ifndef CAFFE2_SGD_ITER_OP_H_
#define CAFFE2_SGD_ITER_OP_H_

#include <limits>
#include <mutex>

#include "caffe2/core/blob_serialization.h"
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/stats.h"

namespace caffe2 {

inline void IncrementIter(TensorCPU* output) {
  CAFFE_ENFORCE_EQ(
      output->numel(),
      1,
      "The output of IterOp exists, but not of the right size.");
  int64_t* iter = output->template mutable_data<int64_t>();
  CAFFE_ENFORCE(*iter >= 0, "Previous iteration number is negative.");
  CAFFE_ENFORCE(
      *iter < std::numeric_limits<int64_t>::max(), "Overflow will happen!");
  (*iter)++;
}

// IterOp runs an iteration counter. I cannot think of a case where we would
// need to access the iter variable on device, so this will always produce a
// tensor on the CPU side. If the blob already exists and is a tensor<int64_t>
// object, we will simply increment it (this emulates the case when we want to
// resume training). Otherwise we will have the iter starting with 0.
template <class Context>
class IterOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  IterOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws) {}

  bool RunOnDevice() override {
    if (InputSize() == 0) {
      VLOG(1) << "[Input size is zero]";
      if (!OperatorBase::OutputIsTensorType(0, CPU)) {
        // This is the first run; set the iter to start with 0.
        LOG(ERROR) << "You are using an old definition of IterOp that will "
                      "be deprecated soon. More specifically, IterOp now "
                      "requires an explicit in-place input and output.";

        VLOG(1) << "Initializing iter counter.";
        auto* output = OperatorBase::OutputTensor(
            0, {1}, at::dtype<int64_t>().device(CPU));
        output->template mutable_data<int64_t>()[0] = 0;
      }
    }
    IncrementIter(OperatorBase::Output<Tensor>(0, CPU));
    return true;
  }
};

template <class Context>
class AtomicIterOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  AtomicIterOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        stats_(std::string("atomic_iter/stats/") + operator_def.input(1)) {}

  bool RunOnDevice() override {
    auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(0);
    std::lock_guard<std::mutex> lg(*mutex);
    IncrementIter(OperatorBase::Output<Tensor>(0, CPU));
    CAFFE_EVENT(stats_, num_iter);
    return true;
  }

 private:
  struct AtomicIterOpStats {
    CAFFE_STAT_CTOR(AtomicIterOpStats);
    CAFFE_EXPORTED_STAT(num_iter);
  } stats_;
};

class MutexSerializer : public BlobSerializerBase {
 public:
  /**
   * Serializes a std::unique_ptr<std::mutex>. Note that this blob has to
   * contain std::unique_ptr<std::mutex>, otherwise this function produces a
   * fatal error.
   */
  void Serialize(
      const void* pointer,
      TypeMeta typeMeta,
      const string& name,
      BlobSerializerBase::SerializationAcceptor acceptor) override;
};

class MutexDeserializer : public BlobDeserializerBase {
 public:
  void Deserialize(const BlobProto& proto, Blob* blob) override;
};

} // namespace caffe2

#endif // CAFFE2_SGD_ITER_OP_H_