#pragma once
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/qualified_name.h>
#include <ATen/core/stack.h>
#include <pybind11/complex.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/Layout.h>
#include <torch/csrc/QScheme.h>
#include <torch/csrc/Stream.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/python_custom_class.h>
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/utils/auto_gil.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/six.h>
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/rref_impl.h>
#endif
#include <ATen/core/function_schema.h>
#include <c10/core/Stream.h>
#ifdef USE_C10D_NCCL
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>
#endif
#include <c10/util/Exception.h>
#include <algorithm>
#include <cstddef>
#include <string>
#include <utility>
#include <vector>
// The visibility attribute is to avoid a warning about storing a field in the
// struct that has a different visibility (from pybind) than the struct.
#ifdef _WIN32
#define VISIBILITY_HIDDEN
#else
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
#endif
namespace torch {
namespace jit {
void clear_registered_instances(void* ptr);
IValue toIValue(
py::handle obj,
const TypePtr& type,
c10::optional<int32_t> N = c10::nullopt);
py::object toPyObject(IValue ivalue);
// Wrap Python function to guard deref
// NB: Need VISIBILITY_HIDDEN for silencing compiler error,
// 'torch::jit::PythonFunctionGuard' declared with greater visibility than the
// type of its field 'torch::jit::PythonFunctionGuard::func_'
struct VISIBILITY_HIDDEN PythonFunctionGuard {
explicit PythonFunctionGuard(py::function func) : func_(std::move(func)) {}
~PythonFunctionGuard() {
pybind11::gil_scoped_acquire ag;
func_.dec_ref();
// explicitly setting PyObject* to nullptr to prevent py::object's dtor to
// decref on the PyObject again.
// See Note [Destructing py::object] in python_ivalue.h
func_.ptr() = nullptr;
}
py::function func_;
};
// The PythonFutureWrapper for ivalue::Future
//
// NB: VISIBILITY_HIDDEN is for silencing compiling error,
// "error: 'torch::jit::PythonFutureWrapper' declared with greater visibility
// than the type of its field 'torch::jit::PythonFutureWrapper::unwrap_func'
// [-Werror=attributes]"
//
// NB: inherit from enable_shared_from_this because then(py::function) needs to
// get a shared_ptr from this pointer.
struct VISIBILITY_HIDDEN PythonFutureWrapper
: std::enable_shared_from_this<PythonFutureWrapper> {
using UnwrapFunc = std::function<void(py::object)>;
explicit PythonFutureWrapper(
c10::intrusive_ptr<c10::ivalue::Future> fut,
c10::optional<UnwrapFunc> unwrap_func = c10::nullopt)
: fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {}
explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete;
PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;
bool done() {
return fut->completed();
}
py::object value() {
// acquiring GIL as toPyObject creates new py::object
// without grabbing the GIL.
py::gil_scoped_acquire acquire;
py::object py_obj = toPyObject(fut->value());
// unwrap_func is a general compositional function that takes in a
// py::object and executes some python function. It is currently mostly used
// to throw python exceptions.
if (unwrap_func) {
(*unwrap_func)(py_obj);
}
return py_obj;
}
py::object wait() {
fut->wait();
if (jit::tracer::isTracing()) {
auto graph = jit::tracer::getTracingState()->graph;
Value* fut_val = jit::tracer::getValueTrace(fut);
auto output = graph->insert(aten::wait, {fut_val});
jit::tracer::setValueTrace(fut->value(), output);
}
return value();
}
// The py::function cb arg must take a std::shared_ptr<PythonFutureWrapper>
// (i.e., torch._C.Future) as the only argument. If the type mismatches, an
// error will be thrown when waiting for the value of this returned Future.
std::shared_ptr<PythonFutureWrapper> then(py::function cb) {
// We need this an additional layer of wrapper here to guard the
// destruction of the py::function object. Because, the
// Future owns a reference to the py::function in its callback
// vector, but Future does not acquire GIL on destruction.
auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb));
return std::make_shared<jit::PythonFutureWrapper>(fut->then(
// Capture a copy of the ivalue::Future instead of the `this` pointer
// because the PythonFutureWrapper object could have been deleted
// when the callbacks are fired. For example, RPC only captures the
// ivalue::Future instead of PythonFutureWrapper in JitFuture's
// callback functions. Hence, if user code does not hold a reference to
// this PythonFutureWrapper object, there is no guarantee that the
// PythonFutureWrapper is still valid when running the callback.
[pyFut(this->getPtr()), pf(std::move(pf))]() -> IValue {
try {
pybind11::gil_scoped_acquire ag;
return toIValue(pf->func_(pyFut), PyObjectType::get());
} catch (py::error_already_set& e) {
auto err = std::runtime_error(c10::str(
"Got the following error when running the callback: ",
e.what()));
{
pybind11::gil_scoped_acquire ag;
// Release ownership on py::objects and also restore Python
// Error Indicator.
e.restore();
// Clear the Python Error Indicator as we has recorded the
// exception in the response message.
PyErr_Clear();
}
throw err;
}
},
PyObjectType::get()));
}
void add_done_callback(py::function cb) {
auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb));
fut->addCallback(std::bind(
[pyFut(this->getPtr())](std::shared_ptr<PythonFunctionGuard> pf) {
try {
pybind11::gil_scoped_acquire ag;
pf->func_(pyFut);
} catch (py::error_already_set& e) {
{
pybind11::gil_scoped_acquire ag;
// Release ownership on py::objects and also restore Python
// Error Indicator.
e.restore();
// Clear the Python Error Indicator as we has recorded the
// exception in the response message.
PyErr_Clear();
}
// Log and ignore exceptions raised through the callback
VLOG(1) << "Got the following error when running the callback: "
<< e.what();
} catch (std::exception& e) {
// Log and ignore exceptions raised through the callback
VLOG(1) << "Got the following error when running the callback: "
<< e.what();
}
},
std::move(pf)));
}
void markCompleted(const py::object& pyValue) {
DCHECK(PyGILState_Check());
IValue value = toIValue(pyValue, PyObjectType::get());
py::gil_scoped_release release;
fut->markCompleted(std::move(value));
}
c10::intrusive_ptr<c10::ivalue::Future> fut;
// unwrap_func works like a callback for the value returned by
// PythonFutureWrapper::wait().
c10::optional<UnwrapFunc> unwrap_func;
private:
std::shared_ptr<PythonFutureWrapper> getPtr() {
return shared_from_this();
}
};
// error reporting: when reporting user-caused errors, these functions should
// not use AT_ERROR macros, since these macros add stack trace information
// that is confusing to display to the end user since it always reports
// locations in libtorch code rather than user code.
inline std::shared_ptr<CompilationUnit> get_python_cu() {
return py::module::import("torch.jit._state")
.attr("_python_cu")
.cast<std::shared_ptr<CompilationUnit>>();
}
struct TypedIValue : public std::pair<IValue, TypePtr> {
using pair::pair;
IValue& ivalue() {
return this->first;
}
TypePtr& type() {
return this->second;
}
};
inline TypedIValue toDictKeyIValue(py::handle key) {
if (py::isinstance<py::str>(key)) {
return TypedIValue(
ConstantString::create(py::cast<std::string>(key)),
StringType::create());
} else if (py::isinstance<py::int_>(key)) {
return TypedIValue(py::cast<int64_t>(key), IntType::create());
} else if (py::isinstance<py::float_>(key)) {
return TypedIValue(py::cast<double>(key), FloatType::create());
} else {
AT_ERROR("Dictionary inputs may only have string, int, or float keys");
}
}
inline c10::optional<TypePtr> unifyOrInitializeType(
const TypePtr& accum,
const TypePtr& unify) {
if (!accum) {
return unify;
}
return unifyTypes(accum, unify);
}
using InferredType = c10::InferredType;
InferredType tryToInferContainerType(py::handle input);
// Try to infer the type of a Python object
// The type cannot be inferred if:
// input is a None
// input is an empty container (list, dict)
// input is an list with element types that cannot be unified
// input is an dict with key or value types that cannot be unified
inline InferredType tryToInferType(py::handle input) {
// Try tensor types
if (THPVariable_Check(input.ptr())) {
return InferredType(TensorType::get());
}
if (input.is(py::none())) {
return InferredType(NoneType::get());
}
if (py::isinstance<StrongFunctionPtr>(input)) {
auto fn = py::cast<StrongFunctionPtr>(input).function_;
return InferredType(FunctionType::create(fn));
}
// Try basic types first
if (py::isinstance<py::bool_>(input)) {
return InferredType(BoolType::get());
} else if (py::isinstance<py::int_>(input)) {
return InferredType(IntType::get());
} else if (py::isinstance<py::float_>(input)) {
return InferredType(FloatType::get());
} else if (PyComplex_CheckExact(input.ptr())) {
return InferredType(ComplexType::get());
} else if (py::isinstance<py::str>(input)) {
return InferredType(StringType::get());
} else if (THPLayout_Check(input.ptr())) {
return InferredType(IntType::get());
} else if (THPDevice_Check(input.ptr())) {
return InferredType(DeviceObjType::get());
} else if (THPStream_Check(input.ptr())) {
return InferredType(StreamObjType::get());
} else if (THPDtype_Check(input.ptr())) {
return InferredType(IntType::get());
} else if (THPQScheme_Check(input.ptr())) {
return InferredType(IntType::get());
} else if (THPLayout_Check(input.ptr())) {
return InferredType(IntType::get());
}
auto enum_type = py::module::import("enum").attr("Enum");
py::bool_ isEnumValue = py::isinstance(input, enum_type);
if (py::cast<bool>(isEnumValue)) {
auto enum_class = input.attr("__class__");
auto enum_type = py::cast<TypePtr>(
py::module::import("torch.jit.annotations")
.attr("try_ann_to_type")(enum_class, SourceRange()));
return InferredType(enum_type);
}
py::bool_ isClass =
py::module::import("inspect").attr("isclass")(input.get_type());
if (py::cast<bool>(isClass)) {
py::str qualifiedName = py::module::import("torch._jit_internal")
.attr("_qualified_name")(input.get_type());
auto pyClass = py::module::import("torch.jit._state")
.attr("_get_script_class")(qualifiedName);
if (!pyClass.is_none()) {
auto cu = get_python_cu();
const auto classname =
c10::QualifiedName(py::cast<std::string>(qualifiedName));
auto class_type = cu->get_class(classname);
TORCH_INTERNAL_ASSERT(class_type);
return InferredType(class_type);
Loading ...