Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / counter_ops.h

#ifndef CAFFE2_OPERATORS_COUNTER_OPS_H
#define CAFFE2_OPERATORS_COUNTER_OPS_H

#include <atomic>

#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"

namespace caffe2 {
template <typename T>
class TORCH_API Counter {
 public:
  explicit Counter(T count) : count_(count) {}
  bool countDown() {
    if (count_-- > 0) {
      return false;
    }
    return true;
  }

  T countUp() {
    return count_++;
  }

  T retrieve() const {
    return count_.load();
  }

  T checkIfDone() const {
    return (count_.load() <= 0);
  }

  T reset(T init_count) {
    return count_.exchange(init_count);
  }

 private:
  std::atomic<T> count_;
};

// TODO(jiayq): deprecate these ops & consolidate them with IterOp/AtomicIterOp

template <typename T, class Context>
class CreateCounterOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit CreateCounterOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        init_count_(this->template GetSingleArgument<T>("init_count", 0)) {
    CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
  }

  bool RunOnDevice() override {
    *this->template Output<std::unique_ptr<Counter<T>>>(0) =
        std::unique_ptr<Counter<T>>(new Counter<T>(init_count_));
    return true;
  }

 private:
  T init_count_ = 0;
};

template <typename T, class Context>
class ResetCounterOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit ResetCounterOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        init_count_(this->template GetSingleArgument<T>("init_count", 0)) {
    CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
  }

  bool RunOnDevice() override {
    auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
    auto previous = counterPtr->reset(init_count_);
    if (OutputSize() == 1) {
      auto* output = Output(0);
      output->Resize();
      *output->template mutable_data<T>() = previous;
    }
    return true;
  }

 private:
  T init_count_;
};

// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class CountDownOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit CountDownOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}

  bool RunOnDevice() override {
    auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
    auto* output = Output(0);
    output->Resize(std::vector<int>{});
    *output->template mutable_data<bool>() = counterPtr->countDown();
    return true;
  }
};

// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class CheckCounterDoneOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit CheckCounterDoneOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}

  bool RunOnDevice() override {
    auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
    auto* output = Output(0);
    output->Resize(std::vector<int>{});
    *output->template mutable_data<bool>() = counterPtr->checkIfDone();
    return true;
  }
};

// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class CountUpOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit CountUpOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}

  bool RunOnDevice() override {
    auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
    auto* output = Output(0);
    output->Resize(std::vector<int>{});
    *output->template mutable_data<T>() = counterPtr->countUp();
    return true;
  }
};

// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class RetrieveCountOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit RetrieveCountOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}

  bool RunOnDevice() override {
    auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
    auto* output = Output(0);
    output->Resize(std::vector<int>{});
    *output->template mutable_data<T>() = counterPtr->retrieve();
    return true;
  }
};

} // namespace caffe2
#endif // CAFFE2_OPERATORS_COUNTER_OPS_H_