#ifndef CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
#define CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
#include <exception>
#include <istream>
#include <ostream>
#include <random>
#include <string>
#include <c10/core/thread_pool.h>
#include <caffe2/core/db.h>
#include <caffe2/core/logging.h>
#include <caffe2/operators/prefetch_op.h>
#include <caffe2/utils/math.h>
#include <caffe2/video/video_decoder.h>
#include <caffe2/video/video_io.h>
namespace caffe2 {
template <class Context>
class VideoInputOp final : public PrefetchOperator<Context> {
public:
using OperatorBase::OutputSize;
using PrefetchOperator<Context>::context_;
using PrefetchOperator<Context>::prefetch_thread_;
explicit VideoInputOp(const OperatorDef& operator_def, Workspace* ws);
~VideoInputOp() {
PrefetchOperator<Context>::Finalize();
}
// override methods
bool Prefetch() override;
bool CopyPrefetched() override;
private:
void CheckParamsAndPrint();
bool GetClipsAndLabelsFromDBValue(
const std::string& value,
int& height,
int& width,
std::vector<unsigned char*>& buffer_rgb,
int* label_data,
int64_t* video_id_data,
int* start_frame_data,
std::mt19937* randgen);
void DecodeAndTransform(
const std::string& value,
float* clip_rgb_data,
float* clip_of_data,
int* label_data,
int64_t* video_id_data,
int* start_frame_data,
std::mt19937* randgen,
std::bernoulli_distribution* mirror_this_clip);
void GetLabelsFromProto(const TensorProto& label_proto, int* label_data);
bool GetImageAndLabelsFromDBValue(
const std::string& value,
int& height,
int& width,
std::vector<unsigned char*>& buffer_rgb,
int* label_data);
const db::DBReader* reader_;
CPUContext cpu_context_;
Tensor prefetched_clip_rgb_;
Tensor prefetched_clip_of_;
Tensor prefetched_label_;
Tensor prefetched_video_id_;
Tensor prefetched_start_frame_;
Tensor prefetched_clip_rgb_on_device_{Context::GetDeviceType()};
Tensor prefetched_clip_of_on_device_{Context::GetDeviceType()};
Tensor prefetched_label_on_device_{Context::GetDeviceType()};
Tensor prefetched_video_id_on_device_{Context::GetDeviceType()};
Tensor prefetched_start_frame_on_device_{Context::GetDeviceType()};
int batch_size_;
int clip_per_video_;
std::vector<int> clip_start_positions_;
std::vector<float> mean_rgb_;
std::vector<float> inv_std_rgb_;
std::vector<float> mean_of_;
std::vector<float> inv_std_of_;
int channels_rgb_;
int channels_of_;
int crop_size_;
int scale_h_;
int scale_w_;
int short_edge_;
std::vector<int> jitter_scales_;
int length_rgb_;
int sampling_rate_rgb_;
int random_sampling_rate_;
int num_of_required_frame_;
int length_of_;
int sampling_rate_of_;
int frame_gap_of_;
bool random_mirror_;
int num_of_class_;
bool use_local_file_;
bool random_crop_;
int crop_per_clip_;
int flow_data_type_;
int flow_alg_type_;
int decode_type_;
int video_res_type_;
bool do_flow_aggregation_;
bool image_as_input_;
bool get_rgb_;
bool get_optical_flow_;
bool get_video_id_;
bool get_start_frame_;
bool do_multi_label_;
// thread pool for parse + decode
int num_decode_threads_;
std::shared_ptr<TaskThreadPool> thread_pool_;
};
template <class Context>
void VideoInputOp<Context>::CheckParamsAndPrint() {
// check whether the input parameters are valid or not
CAFFE_ENFORCE_GT(batch_size_, 0, "Batch size should be positive.");
CAFFE_ENFORCE_GT(
clip_per_video_, 0, "Number of clips per video should be positive.");
CAFFE_ENFORCE_GT(crop_size_, 0, "Must provide the cropping value.");
if (!image_as_input_) {
CAFFE_ENFORCE_GT(
num_of_required_frame_,
0,
"Required number of frames must be positive.");
}
if (image_as_input_) {
CAFFE_ENFORCE_EQ(
video_res_type_,
VideoResType::USE_WIDTH_HEIGHT,
"Currently only USE_WIDTH_HEIGHT option is supported with images");
}
if (video_res_type_ == VideoResType::USE_SHORT_EDGE) {
CAFFE_ENFORCE_GT(short_edge_, 0, "Must provide the short edge value.");
CAFFE_ENFORCE_GE(
short_edge_,
crop_size_,
"The short edge must be no smaller than the crop value.");
} else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
CAFFE_ENFORCE_GT(scale_h_, 0, "Must provide the scale height value.");
CAFFE_ENFORCE_GT(scale_w_, 0, "Must provide the scale width value.");
CAFFE_ENFORCE_GE(
scale_h_,
crop_size_,
"The scaled height must be no smaller than the crop value.");
CAFFE_ENFORCE_GE(
scale_w_,
crop_size_,
"The scaled width must be no smaller than the crop value.");
}
if (jitter_scales_.size() > 0) {
CAFFE_ENFORCE_GE(
video_res_type_,
VideoResType::USE_SHORT_EDGE,
"Scale jittering is used with short_edge scaling only");
}
if (get_rgb_) {
CAFFE_ENFORCE_GT(length_rgb_, 0, "Must provide rgb clip length.");
CAFFE_ENFORCE_GT(
sampling_rate_rgb_, 0, "4 frames for mc2; 2 frames for res3d.");
CAFFE_ENFORCE_EQ(
channels_rgb_, mean_rgb_.size(), "Number rgb channels is wrong!");
CAFFE_ENFORCE_EQ(
channels_rgb_, inv_std_rgb_.size(), "Number rgb channels is wrong!");
}
if (get_optical_flow_) {
CAFFE_ENFORCE_GT(length_of_, 0, "Must provide optical flow clip length.");
CAFFE_ENFORCE_GT(
sampling_rate_of_, 0, "4 frames for mc2; 2 frames for res3d.");
CAFFE_ENFORCE_EQ(
channels_of_,
mean_of_.size(),
"Number of optical flow channels is wrong!");
CAFFE_ENFORCE_EQ(
channels_of_,
inv_std_of_.size(),
"Number of optical flow channels is wrong!");
}
if (clip_per_video_ > 1) {
CAFFE_ENFORCE_EQ(
decode_type_,
DecodeType::DO_UNIFORM_SMP,
"Only uniformly sampling is supported when sampling multiple clips!");
}
if (do_multi_label_) {
CAFFE_ENFORCE_GT(
num_of_class_,
0,
"Number of classes must be set when using multiple labels.");
}
// print out the parameter settings
LOG(INFO) << "Creating a clip input op with the following setting: ";
LOG(INFO) << " Input Type: " << (image_as_input_ ? "Image" : "Video");
LOG(INFO) << " Using " << num_decode_threads_ << " CPU threads;";
LOG(INFO) << " Outputting in batches of " << batch_size_ << " videos;";
LOG(INFO) << " Each video has " << clip_per_video_ << " clips;";
LOG(INFO) << " Scaling image to " << scale_h_ << "x" << scale_w_;
LOG(INFO) << " Cropping video frame to " << crop_size_
<< (random_mirror_ ? " with " : " without ") << "random mirroring;";
LOG(INFO) << " Using " << (random_crop_ ? "random" : "center") << " crop";
LOG(INFO) << " Using " << crop_per_clip_ << " spatial crop(s)";
if (get_rgb_) {
LOG(INFO) << " Using a clip of " << length_rgb_ << " rgb frames "
<< "with " << channels_rgb_ << " channels "
<< "and a sampling rate of 1:" << sampling_rate_rgb_;
if (random_sampling_rate_) {
LOG(INFO) << "random sampling with max:" << random_sampling_rate_;
}
for (int i = 0; i < channels_rgb_; i++) {
LOG(INFO) << " RGB " << i << "-th channel mean: " << mean_rgb_[i]
<< " std: " << 1.f / inv_std_rgb_[i];
}
}
if (get_optical_flow_) {
LOG(INFO) << " Using a clip of " << length_of_ << " optical flow frames "
<< "with " << channels_of_ << " channels "
<< "and a sampling rate of 1:" << sampling_rate_of_
<< " flow_data_type_: " << flow_data_type_
<< " flow_alg_type_: " << flow_alg_type_;
for (int i = 0; i < channels_of_; i++) {
LOG(INFO) << " Optical flow" << i
<< "-th channel mean: " << mean_of_[i]
<< " std: " << 1.f / inv_std_of_[i];
}
}
if (video_res_type_ == VideoResType::ORIGINAL_RES) {
LOG(INFO) << " Use original resolution";
} else if (video_res_type_ == VideoResType::USE_SHORT_EDGE) {
LOG(INFO) << " Resize and keep aspect ratio";
} else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
LOG(INFO) << " Resize and ignore aspect ratio";
} else {
LOG(ERROR) << " Unknown video resolution type";
}
if (video_res_type_ == VideoResType::USE_SHORT_EDGE) {
if (jitter_scales_.size() > 0) {
LOG(INFO) << "Using scale jittering:";
for (int idx = 0; idx < jitter_scales_.size(); idx++) {
LOG(INFO) << "scale " << idx << ": " << jitter_scales_[idx];
}
} else {
LOG(INFO) << "No scale jittering is used.";
}
}
if (decode_type_ == DecodeType::DO_TMP_JITTER) {
LOG(INFO) << " Do temporal jittering";
} else if (decode_type_ == DecodeType::USE_START_FRM) {
LOG(INFO) << " Use start_frm for decoding";
} else if (decode_type_ == DecodeType::DO_UNIFORM_SMP) {
LOG(INFO) << " Do uniformly sampling";
} else {
LOG(ERROR) << " Unknown video decoding type";
}
if (get_start_frame_) {
CAFFE_ENFORCE_EQ(
decode_type_,
DecodeType::USE_START_FRM,
"Only decoding with starting frame is supported w/ get start_frame!");
CAFFE_ENFORCE_EQ(
clip_per_video_, 1, "get start frame support only clip per video = 1");
}
}
template <class Context>
VideoInputOp<Context>::VideoInputOp(
const OperatorDef& operator_def,
Workspace* ws)
: PrefetchOperator<Context>(operator_def, ws),
reader_(nullptr),
batch_size_(
OperatorBase::template GetSingleArgument<int>("batch_size", 0)),
clip_per_video_(
OperatorBase::template GetSingleArgument<int>("clip_per_video", 1)),
clip_start_positions_(OperatorBase::template GetRepeatedArgument<int>(
"clip_start_positions",
{})),
channels_rgb_(
OperatorBase::template GetSingleArgument<int>("channels_rgb", 3)),
channels_of_(
OperatorBase::template GetSingleArgument<int>("channels_of", 2)),
crop_size_(OperatorBase::template GetSingleArgument<int>("crop_size", 0)),
scale_h_(OperatorBase::template GetSingleArgument<int>("scale_h", 0)),
scale_w_(OperatorBase::template GetSingleArgument<int>("scale_w", 0)),
short_edge_(
OperatorBase::template GetSingleArgument<int>("short_edge", 0)),
jitter_scales_(
OperatorBase::template GetRepeatedArgument<int>("jitter_scales", {})),
length_rgb_(
OperatorBase::template GetSingleArgument<int>("length_rgb", 0)),
sampling_rate_rgb_(OperatorBase::template GetSingleArgument<int>(
"sampling_rate_rgb",
1)),
random_sampling_rate_(OperatorBase::template GetSingleArgument<int>(
"random_sampling_rate",
0)),
length_of_(OperatorBase::template GetSingleArgument<int>("length_of", 0)),
sampling_rate_of_(
OperatorBase::template GetSingleArgument<int>("sampling_rate_of", 1)),
frame_gap_of_(
OperatorBase::template GetSingleArgument<int>("frame_gap_of", 1)),
random_mirror_(OperatorBase::template GetSingleArgument<bool>(
"random_mirror",
true)),
num_of_class_(
OperatorBase::template GetSingleArgument<int>("num_of_class", 0)),
use_local_file_(OperatorBase::template GetSingleArgument<bool>(
"use_local_file",
false)),
random_crop_(
OperatorBase::template GetSingleArgument<bool>("random_crop", true)),
crop_per_clip_(
OperatorBase::template GetSingleArgument<int>("crop_per_clip", 1)),
flow_data_type_(
OperatorBase::template GetSingleArgument<int>("flow_data_type", 0)),
flow_alg_type_(
OperatorBase::template GetSingleArgument<int>("flow_alg_type", 0)),
decode_type_(
OperatorBase::template GetSingleArgument<int>("decode_type", 0)),
video_res_type_(
OperatorBase::template GetSingleArgument<int>("video_res_type", 0)),
do_flow_aggregation_(OperatorBase::template GetSingleArgument<bool>(
"do_flow_aggregation",
Loading ...