Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / torch / csrc / jit / python / pybind_utils.h

#pragma once

#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/qualified_name.h>
#include <ATen/core/stack.h>
#include <pybind11/complex.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/Layout.h>
#include <torch/csrc/QScheme.h>
#include <torch/csrc/Stream.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/python_custom_class.h>
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/utils/auto_gil.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/six.h>
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/rref_impl.h>
#endif

#include <ATen/core/function_schema.h>
#include <c10/core/Stream.h>
#ifdef USE_C10D_NCCL
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>
#endif
#include <c10/util/Exception.h>

#include <algorithm>
#include <cstddef>
#include <string>
#include <utility>
#include <vector>

// The visibility attribute is to avoid a warning about storing a field in the
// struct that has a different visibility (from pybind) than the struct.
#ifdef _WIN32
#define VISIBILITY_HIDDEN
#else
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
#endif

namespace torch {
namespace jit {

void clear_registered_instances(void* ptr);

IValue toIValue(
    py::handle obj,
    const TypePtr& type,
    c10::optional<int32_t> N = c10::nullopt);

py::object toPyObject(IValue ivalue);

// Wrap Python function to guard deref
// NB: Need VISIBILITY_HIDDEN for silencing compiler error,
// 'torch::jit::PythonFunctionGuard' declared with greater visibility than the
// type of its field 'torch::jit::PythonFunctionGuard::func_'
struct VISIBILITY_HIDDEN PythonFunctionGuard {
  explicit PythonFunctionGuard(py::function func) : func_(std::move(func)) {}

  ~PythonFunctionGuard() {
    pybind11::gil_scoped_acquire ag;
    func_.dec_ref();
    // explicitly setting PyObject* to nullptr to prevent py::object's dtor to
    // decref on the PyObject again.
    // See Note [Destructing py::object] in python_ivalue.h
    func_.ptr() = nullptr;
  }

  py::function func_;
};

// The PythonFutureWrapper for ivalue::Future
//
// NB: VISIBILITY_HIDDEN is for silencing compiling error,
// "error: 'torch::jit::PythonFutureWrapper' declared with greater visibility
// than the type of its field 'torch::jit::PythonFutureWrapper::unwrap_func'
// [-Werror=attributes]"
//
// NB: inherit from enable_shared_from_this because then(py::function) needs to
//     get a shared_ptr from this pointer.
struct VISIBILITY_HIDDEN PythonFutureWrapper
    : std::enable_shared_from_this<PythonFutureWrapper> {
  using UnwrapFunc = std::function<void(py::object)>;

  explicit PythonFutureWrapper(
      c10::intrusive_ptr<c10::ivalue::Future> fut,
      c10::optional<UnwrapFunc> unwrap_func = c10::nullopt)
      : fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {}

  explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete;
  PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;

  bool done() {
    return fut->completed();
  }

  py::object value() {
    // acquiring GIL as toPyObject creates new py::object
    // without grabbing the GIL.
    py::gil_scoped_acquire acquire;
    py::object py_obj = toPyObject(fut->value());
    // unwrap_func is a general compositional function that takes in a
    // py::object and executes some python function. It is currently mostly used
    // to throw python exceptions.
    if (unwrap_func) {
      (*unwrap_func)(py_obj);
    }
    return py_obj;
  }

  py::object wait() {
    fut->wait();
    if (jit::tracer::isTracing()) {
      auto graph = jit::tracer::getTracingState()->graph;

      Value* fut_val = jit::tracer::getValueTrace(fut);
      auto output = graph->insert(aten::wait, {fut_val});
      jit::tracer::setValueTrace(fut->value(), output);
    }
    return value();
  }

  // The py::function cb arg must take a std::shared_ptr<PythonFutureWrapper>
  // (i.e., torch._C.Future) as the only argument. If the type mismatches, an
  // error will be thrown when waiting for the value of this returned Future.
  std::shared_ptr<PythonFutureWrapper> then(py::function cb) {
    // We need this an additional layer of wrapper here to guard the
    // destruction of the py::function object. Because, the
    // Future owns a reference to the py::function in its callback
    // vector, but Future does not acquire GIL on destruction.
    auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb));

    return std::make_shared<jit::PythonFutureWrapper>(fut->then(
        // Capture a copy of the ivalue::Future instead of the `this` pointer
        // because the PythonFutureWrapper object could have been deleted
        // when the callbacks are fired. For example, RPC only captures the
        // ivalue::Future instead of PythonFutureWrapper in JitFuture's
        // callback functions. Hence, if user code does not hold a reference to
        // this PythonFutureWrapper object, there is no guarantee that the
        // PythonFutureWrapper is still valid when running the callback.
        [pyFut(this->getPtr()), pf(std::move(pf))]() -> IValue {
          try {
            pybind11::gil_scoped_acquire ag;
            return toIValue(pf->func_(pyFut), PyObjectType::get());
          } catch (py::error_already_set& e) {
            auto err = std::runtime_error(c10::str(
                "Got the following error when running the callback: ",
                e.what()));
            {
              pybind11::gil_scoped_acquire ag;
              // Release ownership on py::objects and also restore Python
              // Error Indicator.
              e.restore();
              // Clear the Python Error Indicator as we has recorded the
              // exception in the response message.
              PyErr_Clear();
            }

            throw err;
          }
        },
        PyObjectType::get()));
  }

  void add_done_callback(py::function cb) {
    auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb));
    fut->addCallback(std::bind(
        [pyFut(this->getPtr())](std::shared_ptr<PythonFunctionGuard> pf) {
          try {
            pybind11::gil_scoped_acquire ag;
            pf->func_(pyFut);
          } catch (py::error_already_set& e) {
            {
              pybind11::gil_scoped_acquire ag;
              // Release ownership on py::objects and also restore Python
              // Error Indicator.
              e.restore();
              // Clear the Python Error Indicator as we has recorded the
              // exception in the response message.
              PyErr_Clear();
            }
            // Log and ignore exceptions raised through the callback
            VLOG(1) << "Got the following error when running the callback: "
                    << e.what();

          } catch (std::exception& e) {
            // Log and ignore exceptions raised through the callback
            VLOG(1) << "Got the following error when running the callback: "
                    << e.what();
          }
        },
        std::move(pf)));
  }

  void markCompleted(const py::object& pyValue) {
    DCHECK(PyGILState_Check());
    IValue value = toIValue(pyValue, PyObjectType::get());

    py::gil_scoped_release release;
    fut->markCompleted(std::move(value));
  }

  c10::intrusive_ptr<c10::ivalue::Future> fut;
  // unwrap_func works like a callback for the value returned by
  // PythonFutureWrapper::wait().
  c10::optional<UnwrapFunc> unwrap_func;

 private:
  std::shared_ptr<PythonFutureWrapper> getPtr() {
    return shared_from_this();
  }
};

// error reporting: when reporting user-caused errors, these functions should
// not use AT_ERROR macros, since these macros add stack trace information
// that is confusing to display to the end user since it always reports
// locations in libtorch code rather than user code.

inline std::shared_ptr<CompilationUnit> get_python_cu() {
  return py::module::import("torch.jit._state")
      .attr("_python_cu")
      .cast<std::shared_ptr<CompilationUnit>>();
}

struct TypedIValue : public std::pair<IValue, TypePtr> {
  using pair::pair;

  IValue& ivalue() {
    return this->first;
  }
  TypePtr& type() {
    return this->second;
  }
};

inline TypedIValue toDictKeyIValue(py::handle key) {
  if (py::isinstance<py::str>(key)) {
    return TypedIValue(
        ConstantString::create(py::cast<std::string>(key)),
        StringType::create());
  } else if (py::isinstance<py::int_>(key)) {
    return TypedIValue(py::cast<int64_t>(key), IntType::create());
  } else if (py::isinstance<py::float_>(key)) {
    return TypedIValue(py::cast<double>(key), FloatType::create());
  } else {
    AT_ERROR("Dictionary inputs may only have string, int, or float keys");
  }
}

inline c10::optional<TypePtr> unifyOrInitializeType(
    const TypePtr& accum,
    const TypePtr& unify) {
  if (!accum) {
    return unify;
  }
  return unifyTypes(accum, unify);
}

using InferredType = c10::InferredType;

InferredType tryToInferContainerType(py::handle input);

// Try to infer the type of a Python object
// The type cannot be inferred if:
//   input is a None
//   input is an empty container (list, dict)
//   input is an list with element types that cannot be unified
//   input is an dict with key or value types that cannot be unified
inline InferredType tryToInferType(py::handle input) {
  // Try tensor types
  if (THPVariable_Check(input.ptr())) {
    return InferredType(TensorType::get());
  }

  if (input.is(py::none())) {
    return InferredType(NoneType::get());
  }

  if (py::isinstance<StrongFunctionPtr>(input)) {
    auto fn = py::cast<StrongFunctionPtr>(input).function_;
    return InferredType(FunctionType::create(fn));
  }

  // Try basic types first
  if (py::isinstance<py::bool_>(input)) {
    return InferredType(BoolType::get());
  } else if (py::isinstance<py::int_>(input)) {
    return InferredType(IntType::get());
  } else if (py::isinstance<py::float_>(input)) {
    return InferredType(FloatType::get());
  } else if (PyComplex_CheckExact(input.ptr())) {
    return InferredType(ComplexType::get());
  } else if (py::isinstance<py::str>(input)) {
    return InferredType(StringType::get());
  } else if (THPLayout_Check(input.ptr())) {
    return InferredType(IntType::get());
  } else if (THPDevice_Check(input.ptr())) {
    return InferredType(DeviceObjType::get());
  } else if (THPStream_Check(input.ptr())) {
    return InferredType(StreamObjType::get());
  } else if (THPDtype_Check(input.ptr())) {
    return InferredType(IntType::get());
  } else if (THPQScheme_Check(input.ptr())) {
    return InferredType(IntType::get());
  } else if (THPLayout_Check(input.ptr())) {
    return InferredType(IntType::get());
  }

  auto enum_type = py::module::import("enum").attr("Enum");
  py::bool_ isEnumValue = py::isinstance(input, enum_type);
  if (py::cast<bool>(isEnumValue)) {
    auto enum_class = input.attr("__class__");
    auto enum_type = py::cast<TypePtr>(
        py::module::import("torch.jit.annotations")
            .attr("try_ann_to_type")(enum_class, SourceRange()));
    return InferredType(enum_type);
  }

  py::bool_ isClass =
      py::module::import("inspect").attr("isclass")(input.get_type());
  if (py::cast<bool>(isClass)) {
    py::str qualifiedName = py::module::import("torch._jit_internal")
                                .attr("_qualified_name")(input.get_type());
    auto pyClass = py::module::import("torch.jit._state")
                       .attr("_get_script_class")(qualifiedName);
    if (!pyClass.is_none()) {
      auto cu = get_python_cu();
      const auto classname =
          c10::QualifiedName(py::cast<std::string>(qualifiedName));
      auto class_type = cu->get_class(classname);
      TORCH_INTERNAL_ASSERT(class_type);
      return InferredType(class_type);
Loading ...