Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / core / operator.h

#ifndef CAFFE2_CORE_OPERATOR_H_
#define CAFFE2_CORE_OPERATOR_H_

#include <array>
#include <cfenv>
#include <climits>
#include <cstddef>
#include <exception>
#include <functional>
#include <set>
#include <sstream>
#include <string>
#include <typeinfo>
#include <vector>

#include <c10/macros/Macros.h>
#include <c10/util/Registry.h>
#include <c10/util/typeid.h>
#include "caffe2/core/blob.h"
#include "caffe2/core/common.h"
#include "caffe2/core/net.h"
#include "caffe2/core/observer.h"
#include "caffe2/core/operator_gradient.h"
#include "caffe2/core/operator_schema.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/tensor_int8.h"
#include "caffe2/core/types.h"
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"

#if defined(EXPOSE_C2_OPS) || \
    !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
#include <ATen/core/TensorBody.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#endif

C10_DECLARE_bool(caffe2_operator_throw_if_fp_exceptions);
C10_DECLARE_bool(caffe2_operator_throw_if_fp_overflow_exceptions);
#ifdef __GNU_LIBRARY__
C10_DECLARE_bool(caffe2_operator_throw_on_first_occurrence_if_fp_exceptions);
#endif

namespace c10 {
struct FunctionSchema;
}

namespace caffe2 {

class TORCH_API OperatorBase;
typedef ObserverBase<OperatorBase> OperatorObserver;

class TORCH_API OperatorBase : public Observable<OperatorBase> {
 public:
  explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);

  /*
   * Notes: All outputs ivalues must be tensors. Input ivalue list must start
   * with all tensors ("inputs" in caffe2 terminology),
   * followed by non-tensors ("arguments" in caffe2 terminology).
   * Alternatively, inputs can be one tensor list ivalue followed by non-tensors
   * to represent operators with a variable number of inputs.
   */
#if defined(EXPOSE_C2_OPS) || \
    !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
  explicit OperatorBase(
      const c10::FunctionSchema& schema,
      std::vector<c10::IValue> inputs,
      c10::List<at::Tensor> outputs);
#endif

  virtual ~OperatorBase() noexcept;

  /** @brief Return true if the operator was instantiated with OperatorDef
   * New operators should be instantiated with FunctionSchema
   */
  bool isLegacyOperator() const {
#if defined(EXPOSE_C2_OPS) || \
    !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
    return !fn_schema_;
#else
    return true;
#endif
  }

  const c10::FunctionSchema& getFunctionSchema() const {
    CAFFE_ENFORCE(!isLegacyOperator());
#if defined(EXPOSE_C2_OPS) || \
    !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
    return *fn_schema_.get();
#else
    CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
  }

  /** @brief Checks if the operator has an argument of the given name.
   */
  inline bool HasArgument(const string& name) const {
    if (isLegacyOperator()) {
      CAFFE_ENFORCE(operator_def_, "operator_def was null!");
      return ArgumentHelper::HasArgument(*operator_def_, name);
    }
    return argumentIndexWithName(name).has_value();
  }

  // Functions that deal with arguments. Basically, this allows us to map an
  // argument name to a specific type of argument that we are trying to access.
  template <typename T>
  inline T GetSingleArgument(const string& name, const T& default_value) const {
    if (isLegacyOperator()) {
      CAFFE_ENFORCE(operator_def_, "operator_def was null!");
      return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
          *operator_def_, name, default_value);
    }
#if defined(EXPOSE_C2_OPS) || \
    !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
    auto index = argumentIndexWithName(name);
    CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name);
    const auto& value = newstyle_inputs_[index.value()];
    return value.template to<T>();
#else
    CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
  }

  template <typename T>
  inline bool HasSingleArgumentOfType(const string& name) const {
    CAFFE_ENFORCE(operator_def_, "operator_def was null!");
    return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
        *operator_def_, name);
  }
#if defined(EXPOSE_C2_OPS) || \
    !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
  template <typename T>
  inline vector<T> GetVectorFromIValueList(const c10::IValue& value) const {
    return value.template to<List<T>>().vec();
  }
#endif

  template <typename T>
  inline vector<T> GetRepeatedArgument(
      const string& name,
      const vector<T>& default_value = {}) const;

  // Get the inputs and outputs as specific types.
  template <typename T>
  inline const T& Input(int idx) {
    static_assert(
        !std::is_same<T, Tensor>::value,
        "You should use Input<Tensor>(int, DeviceType) for "
        "Tensor.");
    DCHECK_LT((size_t)idx, inputs_.size());
    try {
      return inputs_.at(idx)->template Get<T>();
    } catch (::caffe2::EnforceNotMet& enf) {
      if (has_debug_def()) {
        TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), ".");
      }
      throw enf;
    }
  }

  // TODO(jerryzh): Remove template
  // and the type argument?
  // This is to keep the API changes minimal and make refactoring
  // a bit easier
  template <typename T>
  inline const T& Input(int idx, DeviceType type) {
    if (isLegacyOperator()) {
      static_assert(
          std::is_same<T, Tensor>::value,
          "Input(int, DeviceType) is only available for Tensor");
      DCHECK_LT((size_t)idx, inputs_.size());
      try {
        // TODO(jerryzh): We'll need to check device type in Get<T>() later
        // Get<T>() -> Get<T>(type)
        const auto& tensor = inputs_.at(idx)->template Get<T>();
        return tensor;
      } catch (::caffe2::EnforceNotMet& enf) {
        if (has_debug_def()) {
          TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), ".");
        }
        throw enf;
      }
    }
#if defined(EXPOSE_C2_OPS) || \
    !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
    DCHECK_LT(0U, newstyle_inputs_.size());
    IValue ival;
    if (newstyle_inputs_[0].isTensorList()) {
      // if the first input is a tensor list, we get input tensors by indexing
      // into that list. currently, this means that only tensors from that list
      // are accessible as inputs. any hypothetical input tensors that come
      // after the list are not accessible.
      auto tensorList = newstyle_inputs_[0].toTensorVector();
      DCHECK_LT((size_t)idx, tensorList.size());
      ival = tensorList[idx];
    } else {
      // if the first input is not a tensor list, we get input tensors by
      // indexing into the inputs.
      DCHECK_LT((size_t)idx, newstyle_inputs_.size());
      ival = newstyle_inputs_[idx];
    }
    CAFFE_ENFORCE(
        ival.isTensor(),
        "Input(int, DeviceType) is only available for IValues that store Tensors");
    auto t = ival.toTensor();
    if (!t.is_contiguous()) {
      t = t.contiguous();
    }
    Tensor tensor = caffe2::Tensor(std::move(t));
    CAFFE_ENFORCE_EQ(tensor.GetDeviceType(), type);
    input_tensors_[idx] = std::move(tensor);
    return input_tensors_[idx];
#else
    CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
  }

  template <typename T>
  inline T* Output(int idx) {
    CAFFE_ENFORCE(
        isLegacyOperator(),
        "Output(idx) not supported for operators exported to c10. Please use XOutput instead.");

    static_assert(
        !std::is_same<T, Tensor>::value,
        "You should use Output<Tensor>(int, DeviceType) for "
        "Tensor.");
    return outputs_.at(idx)->template GetMutable<T>();
  }

  // TODO(jerryzh): Remove this template
  template <typename T>
  inline T* Output(int idx, DeviceType type) {
    if (isLegacyOperator()) {
      static_assert(
          std::is_same<T, Tensor>::value,
          "Output(int, DeviceType) is only available for Tensor");
      // When you get a Tensor here it is not fully initialized
      return BlobGetMutableTensor(outputs_.at(idx), type);
    }
#if defined(EXPOSE_C2_OPS) || \
    !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
    at::Tensor output = newstyle_outputs_[idx];
    if (!output.defined() || caffe2::Tensor(output).GetDeviceType() != type) {
      // Fix tensor type
      Tensor tensor = Tensor(type);
      output = at::Tensor(std::move(tensor.getIntrusivePtr()));
    }
    output_tensors_[idx] = caffe2::Tensor(output);
    newstyle_outputs_[idx] = std::move(output);
    return &output_tensors_[idx];
#else
    CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
  }

  inline Tensor
  XOutputTensor(int idx, at::IntArrayRef dims, at::TensorOptions options) {
    CAFFE_ENFORCE_WITH_CALLER(
        options.device_opt() != c10::nullopt,
        "device must be provided in option.");
    if (isLegacyOperator()) {
      return XBlobGetMutableTensor(outputs_.at(idx), dims, options);
    }

    return OutputTensor(idx, dims, options)->UnsafeSharedInstance();
  }

  void SetOutputTensor(int idx, Tensor tensor) {
    if (!isLegacyOperator()) {
#if defined(EXPOSE_C2_OPS) || \
    !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
      newstyle_outputs_[idx] = at::Tensor(tensor);

      // also update the tensor in the hack
      output_tensors_[idx] = std::move(tensor);
#else
      CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
    } else {
      // update the tensor in the workspace
      BlobSetTensor(outputs_.at(idx), std::move(tensor));
    }
  }

  Tensor OutputTensorOrUndefined(int idx) {
    if (isLegacyOperator()) {
      return BlobGetTensorOrUndefined(*outputs_.at(idx));
    }
    return output_tensors_[idx].UnsafeSharedInstance();
  }

  inline Tensor*
  OutputTensor(int idx, at::IntArrayRef dims, at::TensorOptions options) {
    if (isLegacyOperator()) {
      CAFFE_ENFORCE_WITH_CALLER(
          options.device_opt() != c10::nullopt,
          "device must be provided in options.");
      return BlobGetMutableTensor(outputs_.at(idx), dims, options);
    }
#if defined(EXPOSE_C2_OPS) || \
    !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
    at::Tensor output = newstyle_outputs_[idx];
    Tensor tensor = output.defined()
        ? GetSizedTensorWithOptions(caffe2::Tensor(output), dims, options)
        : caffe2::empty(dims, options);
    // assign it back in case it changed
    output = at::Tensor(std::move(tensor.getIntrusivePtr()));

    output_tensors_[idx] = caffe2::Tensor(output);
    newstyle_outputs_[idx] = std::move(output);
    return &output_tensors_[idx];
#else
    CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
  }

  // Get output Tensor of the operator and CopyFrom the given Tensor
  Tensor* OutputTensorCopyFrom(
      int idx,
      at::TensorOptions options,
      const Tensor& src,
      bool async = false) {
    CAFFE_ENFORCE_WITH_CALLER(
        options.device_opt() != c10::nullopt,
        "device must be provided in options.");
    // Ouptut Tensor will always have the same data type as `src`
    if (!options.has_dtype()) {
      options = options.dtype(src.dtype());
    }
    CAFFE_ENFORCE_WITH_CALLER(
        options.dtype() == src.dtype(),
        "We don't allow change of src data type in OutputTensorCopyFrom");
    Tensor* t = OutputTensor(idx, src.sizes(), options);
    t->CopyFrom(src, async);
    return t;
  }

  Tensor* OutputTensorAlias(int idx, const Tensor& src) {
    CAFFE_ENFORCE(
        isLegacyOperator(),
        "OutputTensorAlias(idx, src) not (yet) supported for operators exported to c10.");
Loading ...