#ifndef CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
#define CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include "caffe2/core/operator.h"
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(PiecewiseLinearTransform);
namespace caffe2 {
template <typename T, class Context>
class PiecewiseLinearTransformOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit PiecewiseLinearTransformOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {
binary_ = this->template GetSingleArgument<bool>("binary", false);
// Retrieve transform params (i.e., the linear functions).
bounds_from_arg_ = this->template GetRepeatedArgument<T>("bounds");
slopes_from_arg_ = this->template GetRepeatedArgument<T>("slopes");
intercepts_from_arg_ = this->template GetRepeatedArgument<T>("intercepts");
transform_param_from_arg_ = CheckTransParamFromArg();
}
bool RunOnDevice() override {
return binary_ ? TransformBinary() : TransformGeneral();
}
private:
// num_func_per_group is the number of pieces of linear functions of
// each group.
// num_group: The number of groups of linear functions. Each group is for
// transforming one column of predictions.
void InferNumFunctionsPerGroup(
const int64_t num_bounds,
const int64_t num_slopes,
const int64_t num_intercepts,
int64_t* num_func_per_group,
int64_t* num_group) {
CAFFE_ENFORCE_EQ(num_slopes, num_intercepts);
// This is based on the facts:
// 1. in each group, the num of bounds minus the num of slopes is 1;
// 2. each group has the same number of pieces.
*num_group = num_bounds - num_slopes;
CAFFE_ENFORCE_GT(*num_group, 0);
if (binary_) {
CAFFE_ENFORCE_EQ(*num_group, 1);
}
*num_func_per_group = num_slopes / *num_group;
CAFFE_ENFORCE_GT(*num_func_per_group, 0);
CAFFE_ENFORCE_EQ(num_slopes % *num_group, 0);
}
bool CheckBoundsSorted(
const T* bounds,
const int64_t num_bounds_per_group,
const int64_t num_group) {
const T* start = bounds;
for (int64_t i = 0; i < num_group; i++) {
if (!std::is_sorted(start, start + num_bounds_per_group)) {
return false;
}
start += num_bounds_per_group;
}
return true;
}
// Returns true if the transform params from arg are valid.
// Otherwise, we will assume the transform params will pass from Input blobs.
bool CheckTransParamFromArg() {
int good_param = 0;
good_param += bounds_from_arg_.size() > 0;
good_param += slopes_from_arg_.size() > 0;
good_param += intercepts_from_arg_.size() > 0;
CAFFE_ENFORCE(
good_param == 0 || good_param == 3,
"bounds, slopes, intercepts must be all set or all not set");
if (good_param == 3) {
int64_t num_func_per_group;
int64_t num_group;
InferNumFunctionsPerGroup(
bounds_from_arg_.size(),
slopes_from_arg_.size(),
intercepts_from_arg_.size(),
&num_func_per_group,
&num_group);
CAFFE_ENFORCE(
CheckBoundsSorted(
bounds_from_arg_.data(), num_func_per_group + 1, num_group),
"bounds must be sorted for each group");
}
return good_param == 3;
}
void setUpTensors(int64_t& num_func_per_group, int64_t& num_group, int64_t M);
void GetTransParamData(
const T** bounds,
const T** slopes,
const T** intercepts,
int64_t* num_func_per_group,
int64_t* num_group) {
int64_t num_bounds;
int64_t num_slopes;
int64_t num_intercepts;
if (transform_param_from_arg_) {
CAFFE_ENFORCE_EQ(InputSize(), 1);
*bounds = bounds_from_arg_.data();
*slopes = slopes_from_arg_.data();
*intercepts = intercepts_from_arg_.data();
num_bounds = bounds_from_arg_.size();
num_slopes = slopes_from_arg_.size();
num_intercepts = intercepts_from_arg_.size();
} else {
CAFFE_ENFORCE_EQ(InputSize(), 4);
auto& bounds_input = Input(BOUNDS);
auto& slopes_input = Input(SLOPES);
auto& intercepts_input = Input(INTERCEPTS);
*bounds = bounds_input.template data<T>();
*slopes = slopes_input.template data<T>();
*intercepts = intercepts_input.template data<T>();
num_bounds = bounds_input.numel();
num_slopes = slopes_input.numel();
num_intercepts = intercepts_input.numel();
}
InferNumFunctionsPerGroup(
num_bounds, num_slopes, num_intercepts, num_func_per_group, num_group);
}
bool TransformGeneral() {
auto& X = Input(0);
CAFFE_ENFORCE_EQ(X.dim(), 2);
int64_t N = X.dim32(0);
int64_t M = X.dim32(1);
auto* Y = Output(0, X.sizes(), at::dtype<T>());
const auto* Xdata = X.template data<T>();
T* Ydata = Y->template mutable_data<T>();
const T* bounds;
const T* slopes;
const T* intercepts;
int64_t num_func_per_group;
int64_t num_group;
GetTransParamData(
&bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
CAFFE_ENFORCE_EQ(num_group, M);
for (int64_t j = 0; j < M; ++j) {
const T* bounds_group = bounds + j * (num_func_per_group + 1);
const T* slopes_group = slopes + j * num_func_per_group;
const T* intercepts_group = intercepts + j * num_func_per_group;
for (int64_t i = 0; i < N; ++i) {
Ydata[i * M + j] = PiecewiseLinearTransform(
Xdata[i * M + j],
bounds_group,
slopes_group,
intercepts_group,
num_func_per_group);
}
}
return true;
}
bool TransformBinary() {
auto& X = Input(PREDICTIONS);
CAFFE_ENFORCE(X.dim() == 1 || X.dim() == 2);
int64_t N = X.dim32(0);
int64_t M = X.dim() == 2 ? X.dim32(1) : 1;
CAFFE_ENFORCE(
M == 1 || M == 2,
"If binary is set to true, the input must be Nx2 or Nx1 tensor");
auto* Y = Output(0, X.sizes(), at::dtype<T>());
const auto* Xdata = X.template data<T>();
T* Ydata = Y->template mutable_data<T>();
const T* bounds;
const T* slopes;
const T* intercepts;
int64_t num_func_per_group;
int64_t num_group;
GetTransParamData(
&bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
CAFFE_ENFORCE_EQ(num_group, 1);
if (M == 1) {
for (int64_t i = 0; i < N; ++i) {
Ydata[i] = PiecewiseLinearTransform(
Xdata[i], bounds, slopes, intercepts, num_func_per_group);
}
} else {
for (int64_t i = 0; i < N; ++i) {
Ydata[i * M + 1] = PiecewiseLinearTransform(
Xdata[i * M + 1], bounds, slopes, intercepts, num_func_per_group);
Ydata[i * M] = 1.0f - Ydata[i * M + 1];
}
}
return true;
}
T PiecewiseLinearTransform(
const T x,
const T* bounds,
const T* slopes,
const T* intercepts,
const int64_t num_func_per_group) {
T y = 0;
// deal with samples out of bounds
// make it the same as the upper/lower bound value
if (x <= bounds[0]) {
y = slopes[0] * bounds[0] + intercepts[0];
} else if (x >= bounds[num_func_per_group]) {
y = slopes[num_func_per_group - 1] * bounds[num_func_per_group] +
intercepts[num_func_per_group - 1];
} else {
auto low_bound =
std::lower_bound(bounds, bounds + num_func_per_group + 1, x);
int bounds_idx = low_bound - bounds - 1;
// compute the piecewise linear transformation as Y
y = slopes[bounds_idx] * x + intercepts[bounds_idx];
}
return y;
}
private:
bool binary_;
vector<T> bounds_from_arg_;
vector<T> slopes_from_arg_;
vector<T> intercepts_from_arg_;
Tensor bounds_device_{Context::GetDeviceType()};
Tensor intercepts_device_{Context::GetDeviceType()};
Tensor slopes_device_{Context::GetDeviceType()};
bool gpu_copied_ = false;
// If true, the piecewise linear functions are passed through args,
// otherwise, they are passed through Input blobs.
bool transform_param_from_arg_;
INPUT_TAGS(PREDICTIONS, BOUNDS, SLOPES, INTERCEPTS);
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_