#ifndef CAFFE2_OPERATORS_SPACE_BATCH_OP_H_
#define CAFFE2_OPERATORS_SPACE_BATCH_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename Context>
void spaceToBatch(
const Tensor& input,
int pad_t,
int pad_l,
int block_size,
Tensor* output,
Context* /*context*/) {
CAFFE_ENFORCE(input.dim() == 4);
CAFFE_ENFORCE(output->dim() == 4);
const int output_batch = output->dim32(0);
const int output_depth = output->dim32(1);
const int output_height = output->dim32(2);
const int output_width = output->dim32(3);
const int input_batch = input.dim32(0);
const int input_depth = input.dim32(1);
const int input_height = input.dim32(2);
const int input_width = input.dim32(3);
for (int out_b = 0; out_b < output_batch; ++out_b) {
const int in_b = out_b % input_batch;
const int offset_w = (out_b / input_batch) % block_size;
const int offset_h = (out_b / input_batch) / block_size;
for (int d = 0; d < input_depth; ++d) {
for (int out_h = 0; out_h < output_height; ++out_h) {
const int in_h = out_h * block_size + offset_h - pad_t;
for (int out_w = 0; out_w < output_width; ++out_w) {
const int in_w = out_w * block_size + offset_w - pad_l;
const auto output_offset =
((out_b * output_depth + d) * output_height + out_h) *
output_width +
out_w;
const auto input_offset =
((in_b * input_depth + d) * input_height + in_h) * input_width +
in_w;
if (in_h >= 0 && in_w >= 0 && in_h < input_height &&
in_w < input_width) {
output->template mutable_data<float>()[output_offset] =
input.template data<float>()[input_offset];
} else {
output->template mutable_data<float>()[output_offset] = 0.0;
}
}
}
}
}
}
template <typename Context>
void batchToSpace(
const Tensor& input,
int pad_t,
int pad_l,
int block_size,
Tensor* output,
Context* /*context*/) {
CAFFE_ENFORCE(input.dim() == 4);
CAFFE_ENFORCE(output->dim() == 4);
const int output_batch = output->dim32(0);
const int output_depth = output->dim32(1);
const int output_height = output->dim32(2);
const int output_width = output->dim32(3);
const int input_batch = input.dim32(0);
const int input_depth = input.dim32(1);
const int input_height = input.dim32(2);
const int input_width = input.dim32(3);
CAFFE_ENFORCE(input_depth == output_depth);
for (int in_b = 0; in_b < input_batch; ++in_b) {
const int out_b = in_b % output_batch;
const int offset_w = (in_b / output_batch) % block_size;
const int offset_h = (in_b / output_batch) / block_size;
for (int d = 0; d < input_depth; ++d) {
for (int in_h = 0; in_h < input_height; ++in_h) {
const int out_h = in_h * block_size + offset_h - pad_t;
for (int in_w = 0; in_w < input_width; ++in_w) {
const int out_w = in_w * block_size + offset_w - pad_l;
if (out_h >= 0 && out_w >= 0 && out_h < output_height &&
out_w < output_width) {
const auto output_offset =
((out_b * output_depth + d) * output_height + out_h) *
output_width +
out_w;
const auto input_offset =
((in_b * input_depth + d) * input_height + in_h) * input_width +
in_w;
output->template mutable_data<float>()[output_offset] =
input.template data<float>()[input_offset];
}
}
}
}
}
}
template <typename Context>
class SpaceBatchOpBase : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit SpaceBatchOpBase(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
pad_(this->template GetSingleArgument<int>("pad", 0)),
pad_t_(this->template GetSingleArgument<int>("pad_t", pad_)),
pad_l_(this->template GetSingleArgument<int>("pad", pad_)),
pad_b_(this->template GetSingleArgument<int>("pad", pad_)),
pad_r_(this->template GetSingleArgument<int>("pad", pad_)),
block_size_(this->template GetSingleArgument<int>("block_size", 2)),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))) {
CAFFE_ENFORCE(order_ == StorageOrder::NCHW);
}
protected:
int pad_;
int pad_t_;
int pad_l_;
int pad_b_;
int pad_r_;
int block_size_;
StorageOrder order_;
};
template <typename Context>
class SpaceToBatchOp final : public SpaceBatchOpBase<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
using SpaceBatchOpBase<Context>::SpaceBatchOpBase;
bool RunOnDevice() override {
const auto& input = Input(0);
auto* output = Output(0);
const int batch = input.dim32(0);
const int depth = input.dim32(1);
const int height = this->pad_b_ + this->pad_t_ + input.dim32(2);
const int width = this->pad_l_ + this->pad_r_ + input.dim32(3);
CAFFE_ENFORCE(
height % this->block_size_ == 0,
"Height: ",
height,
", block size: ",
this->block_size_);
CAFFE_ENFORCE(width % this->block_size_ == 0);
const int output_batch = batch * this->block_size_ * this->block_size_;
const int output_height = height / this->block_size_;
const int output_width = width / this->block_size_;
Output(0)->Resize(output_batch, depth, output_height, output_width);
spaceToBatch<Context>(
input,
this->pad_t_,
this->pad_l_,
this->block_size_,
output,
&context_);
return true;
}
};
template <typename Context>
class BatchToSpaceOp final : public SpaceBatchOpBase<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
using SpaceBatchOpBase<Context>::SpaceBatchOpBase;
bool RunOnDevice() override {
const auto& input = Input(0);
auto* output = Output(0);
const int batch = input.dim32(0);
const int depth = input.dim32(1);
const int height = input.dim32(2);
const int width = input.dim32(3);
const int output_batch = batch / this->block_size_ / this->block_size_;
const int output_height =
height * this->block_size_ - this->pad_b_ - this->pad_t_;
const int output_width =
width * this->block_size_ - this->pad_l_ - this->pad_r_;
Output(0)->Resize(output_batch, depth, output_height, output_width);
batchToSpace<Context>(
input,
this->pad_t_,
this->pad_l_,
this->block_size_,
output,
&context_);
return true;
}
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_SPACE_BATCH_OP_H_