Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / onnx_while_op.h

#ifndef CAFFE2_OPERATORS_ONNX_WHILE_OP_H_
#define CAFFE2_OPERATORS_ONNX_WHILE_OP_H_

#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/create_scope_op.h"

namespace caffe2 {

template <class Context>
class ONNXWhileOp final : public Operator<Context> {
 public:
  explicit ONNXWhileOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        parent_ws_(ws),
        has_trip_count_(
            this->template GetSingleArgument<int64_t>("has_trip_count", 0)),
        has_cond_(this->template GetSingleArgument<int64_t>("has_cond", 0)),
        save_scopes_(
            this->template GetSingleArgument<int64_t>("save_scopes", 0)),
        disable_scopes_(
            this->template GetSingleArgument<int64_t>("disable_scopes", 0)),
        num_loop_carried_deps_(this->template GetSingleArgument<int64_t>(
            "num_loop_carried_deps",
            -1)) {
    CAFFE_ENFORCE(
        this->template HasSingleArgumentOfType<NetDef>("body"),
        "body net must be specified in ONNXWhile operator");
    if (disable_scopes_) {
      CAFFE_ENFORCE(
          !save_scopes_, "Cannot save scopes when disable_scopes=True");
    }
    body_net_def_ = this->template GetSingleArgument<NetDef>("body", NetDef());
    static int64_t counter = -1;
    if (!body_net_def_.has_name()) {
      if (counter == -1) {
        ++counter;
        body_net_def_.set_name("loop_net");
      } else {
        ++counter;
        body_net_def_.set_name("loop_net." + c10::to_string(counter));
      }
    }
  }

  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() {
    return DispatchHelper<TensorTypes<int, bool, long>>::call(this, Input(1));
  }

  // Operator
  //  Inputs: max trip count, condition, initial loop-carried dependencies
  //  Outputs: Final loop-carried dependencies, scan_outputs
  // Body
  //  Inputs: iteration number, condition, loop-carried dependencies
  //  Outputs: condition, loop-carried dependencies, scan_outputs
  template <typename CondVarType>
  bool DoRunWithType() {
    // Clear workspaces from the previous invocations of the loop
    // and setup a local scope for the first iteration
    ws_stack_.clear();
    auto loop_ws = !disable_scopes_
        ? ws_stack_.pushForwardWorkspace(parent_ws_).get()
        : parent_ws_;

    constexpr int64_t num_inputs_before_lcds = 2;
    // First input is the maximumt trip count. Second input is the condition
    // variable (for the first iteration). The rest of the inputs are
    // loop-carried dependencies.
    int64_t num_loop_carried_deps;
    if (num_loop_carried_deps_ != -1) {
      num_loop_carried_deps = num_loop_carried_deps_;
    } else {
      num_loop_carried_deps = InputSize() - num_inputs_before_lcds;
    }
    int64_t max_trip_count = *Input(0).template data<int64_t>();
    const bool first_iter_condition = *Input(1).template data<CondVarType>();

    scope_ = std::make_shared<LocalScope>(
        loop_ws, body_net_def_, num_loop_carried_deps);

    // Body graph has 1+N+K outputs: recalculated condition variable, N
    // loop-carried dependencies, and K scan_outputs
    int num_scan_outputs =
        scope_->net()->external_output().size() - num_loop_carried_deps - 1;

    CAFFE_ENFORCE_GE(
        num_scan_outputs,
        0,
        "Body graph must have N+K outputs, where N is the number "
        "of loop-carried dependencies and K is the number of scan "
        "outputs");

    // Copy initial loop-carried dependencies
    for (int i = 0; i < num_loop_carried_deps; ++i) {
      scope_->lcd_tensor(i)->CopyFrom(Input(i + num_inputs_before_lcds));
    }

    // Initialize iteration variable
    scope_->set_iteration(0ll);

    // Initialize input condition variable
    scope_->template set_input_condition<CondVarType>(first_iter_condition);

    auto valid_iter_num = [this, max_trip_count](int64_t i) {
      if (has_trip_count_) {
        return i < max_trip_count;
      } else {
        return true;
      }
    };

    auto condition_true = [this, first_iter_condition](
                              int64_t i, bool cond_value) {
      if (has_cond_) {
        if (i == 0) {
          return (bool)first_iter_condition;
        } else {
          return cond_value;
        }
      } else {
        return true;
      }
    };

    // Allocate scan_outputs for zero-iteration case
    for (int i = 0; i < num_scan_outputs; ++i) {
      Output(i + num_loop_carried_deps)->Resize(0);
      Output(i + num_loop_carried_deps)->template mutable_data<int32_t>();
    }

    // Use this to keep track of the sizes of the scan outputs and validate
    // they're the same across iterations.
    std::vector<std::vector<int64_t>> scan_outputs_sizes;

    Workspace* cur_ws = nullptr;
    bool cur_output_condition = false;

    while (true) {
      int64_t itr = scope_->iteration();
      if (valid_iter_num(itr) && condition_true(itr, cur_output_condition)) {
        if (!scope_->net()->Run()) {
          return false;
        }

        cur_ws = scope_->workspace();
        cur_output_condition = scope_->template output_condition<CondVarType>();
        if (save_scopes_) {
          loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_).get();
          scope_ = std::make_shared<LocalScope>(
              loop_ws, body_net_def_, num_loop_carried_deps);
        }

        // Copy forward loop-carried dependencies
        for (int i = 0; i < num_loop_carried_deps; ++i) {
          Blob* b = cur_ws->GetBlob(scope_->net()->external_output()[i + 1]);
          const Tensor& t = b->template Get<Tensor>();
          scope_->lcd_tensor(i)->CopyFrom(t);
        }
        // Copy out scan_outputs
        for (int i = 0; i < num_scan_outputs; ++i) {
          int net_output_idx = i + 1 + num_loop_carried_deps;
          const Tensor& scan_output =
              cur_ws->GetBlob(scope_->net()->external_output()[net_output_idx])
                  ->template Get<Tensor>();
          auto* scan_output_target = Output(i + num_loop_carried_deps);
          if (itr == 0) {
            auto dims = scan_output.sizes().vec();
            scan_outputs_sizes.push_back(dims);
            dims.insert(dims.begin(), 1);
            scan_output_target->Resize(dims);
            scan_output_target->CopyFrom(scan_output);
          } else {
            auto dims = scan_output.sizes().vec();
            CAFFE_ENFORCE_EQ(
                dims,
                scan_outputs_sizes[i],
                "Size of scan output changed across iterations");
            dims.insert(dims.begin(), itr);
            scan_output_target->Extend(1, 100);

            int64_t timestep_size = 1;
            for (const int64_t t : scan_outputs_sizes[i]) {
              timestep_size *= t;
            }

            const void* src_data = scan_output.raw_data();
            auto& sot_meta = scan_output_target->dtype();
            void* dst_data =
                (char*)scan_output_target->raw_mutable_data(sot_meta) +
                timestep_size * scan_output.itemsize() * itr;
            memcpy(dst_data, src_data, timestep_size * scan_output.itemsize());
          }
        }
        scope_->set_iteration(itr + 1ll);
        scope_->template set_input_condition<CondVarType>(cur_output_condition);
      } else {
        break;
      }
    }

    // Copy out final loop-carried dependencies
    for (int i = 0; i < num_loop_carried_deps; ++i) {
      Output(i)->CopyFrom(*scope_->lcd_tensor(i));
    }

    return true;
  }

 private:
  class LocalScope {
   public:
    LocalScope(Workspace* loop_ws, const NetDef& body_net_def, size_t num_lcds)
        : loop_ws_(loop_ws) {
      CAFFE_ENFORCE(loop_ws_, "Failed to initialize local loop workspace");

      // Create loop-carried deps in Workspace
      lcd_tensors_.clear();
      for (int i = 2; i < num_lcds + 2; ++i) {
        Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i));
        Tensor* t = BlobGetMutableTensor(b, Context::GetDeviceType());
        lcd_tensors_.push_back(t);
      }
      // First output is the iteration variable
      auto* iteration_var_blob =
          loop_ws_->CreateBlob(body_net_def.external_input(0));
      iteration_var_ =
          BlobGetMutableTensor(iteration_var_blob, Context::GetDeviceType());

      input_condition_var_ = BlobGetMutableTensor(
          loop_ws_->CreateBlob(body_net_def.external_input(1)),
          Context::GetDeviceType());

      auto* condition_var_blob =
          loop_ws_->CreateBlob(body_net_def.external_output(0));
      condition_var_ =
          BlobGetMutableTensor(condition_var_blob, Context::GetDeviceType());
      condition_var_->Resize(1);
      condition_var_->template mutable_data<bool>();

      body_net_ = loop_ws_->GetNet(body_net_def.name());
      if (!body_net_) {
        body_net_ = loop_ws_->CreateNet(body_net_def, true);
      }
      CAFFE_ENFORCE(body_net_, "Failed to initialize loop subnet");
    }

    NetBase* net() const {
      return body_net_;
    }

    Workspace* workspace() const {
      return loop_ws_;
    }

    int64_t iteration() const {
      auto* iteration_var_ptr =
          iteration_var_->template mutable_data<int64_t>();
      return *iteration_var_ptr;
    }

    Tensor* lcd_tensor(int idx) {
      return lcd_tensors_[idx];
    }

    void set_iteration(int64_t itr) {
      iteration_var_->Resize();
      auto* iteration_var_ptr =
          iteration_var_->template mutable_data<int64_t>();
      *iteration_var_ptr = itr;
    }

    template <typename CondVarType>
    void set_input_condition(bool cond_value) {
      input_condition_var_->Resize(1);
      auto* input_condition_var_ptr =
          input_condition_var_->template mutable_data<CondVarType>();
      *input_condition_var_ptr = cond_value;
    }

    template <typename CondVarType>
    bool output_condition() const {
      auto* condition_var_ptr =
          condition_var_->template mutable_data<CondVarType>();
      return *condition_var_ptr;
    }

   private:
    Workspace* loop_ws_;

    NetBase* body_net_; // owned by a workspace
    Tensor* iteration_var_;
    Tensor* input_condition_var_;
    Tensor* condition_var_;

    std::vector<Tensor*> lcd_tensors_;
  };

  NetDef body_net_def_;
  Workspace* parent_ws_;
  detail::WorkspaceStack ws_stack_;

  bool has_trip_count_;
  bool has_cond_;
  bool save_scopes_;
  bool disable_scopes_;
  int64_t num_loop_carried_deps_;

  std::shared_ptr<LocalScope> scope_;
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_ONNX_WHILE_OP_H