#ifndef CAFFE2_OPERATORS_ONNX_WHILE_OP_H_
#define CAFFE2_OPERATORS_ONNX_WHILE_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/create_scope_op.h"
namespace caffe2 {
template <class Context>
class ONNXWhileOp final : public Operator<Context> {
public:
explicit ONNXWhileOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
parent_ws_(ws),
has_trip_count_(
this->template GetSingleArgument<int64_t>("has_trip_count", 0)),
has_cond_(this->template GetSingleArgument<int64_t>("has_cond", 0)),
save_scopes_(
this->template GetSingleArgument<int64_t>("save_scopes", 0)),
disable_scopes_(
this->template GetSingleArgument<int64_t>("disable_scopes", 0)),
num_loop_carried_deps_(this->template GetSingleArgument<int64_t>(
"num_loop_carried_deps",
-1)) {
CAFFE_ENFORCE(
this->template HasSingleArgumentOfType<NetDef>("body"),
"body net must be specified in ONNXWhile operator");
if (disable_scopes_) {
CAFFE_ENFORCE(
!save_scopes_, "Cannot save scopes when disable_scopes=True");
}
body_net_def_ = this->template GetSingleArgument<NetDef>("body", NetDef());
static int64_t counter = -1;
if (!body_net_def_.has_name()) {
if (counter == -1) {
++counter;
body_net_def_.set_name("loop_net");
} else {
++counter;
body_net_def_.set_name("loop_net." + c10::to_string(counter));
}
}
}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() {
return DispatchHelper<TensorTypes<int, bool, long>>::call(this, Input(1));
}
// Operator
// Inputs: max trip count, condition, initial loop-carried dependencies
// Outputs: Final loop-carried dependencies, scan_outputs
// Body
// Inputs: iteration number, condition, loop-carried dependencies
// Outputs: condition, loop-carried dependencies, scan_outputs
template <typename CondVarType>
bool DoRunWithType() {
// Clear workspaces from the previous invocations of the loop
// and setup a local scope for the first iteration
ws_stack_.clear();
auto loop_ws = !disable_scopes_
? ws_stack_.pushForwardWorkspace(parent_ws_).get()
: parent_ws_;
constexpr int64_t num_inputs_before_lcds = 2;
// First input is the maximumt trip count. Second input is the condition
// variable (for the first iteration). The rest of the inputs are
// loop-carried dependencies.
int64_t num_loop_carried_deps;
if (num_loop_carried_deps_ != -1) {
num_loop_carried_deps = num_loop_carried_deps_;
} else {
num_loop_carried_deps = InputSize() - num_inputs_before_lcds;
}
int64_t max_trip_count = *Input(0).template data<int64_t>();
const bool first_iter_condition = *Input(1).template data<CondVarType>();
scope_ = std::make_shared<LocalScope>(
loop_ws, body_net_def_, num_loop_carried_deps);
// Body graph has 1+N+K outputs: recalculated condition variable, N
// loop-carried dependencies, and K scan_outputs
int num_scan_outputs =
scope_->net()->external_output().size() - num_loop_carried_deps - 1;
CAFFE_ENFORCE_GE(
num_scan_outputs,
0,
"Body graph must have N+K outputs, where N is the number "
"of loop-carried dependencies and K is the number of scan "
"outputs");
// Copy initial loop-carried dependencies
for (int i = 0; i < num_loop_carried_deps; ++i) {
scope_->lcd_tensor(i)->CopyFrom(Input(i + num_inputs_before_lcds));
}
// Initialize iteration variable
scope_->set_iteration(0ll);
// Initialize input condition variable
scope_->template set_input_condition<CondVarType>(first_iter_condition);
auto valid_iter_num = [this, max_trip_count](int64_t i) {
if (has_trip_count_) {
return i < max_trip_count;
} else {
return true;
}
};
auto condition_true = [this, first_iter_condition](
int64_t i, bool cond_value) {
if (has_cond_) {
if (i == 0) {
return (bool)first_iter_condition;
} else {
return cond_value;
}
} else {
return true;
}
};
// Allocate scan_outputs for zero-iteration case
for (int i = 0; i < num_scan_outputs; ++i) {
Output(i + num_loop_carried_deps)->Resize(0);
Output(i + num_loop_carried_deps)->template mutable_data<int32_t>();
}
// Use this to keep track of the sizes of the scan outputs and validate
// they're the same across iterations.
std::vector<std::vector<int64_t>> scan_outputs_sizes;
Workspace* cur_ws = nullptr;
bool cur_output_condition = false;
while (true) {
int64_t itr = scope_->iteration();
if (valid_iter_num(itr) && condition_true(itr, cur_output_condition)) {
if (!scope_->net()->Run()) {
return false;
}
cur_ws = scope_->workspace();
cur_output_condition = scope_->template output_condition<CondVarType>();
if (save_scopes_) {
loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_).get();
scope_ = std::make_shared<LocalScope>(
loop_ws, body_net_def_, num_loop_carried_deps);
}
// Copy forward loop-carried dependencies
for (int i = 0; i < num_loop_carried_deps; ++i) {
Blob* b = cur_ws->GetBlob(scope_->net()->external_output()[i + 1]);
const Tensor& t = b->template Get<Tensor>();
scope_->lcd_tensor(i)->CopyFrom(t);
}
// Copy out scan_outputs
for (int i = 0; i < num_scan_outputs; ++i) {
int net_output_idx = i + 1 + num_loop_carried_deps;
const Tensor& scan_output =
cur_ws->GetBlob(scope_->net()->external_output()[net_output_idx])
->template Get<Tensor>();
auto* scan_output_target = Output(i + num_loop_carried_deps);
if (itr == 0) {
auto dims = scan_output.sizes().vec();
scan_outputs_sizes.push_back(dims);
dims.insert(dims.begin(), 1);
scan_output_target->Resize(dims);
scan_output_target->CopyFrom(scan_output);
} else {
auto dims = scan_output.sizes().vec();
CAFFE_ENFORCE_EQ(
dims,
scan_outputs_sizes[i],
"Size of scan output changed across iterations");
dims.insert(dims.begin(), itr);
scan_output_target->Extend(1, 100);
int64_t timestep_size = 1;
for (const int64_t t : scan_outputs_sizes[i]) {
timestep_size *= t;
}
const void* src_data = scan_output.raw_data();
auto& sot_meta = scan_output_target->dtype();
void* dst_data =
(char*)scan_output_target->raw_mutable_data(sot_meta) +
timestep_size * scan_output.itemsize() * itr;
memcpy(dst_data, src_data, timestep_size * scan_output.itemsize());
}
}
scope_->set_iteration(itr + 1ll);
scope_->template set_input_condition<CondVarType>(cur_output_condition);
} else {
break;
}
}
// Copy out final loop-carried dependencies
for (int i = 0; i < num_loop_carried_deps; ++i) {
Output(i)->CopyFrom(*scope_->lcd_tensor(i));
}
return true;
}
private:
class LocalScope {
public:
LocalScope(Workspace* loop_ws, const NetDef& body_net_def, size_t num_lcds)
: loop_ws_(loop_ws) {
CAFFE_ENFORCE(loop_ws_, "Failed to initialize local loop workspace");
// Create loop-carried deps in Workspace
lcd_tensors_.clear();
for (int i = 2; i < num_lcds + 2; ++i) {
Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i));
Tensor* t = BlobGetMutableTensor(b, Context::GetDeviceType());
lcd_tensors_.push_back(t);
}
// First output is the iteration variable
auto* iteration_var_blob =
loop_ws_->CreateBlob(body_net_def.external_input(0));
iteration_var_ =
BlobGetMutableTensor(iteration_var_blob, Context::GetDeviceType());
input_condition_var_ = BlobGetMutableTensor(
loop_ws_->CreateBlob(body_net_def.external_input(1)),
Context::GetDeviceType());
auto* condition_var_blob =
loop_ws_->CreateBlob(body_net_def.external_output(0));
condition_var_ =
BlobGetMutableTensor(condition_var_blob, Context::GetDeviceType());
condition_var_->Resize(1);
condition_var_->template mutable_data<bool>();
body_net_ = loop_ws_->GetNet(body_net_def.name());
if (!body_net_) {
body_net_ = loop_ws_->CreateNet(body_net_def, true);
}
CAFFE_ENFORCE(body_net_, "Failed to initialize loop subnet");
}
NetBase* net() const {
return body_net_;
}
Workspace* workspace() const {
return loop_ws_;
}
int64_t iteration() const {
auto* iteration_var_ptr =
iteration_var_->template mutable_data<int64_t>();
return *iteration_var_ptr;
}
Tensor* lcd_tensor(int idx) {
return lcd_tensors_[idx];
}
void set_iteration(int64_t itr) {
iteration_var_->Resize();
auto* iteration_var_ptr =
iteration_var_->template mutable_data<int64_t>();
*iteration_var_ptr = itr;
}
template <typename CondVarType>
void set_input_condition(bool cond_value) {
input_condition_var_->Resize(1);
auto* input_condition_var_ptr =
input_condition_var_->template mutable_data<CondVarType>();
*input_condition_var_ptr = cond_value;
}
template <typename CondVarType>
bool output_condition() const {
auto* condition_var_ptr =
condition_var_->template mutable_data<CondVarType>();
return *condition_var_ptr;
}
private:
Workspace* loop_ws_;
NetBase* body_net_; // owned by a workspace
Tensor* iteration_var_;
Tensor* input_condition_var_;
Tensor* condition_var_;
std::vector<Tensor*> lcd_tensors_;
};
NetDef body_net_def_;
Workspace* parent_ws_;
detail::WorkspaceStack ws_stack_;
bool has_trip_count_;
bool has_cond_;
bool save_scopes_;
bool disable_scopes_;
int64_t num_loop_carried_deps_;
std::shared_ptr<LocalScope> scope_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_ONNX_WHILE_OP_H