#ifndef CAFFE2_CORE_OPERATOR_H_
#define CAFFE2_CORE_OPERATOR_H_
#include <array>
#include <cfenv>
#include <climits>
#include <cstddef>
#include <exception>
#include <functional>
#include <set>
#include <sstream>
#include <string>
#include <typeinfo>
#include <vector>
#include <c10/macros/Macros.h>
#include <c10/util/Registry.h>
#include <c10/util/typeid.h>
#include "caffe2/core/blob.h"
#include "caffe2/core/common.h"
#include "caffe2/core/net.h"
#include "caffe2/core/observer.h"
#include "caffe2/core/operator_gradient.h"
#include "caffe2/core/operator_schema.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/tensor_int8.h"
#include "caffe2/core/types.h"
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
#include <ATen/core/TensorBody.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#endif
C10_DECLARE_bool(caffe2_operator_throw_if_fp_exceptions);
C10_DECLARE_bool(caffe2_operator_throw_if_fp_overflow_exceptions);
#ifdef __GNU_LIBRARY__
C10_DECLARE_bool(caffe2_operator_throw_on_first_occurrence_if_fp_exceptions);
#endif
namespace c10 {
struct FunctionSchema;
}
namespace caffe2 {
class TORCH_API OperatorBase;
typedef ObserverBase<OperatorBase> OperatorObserver;
class TORCH_API OperatorBase : public Observable<OperatorBase> {
public:
explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
/*
* Notes: All outputs ivalues must be tensors. Input ivalue list must start
* with all tensors ("inputs" in caffe2 terminology),
* followed by non-tensors ("arguments" in caffe2 terminology).
* Alternatively, inputs can be one tensor list ivalue followed by non-tensors
* to represent operators with a variable number of inputs.
*/
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
explicit OperatorBase(
const c10::FunctionSchema& schema,
std::vector<c10::IValue> inputs,
c10::List<at::Tensor> outputs);
#endif
virtual ~OperatorBase() noexcept;
/** @brief Return true if the operator was instantiated with OperatorDef
* New operators should be instantiated with FunctionSchema
*/
bool isLegacyOperator() const {
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
return !fn_schema_;
#else
return true;
#endif
}
const c10::FunctionSchema& getFunctionSchema() const {
CAFFE_ENFORCE(!isLegacyOperator());
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
return *fn_schema_.get();
#else
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
}
/** @brief Checks if the operator has an argument of the given name.
*/
inline bool HasArgument(const string& name) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::HasArgument(*operator_def_, name);
}
return argumentIndexWithName(name).has_value();
}
// Functions that deal with arguments. Basically, this allows us to map an
// argument name to a specific type of argument that we are trying to access.
template <typename T>
inline T GetSingleArgument(const string& name, const T& default_value) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
*operator_def_, name, default_value);
}
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
auto index = argumentIndexWithName(name);
CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name);
const auto& value = newstyle_inputs_[index.value()];
return value.template to<T>();
#else
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
}
template <typename T>
inline bool HasSingleArgumentOfType(const string& name) const {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
*operator_def_, name);
}
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
template <typename T>
inline vector<T> GetVectorFromIValueList(const c10::IValue& value) const {
return value.template to<List<T>>().vec();
}
#endif
template <typename T>
inline vector<T> GetRepeatedArgument(
const string& name,
const vector<T>& default_value = {}) const;
// Get the inputs and outputs as specific types.
template <typename T>
inline const T& Input(int idx) {
static_assert(
!std::is_same<T, Tensor>::value,
"You should use Input<Tensor>(int, DeviceType) for "
"Tensor.");
DCHECK_LT((size_t)idx, inputs_.size());
try {
return inputs_.at(idx)->template Get<T>();
} catch (::caffe2::EnforceNotMet& enf) {
if (has_debug_def()) {
TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), ".");
}
throw enf;
}
}
// TODO(jerryzh): Remove template
// and the type argument?
// This is to keep the API changes minimal and make refactoring
// a bit easier
template <typename T>
inline const T& Input(int idx, DeviceType type) {
if (isLegacyOperator()) {
static_assert(
std::is_same<T, Tensor>::value,
"Input(int, DeviceType) is only available for Tensor");
DCHECK_LT((size_t)idx, inputs_.size());
try {
// TODO(jerryzh): We'll need to check device type in Get<T>() later
// Get<T>() -> Get<T>(type)
const auto& tensor = inputs_.at(idx)->template Get<T>();
return tensor;
} catch (::caffe2::EnforceNotMet& enf) {
if (has_debug_def()) {
TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), ".");
}
throw enf;
}
}
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
DCHECK_LT(0U, newstyle_inputs_.size());
IValue ival;
if (newstyle_inputs_[0].isTensorList()) {
// if the first input is a tensor list, we get input tensors by indexing
// into that list. currently, this means that only tensors from that list
// are accessible as inputs. any hypothetical input tensors that come
// after the list are not accessible.
auto tensorList = newstyle_inputs_[0].toTensorVector();
DCHECK_LT((size_t)idx, tensorList.size());
ival = tensorList[idx];
} else {
// if the first input is not a tensor list, we get input tensors by
// indexing into the inputs.
DCHECK_LT((size_t)idx, newstyle_inputs_.size());
ival = newstyle_inputs_[idx];
}
CAFFE_ENFORCE(
ival.isTensor(),
"Input(int, DeviceType) is only available for IValues that store Tensors");
auto t = ival.toTensor();
if (!t.is_contiguous()) {
t = t.contiguous();
}
Tensor tensor = caffe2::Tensor(std::move(t));
CAFFE_ENFORCE_EQ(tensor.GetDeviceType(), type);
input_tensors_[idx] = std::move(tensor);
return input_tensors_[idx];
#else
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
}
template <typename T>
inline T* Output(int idx) {
CAFFE_ENFORCE(
isLegacyOperator(),
"Output(idx) not supported for operators exported to c10. Please use XOutput instead.");
static_assert(
!std::is_same<T, Tensor>::value,
"You should use Output<Tensor>(int, DeviceType) for "
"Tensor.");
return outputs_.at(idx)->template GetMutable<T>();
}
// TODO(jerryzh): Remove this template
template <typename T>
inline T* Output(int idx, DeviceType type) {
if (isLegacyOperator()) {
static_assert(
std::is_same<T, Tensor>::value,
"Output(int, DeviceType) is only available for Tensor");
// When you get a Tensor here it is not fully initialized
return BlobGetMutableTensor(outputs_.at(idx), type);
}
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
at::Tensor output = newstyle_outputs_[idx];
if (!output.defined() || caffe2::Tensor(output).GetDeviceType() != type) {
// Fix tensor type
Tensor tensor = Tensor(type);
output = at::Tensor(std::move(tensor.getIntrusivePtr()));
}
output_tensors_[idx] = caffe2::Tensor(output);
newstyle_outputs_[idx] = std::move(output);
return &output_tensors_[idx];
#else
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
}
inline Tensor
XOutputTensor(int idx, at::IntArrayRef dims, at::TensorOptions options) {
CAFFE_ENFORCE_WITH_CALLER(
options.device_opt() != c10::nullopt,
"device must be provided in option.");
if (isLegacyOperator()) {
return XBlobGetMutableTensor(outputs_.at(idx), dims, options);
}
return OutputTensor(idx, dims, options)->UnsafeSharedInstance();
}
void SetOutputTensor(int idx, Tensor tensor) {
if (!isLegacyOperator()) {
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
newstyle_outputs_[idx] = at::Tensor(tensor);
// also update the tensor in the hack
output_tensors_[idx] = std::move(tensor);
#else
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
} else {
// update the tensor in the workspace
BlobSetTensor(outputs_.at(idx), std::move(tensor));
}
}
Tensor OutputTensorOrUndefined(int idx) {
if (isLegacyOperator()) {
return BlobGetTensorOrUndefined(*outputs_.at(idx));
}
return output_tensors_[idx].UnsafeSharedInstance();
}
inline Tensor*
OutputTensor(int idx, at::IntArrayRef dims, at::TensorOptions options) {
if (isLegacyOperator()) {
CAFFE_ENFORCE_WITH_CALLER(
options.device_opt() != c10::nullopt,
"device must be provided in options.");
return BlobGetMutableTensor(outputs_.at(idx), dims, options);
}
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
at::Tensor output = newstyle_outputs_[idx];
Tensor tensor = output.defined()
? GetSizedTensorWithOptions(caffe2::Tensor(output), dims, options)
: caffe2::empty(dims, options);
// assign it back in case it changed
output = at::Tensor(std::move(tensor.getIntrusivePtr()));
output_tensors_[idx] = caffe2::Tensor(output);
newstyle_outputs_[idx] = std::move(output);
return &output_tensors_[idx];
#else
CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2");
#endif
}
// Get output Tensor of the operator and CopyFrom the given Tensor
Tensor* OutputTensorCopyFrom(
int idx,
at::TensorOptions options,
const Tensor& src,
bool async = false) {
CAFFE_ENFORCE_WITH_CALLER(
options.device_opt() != c10::nullopt,
"device must be provided in options.");
// Ouptut Tensor will always have the same data type as `src`
if (!options.has_dtype()) {
options = options.dtype(src.dtype());
}
CAFFE_ENFORCE_WITH_CALLER(
options.dtype() == src.dtype(),
"We don't allow change of src data type in OutputTensorCopyFrom");
Tensor* t = OutputTensor(idx, src.sizes(), options);
t->CopyFrom(src, async);
return t;
}
Tensor* OutputTensorAlias(int idx, const Tensor& src) {
CAFFE_ENFORCE(
isLegacyOperator(),
"OutputTensorAlias(idx, src) not (yet) supported for operators exported to c10.");
Loading ...