Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / video / video_input_op.h

#ifndef CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
#define CAFFE2_VIDEO_VIDEO_INPUT_OP_H_

#include <exception>
#include <istream>
#include <ostream>
#include <random>
#include <string>

#include <c10/core/thread_pool.h>
#include <caffe2/core/db.h>
#include <caffe2/core/logging.h>
#include <caffe2/operators/prefetch_op.h>
#include <caffe2/utils/math.h>
#include <caffe2/video/video_decoder.h>
#include <caffe2/video/video_io.h>

namespace caffe2 {

template <class Context>
class VideoInputOp final : public PrefetchOperator<Context> {
 public:
  using OperatorBase::OutputSize;
  using PrefetchOperator<Context>::context_;
  using PrefetchOperator<Context>::prefetch_thread_;
  explicit VideoInputOp(const OperatorDef& operator_def, Workspace* ws);
  ~VideoInputOp() {
    PrefetchOperator<Context>::Finalize();
  }

  // override methods
  bool Prefetch() override;
  bool CopyPrefetched() override;

 private:
  void CheckParamsAndPrint();

  bool GetClipsAndLabelsFromDBValue(
      const std::string& value,
      int& height,
      int& width,
      std::vector<unsigned char*>& buffer_rgb,
      int* label_data,
      int64_t* video_id_data,
      int* start_frame_data,
      std::mt19937* randgen);

  void DecodeAndTransform(
      const std::string& value,
      float* clip_rgb_data,
      float* clip_of_data,
      int* label_data,
      int64_t* video_id_data,
      int* start_frame_data,
      std::mt19937* randgen,
      std::bernoulli_distribution* mirror_this_clip);

  void GetLabelsFromProto(const TensorProto& label_proto, int* label_data);

  bool GetImageAndLabelsFromDBValue(
      const std::string& value,
      int& height,
      int& width,
      std::vector<unsigned char*>& buffer_rgb,
      int* label_data);

  const db::DBReader* reader_;
  CPUContext cpu_context_;
  Tensor prefetched_clip_rgb_;
  Tensor prefetched_clip_of_;
  Tensor prefetched_label_;
  Tensor prefetched_video_id_;
  Tensor prefetched_start_frame_;
  Tensor prefetched_clip_rgb_on_device_{Context::GetDeviceType()};
  Tensor prefetched_clip_of_on_device_{Context::GetDeviceType()};
  Tensor prefetched_label_on_device_{Context::GetDeviceType()};
  Tensor prefetched_video_id_on_device_{Context::GetDeviceType()};
  Tensor prefetched_start_frame_on_device_{Context::GetDeviceType()};

  int batch_size_;
  int clip_per_video_;
  std::vector<int> clip_start_positions_;
  std::vector<float> mean_rgb_;
  std::vector<float> inv_std_rgb_;
  std::vector<float> mean_of_;
  std::vector<float> inv_std_of_;
  int channels_rgb_;
  int channels_of_;
  int crop_size_;
  int scale_h_;
  int scale_w_;
  int short_edge_;
  std::vector<int> jitter_scales_;
  int length_rgb_;
  int sampling_rate_rgb_;
  int random_sampling_rate_;
  int num_of_required_frame_;
  int length_of_;
  int sampling_rate_of_;
  int frame_gap_of_;
  bool random_mirror_;
  int num_of_class_;
  bool use_local_file_;
  bool random_crop_;
  int crop_per_clip_;
  int flow_data_type_;
  int flow_alg_type_;
  int decode_type_;
  int video_res_type_;
  bool do_flow_aggregation_;
  bool image_as_input_;
  bool get_rgb_;
  bool get_optical_flow_;
  bool get_video_id_;
  bool get_start_frame_;
  bool do_multi_label_;

  // thread pool for parse + decode
  int num_decode_threads_;
  std::shared_ptr<TaskThreadPool> thread_pool_;
};

template <class Context>
void VideoInputOp<Context>::CheckParamsAndPrint() {
  // check whether the input parameters are valid or not
  CAFFE_ENFORCE_GT(batch_size_, 0, "Batch size should be positive.");
  CAFFE_ENFORCE_GT(
      clip_per_video_, 0, "Number of clips per video should be positive.");
  CAFFE_ENFORCE_GT(crop_size_, 0, "Must provide the cropping value.");

  if (!image_as_input_) {
    CAFFE_ENFORCE_GT(
        num_of_required_frame_,
        0,
        "Required number of frames must be positive.");
  }

  if (image_as_input_) {
    CAFFE_ENFORCE_EQ(
        video_res_type_,
        VideoResType::USE_WIDTH_HEIGHT,
        "Currently only USE_WIDTH_HEIGHT option is supported with images");
  }

  if (video_res_type_ == VideoResType::USE_SHORT_EDGE) {
    CAFFE_ENFORCE_GT(short_edge_, 0, "Must provide the short edge value.");
    CAFFE_ENFORCE_GE(
        short_edge_,
        crop_size_,
        "The short edge must be no smaller than the crop value.");
  } else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
    CAFFE_ENFORCE_GT(scale_h_, 0, "Must provide the scale height value.");
    CAFFE_ENFORCE_GT(scale_w_, 0, "Must provide the scale width value.");
    CAFFE_ENFORCE_GE(
        scale_h_,
        crop_size_,
        "The scaled height must be no smaller than the crop value.");
    CAFFE_ENFORCE_GE(
        scale_w_,
        crop_size_,
        "The scaled width must be no smaller than the crop value.");
  }

  if (jitter_scales_.size() > 0) {
    CAFFE_ENFORCE_GE(
        video_res_type_,
        VideoResType::USE_SHORT_EDGE,
        "Scale jittering is used with short_edge scaling only");
  }

  if (get_rgb_) {
    CAFFE_ENFORCE_GT(length_rgb_, 0, "Must provide rgb clip length.");
    CAFFE_ENFORCE_GT(
        sampling_rate_rgb_, 0, "4 frames for mc2; 2 frames for res3d.");
    CAFFE_ENFORCE_EQ(
        channels_rgb_, mean_rgb_.size(), "Number rgb channels is wrong!");
    CAFFE_ENFORCE_EQ(
        channels_rgb_, inv_std_rgb_.size(), "Number rgb channels is wrong!");
  }

  if (get_optical_flow_) {
    CAFFE_ENFORCE_GT(length_of_, 0, "Must provide optical flow clip length.");
    CAFFE_ENFORCE_GT(
        sampling_rate_of_, 0, "4 frames for mc2; 2 frames for res3d.");
    CAFFE_ENFORCE_EQ(
        channels_of_,
        mean_of_.size(),
        "Number of optical flow channels is wrong!");
    CAFFE_ENFORCE_EQ(
        channels_of_,
        inv_std_of_.size(),
        "Number of optical flow channels is wrong!");
  }

  if (clip_per_video_ > 1) {
    CAFFE_ENFORCE_EQ(
        decode_type_,
        DecodeType::DO_UNIFORM_SMP,
        "Only uniformly sampling is supported when sampling multiple clips!");
  }

  if (do_multi_label_) {
    CAFFE_ENFORCE_GT(
        num_of_class_,
        0,
        "Number of classes must be set when using multiple labels.");
  }

  // print out the parameter settings
  LOG(INFO) << "Creating a clip input op with the following setting: ";
  LOG(INFO) << "    Input Type: " << (image_as_input_ ? "Image" : "Video");
  LOG(INFO) << "    Using " << num_decode_threads_ << " CPU threads;";
  LOG(INFO) << "    Outputting in batches of " << batch_size_ << " videos;";
  LOG(INFO) << "    Each video has " << clip_per_video_ << " clips;";
  LOG(INFO) << "    Scaling image to " << scale_h_ << "x" << scale_w_;
  LOG(INFO) << "    Cropping video frame to " << crop_size_
            << (random_mirror_ ? " with " : " without ") << "random mirroring;";
  LOG(INFO) << "    Using " << (random_crop_ ? "random" : "center") << " crop";
  LOG(INFO) << "    Using " << crop_per_clip_ << " spatial crop(s)";

  if (get_rgb_) {
    LOG(INFO) << "    Using a clip of " << length_rgb_ << " rgb frames "
              << "with " << channels_rgb_ << " channels "
              << "and a sampling rate of 1:" << sampling_rate_rgb_;
    if (random_sampling_rate_) {
      LOG(INFO) << "random sampling with max:" << random_sampling_rate_;
    }
    for (int i = 0; i < channels_rgb_; i++) {
      LOG(INFO) << "    RGB " << i << "-th channel mean: " << mean_rgb_[i]
                << " std: " << 1.f / inv_std_rgb_[i];
    }
  }

  if (get_optical_flow_) {
    LOG(INFO) << "    Using a clip of " << length_of_ << " optical flow frames "
              << "with " << channels_of_ << " channels "
              << "and a sampling rate of 1:" << sampling_rate_of_
              << " flow_data_type_: " << flow_data_type_
              << " flow_alg_type_: " << flow_alg_type_;
    for (int i = 0; i < channels_of_; i++) {
      LOG(INFO) << "    Optical flow" << i
                << "-th channel mean: " << mean_of_[i]
                << " std: " << 1.f / inv_std_of_[i];
    }
  }

  if (video_res_type_ == VideoResType::ORIGINAL_RES) {
    LOG(INFO) << "    Use original resolution";
  } else if (video_res_type_ == VideoResType::USE_SHORT_EDGE) {
    LOG(INFO) << "    Resize and keep aspect ratio";
  } else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
    LOG(INFO) << "    Resize and ignore aspect ratio";
  } else {
    LOG(ERROR) << "    Unknown video resolution type";
  }

  if (video_res_type_ == VideoResType::USE_SHORT_EDGE) {
    if (jitter_scales_.size() > 0) {
      LOG(INFO) << "Using scale jittering:";
      for (int idx = 0; idx < jitter_scales_.size(); idx++) {
        LOG(INFO) << "scale " << idx << ": " << jitter_scales_[idx];
      }
    } else {
      LOG(INFO) << "No scale jittering is used.";
    }
  }

  if (decode_type_ == DecodeType::DO_TMP_JITTER) {
    LOG(INFO) << "    Do temporal jittering";
  } else if (decode_type_ == DecodeType::USE_START_FRM) {
    LOG(INFO) << "    Use start_frm for decoding";
  } else if (decode_type_ == DecodeType::DO_UNIFORM_SMP) {
    LOG(INFO) << "    Do uniformly sampling";
  } else {
    LOG(ERROR) << "    Unknown video decoding type";
  }
  if (get_start_frame_) {
    CAFFE_ENFORCE_EQ(
        decode_type_,
        DecodeType::USE_START_FRM,
        "Only decoding with starting frame is supported w/ get start_frame!");
    CAFFE_ENFORCE_EQ(
        clip_per_video_, 1, "get start frame support only clip per video = 1");
  }
}

template <class Context>
VideoInputOp<Context>::VideoInputOp(
    const OperatorDef& operator_def,
    Workspace* ws)
    : PrefetchOperator<Context>(operator_def, ws),
      reader_(nullptr),
      batch_size_(
          OperatorBase::template GetSingleArgument<int>("batch_size", 0)),
      clip_per_video_(
          OperatorBase::template GetSingleArgument<int>("clip_per_video", 1)),
      clip_start_positions_(OperatorBase::template GetRepeatedArgument<int>(
          "clip_start_positions",
          {})),
      channels_rgb_(
          OperatorBase::template GetSingleArgument<int>("channels_rgb", 3)),
      channels_of_(
          OperatorBase::template GetSingleArgument<int>("channels_of", 2)),
      crop_size_(OperatorBase::template GetSingleArgument<int>("crop_size", 0)),
      scale_h_(OperatorBase::template GetSingleArgument<int>("scale_h", 0)),
      scale_w_(OperatorBase::template GetSingleArgument<int>("scale_w", 0)),
      short_edge_(
          OperatorBase::template GetSingleArgument<int>("short_edge", 0)),
      jitter_scales_(
          OperatorBase::template GetRepeatedArgument<int>("jitter_scales", {})),
      length_rgb_(
          OperatorBase::template GetSingleArgument<int>("length_rgb", 0)),
      sampling_rate_rgb_(OperatorBase::template GetSingleArgument<int>(
          "sampling_rate_rgb",
          1)),
      random_sampling_rate_(OperatorBase::template GetSingleArgument<int>(
          "random_sampling_rate",
          0)),
      length_of_(OperatorBase::template GetSingleArgument<int>("length_of", 0)),
      sampling_rate_of_(
          OperatorBase::template GetSingleArgument<int>("sampling_rate_of", 1)),
      frame_gap_of_(
          OperatorBase::template GetSingleArgument<int>("frame_gap_of", 1)),
      random_mirror_(OperatorBase::template GetSingleArgument<bool>(
          "random_mirror",
          true)),
      num_of_class_(
          OperatorBase::template GetSingleArgument<int>("num_of_class", 0)),
      use_local_file_(OperatorBase::template GetSingleArgument<bool>(
          "use_local_file",
          false)),
      random_crop_(
          OperatorBase::template GetSingleArgument<bool>("random_crop", true)),
      crop_per_clip_(
          OperatorBase::template GetSingleArgument<int>("crop_per_clip", 1)),
      flow_data_type_(
          OperatorBase::template GetSingleArgument<int>("flow_data_type", 0)),
      flow_alg_type_(
          OperatorBase::template GetSingleArgument<int>("flow_alg_type", 0)),
      decode_type_(
          OperatorBase::template GetSingleArgument<int>("decode_type", 0)),
      video_res_type_(
          OperatorBase::template GetSingleArgument<int>("video_res_type", 0)),
      do_flow_aggregation_(OperatorBase::template GetSingleArgument<bool>(
          "do_flow_aggregation",
Loading ...