Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / sparse_to_dense_op.h

#ifndef CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
#define CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_

#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

template <class Context>
class SparseToDenseOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  USE_DISPATCH_HELPER;

  template <class... Args>
  explicit SparseToDenseOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        output_first_dim_(
            this->template GetSingleArgument<int>("output_first_dim", 0)) {}

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
        this, Input(INDICES));
  }

 private:
  template <typename TInd>
  int GetOutputFirstDim(
      const TInd* sparse_indices_vec,
      const int32_t sparse_indices_len) {
    if (output_first_dim_ > 0) {
      CAFFE_ENFORCE_EQ(InputSize(), 2);
      return output_first_dim_;
    }
    if (InputSize() == 3) {
      auto& data_to_infer_dim = Input(DATA_TO_INFER_DIM);
      CAFFE_ENFORCE_GE(data_to_infer_dim.dim(), 1);
      return data_to_infer_dim.dim32(0);
    }
    if (sparse_indices_len <= 0) {
      return 0;
    }

    // Awkward way to get the max element to make it work with both CUDA
    // and CPU.
    ReinitializeTensor(&max_element_, {1}, at::dtype<TInd>().device(Context::GetDeviceType()));
    TInd* max_element_ptr = max_element_.template mutable_data<TInd>();
    math::ReduceMax<TInd>(sparse_indices_len, sparse_indices_vec, max_element_ptr,
          &scratch_, &context_);
    max_element_host_.CopyFrom(max_element_);
    return 1 + max_element_host_.template data<TInd>()[0];
  }

  template <typename TInd>
  bool DoRunWithType() {
    return DispatchHelper<
        TensorTypes2<
            float,
            int32_t,
            int64_t,
            GenericTensorImplementation>,
        TInd>::call(this, Input(VALUES));
  }

  template <typename TInd, typename TData>
  bool DoRunWithType2() {
    auto& sparse_indices = Input(INDICES);
    CAFFE_ENFORCE_EQ(sparse_indices.dim(), 1);
    auto& sparse_values = Input(VALUES);
    CAFFE_ENFORCE_GE(sparse_values.dim(), 1);
    CAFFE_ENFORCE_EQ(sparse_indices.numel(), sparse_values.size(0));

    const TInd* sparse_indices_vec = sparse_indices.template data<TInd>();
    const int32_t sparse_indices_len = sparse_indices.dim32(0);
    const int output_first_dim =
        GetOutputFirstDim(sparse_indices_vec, sparse_indices_len);

    auto shape = sparse_values.sizes().vec();
    shape[0] = output_first_dim;

    auto* output = Output(0, shape, at::dtype<TData>());

    TData* output_data = output->template mutable_data<TData>();
    if (!output_first_dim) {
      return true;
    }
    memset(output_data, 0, output->nbytes());
    const auto block_nitems = sparse_values.size_from_dim(1);
    const TData* sparse_values_vec = sparse_values.template data<TData>();

    for (int32_t i = 0; i < sparse_indices_len; i++) {
      const TInd idx = sparse_indices_vec[i];
      CAFFE_ENFORCE_GE(idx, 0);
      CAFFE_ENFORCE_LT(idx, output_first_dim);
      math::Add(
          block_nitems,
          output_data + idx * block_nitems,
          sparse_values_vec + i * block_nitems,
          output_data + idx * block_nitems,
          &context_);
    }
    return true;
  }

  template <typename TInd>
  bool DoRunWithOtherType2() {
    CAFFE_THROW(
        "SparseToDense is not implemented on tensor of type ",
        Input(VALUES).dtype().name(),
        "consider adding it as a type in the DispatchHelper list or "
        "implementing a generic version (which won't work for "
        "duplicated indices though)");
  }

 private:
  int output_first_dim_;
  Tensor scratch_{Context::GetDeviceType()};
  Tensor max_element_host_{CPU};
  Tensor max_element_;

  INPUT_TAGS(INDICES, VALUES, DATA_TO_INFER_DIM);
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_