Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch / include / torch / csrc / Exceptions.h
Size: Mime:
#pragma once

#include <exception>
#include <memory>
#include <mutex>
#include <queue>
#include <string>
#include <system_error>

#include <ATen/detail/FunctionTraits.h>
#include <c10/util/C++17.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/utils/auto_gil.h>
#include <torch/csrc/utils/cpp_stacktraces.h>
#include <torch/csrc/utils/pybind.h>

#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
#include <torch/csrc/distributed/c10d/exception.h>
#endif

static inline void PyErr_SetString(PyObject* type, const std::string& message) {
  PyErr_SetString(type, message.c_str());
}
/// NOTE [ Conversion Cpp Python Warning ]
/// The warning handler cannot set python warnings immediately
/// as it requires acquiring the GIL (potential deadlock)
/// and would need to cleanly exit if the warning raised a
/// python error. To solve this, we buffer the warnings and
/// process them when we go back to python.
/// This requires the two try/catch blocks below to handle the
/// following cases:
///   - If there is no Error raised in the inner try/catch, the
///     buffered warnings are processed as python warnings.
///     - If they don't raise an error, the function process with the
///       original return code.
///     - If any of them raise an error, the error is set (PyErr_*) and
///       the destructor will raise a cpp exception python_error() that
///       will be caught by the outer try/catch that will be able to change
///       the return value of the function to reflect the error.
///   - If an Error was raised in the inner try/catch, the inner try/catch
///     must set the python error. The buffered warnings are then
///     processed as cpp warnings as we cannot predict before hand
///     whether a python warning will raise an error or not and we
///     cannot handle two errors at the same time.
/// This advanced handler will only be used in the current thread.
/// If any other thread is used, warnings will be processed as
/// cpp warnings.
#define HANDLE_TH_ERRORS                              \
  try {                                               \
    torch::PyWarningHandler __enforce_warning_buffer; \
    try {
#define _CATCH_GENERIC_ERROR(ErrorType, PythonErrorType, retstmnt) \
  catch (const c10::ErrorType& e) {                                \
    auto msg = torch::get_cpp_stacktraces_enabled()                \
        ? e.what()                                                 \
        : e.what_without_backtrace();                              \
    PyErr_SetString(PythonErrorType, torch::processErrorMsg(msg)); \
    retstmnt;                                                      \
  }

// Only catch torch-specific exceptions
#define CATCH_CORE_ERRORS(retstmnt)                                     \
  catch (python_error & e) {                                            \
    e.restore();                                                        \
    retstmnt;                                                           \
  }                                                                     \
  catch (py::error_already_set & e) {                                   \
    e.restore();                                                        \
    retstmnt;                                                           \
  }                                                                     \
  _CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt)          \
  _CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt)          \
  _CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt)            \
  _CATCH_GENERIC_ERROR(                                                 \
      NotImplementedError, PyExc_NotImplementedError, retstmnt)         \
  _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
  _CATCH_GENERIC_ERROR(                                                 \
      OutOfMemoryError, THPException_OutOfMemoryError, retstmnt)        \
  _CATCH_GENERIC_ERROR(                                                 \
      DistBackendError, THPException_DistBackendError, retstmnt)        \
  _CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt)             \
  catch (torch::PyTorchError & e) {                                     \
    auto msg = torch::processErrorMsg(e.what());                        \
    PyErr_SetString(e.python_type(), msg);                              \
    retstmnt;                                                           \
  }

#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
#define CATCH_C10D_ERRORS(retstmnt)              \
  catch (const c10d::TimeoutError& e) {          \
    auto msg = torch::processErrorMsg(e.what()); \
    PyErr_SetString(PyExc_TimeoutError, msg);    \
    retstmnt;                                    \
  }                                              \
  catch (const c10d::C10dError& e) {             \
    auto msg = torch::processErrorMsg(e.what()); \
    PyErr_SetString(PyExc_RuntimeError, msg);    \
    retstmnt;                                    \
  }
#else
#define CATCH_C10D_ERRORS(retstmnt)
#endif

#define CATCH_TH_ERRORS(retstmnt) \
  CATCH_CORE_ERRORS(retstmnt)     \
  CATCH_C10D_ERRORS(retstmnt)

#define CATCH_ALL_ERRORS(retstmnt)               \
  CATCH_TH_ERRORS(retstmnt)                      \
  catch (const std::exception& e) {              \
    auto msg = torch::processErrorMsg(e.what()); \
    PyErr_SetString(PyExc_RuntimeError, msg);    \
    retstmnt;                                    \
  }

#define END_HANDLE_TH_ERRORS_PYBIND                                 \
  }                                                                 \
  catch (...) {                                                     \
    __enforce_warning_buffer.set_in_exception();                    \
    throw;                                                          \
  }                                                                 \
  }                                                                 \
  catch (py::error_already_set & e) {                               \
    throw;                                                          \
  }                                                                 \
  catch (py::builtin_exception & e) {                               \
    throw;                                                          \
  }                                                                 \
  catch (torch::jit::JITException & e) {                            \
    throw;                                                          \
  }                                                                 \
  catch (const std::exception& e) {                                 \
    torch::translate_exception_to_python(std::current_exception()); \
    throw py::error_already_set();                                  \
  }

#define END_HANDLE_TH_ERRORS_RET(retval)                            \
  }                                                                 \
  catch (...) {                                                     \
    __enforce_warning_buffer.set_in_exception();                    \
    throw;                                                          \
  }                                                                 \
  }                                                                 \
  catch (const std::exception& e) {                                 \
    torch::translate_exception_to_python(std::current_exception()); \
    return retval;                                                  \
  }

#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)

extern PyObject *THPException_FatalError, *THPException_LinAlgError,
    *THPException_OutOfMemoryError, *THPException_DistBackendError;

// Throwing this exception means that the python error flags have been already
// set and control should be immediately returned to the interpreter.
struct python_error : public std::exception {
  python_error() : type(nullptr), value(nullptr), traceback(nullptr) {}

  python_error(const python_error& other)
      : type(other.type),
        value(other.value),
        traceback(other.traceback),
        message(other.message) {
    pybind11::gil_scoped_acquire gil;
    Py_XINCREF(type);
    Py_XINCREF(value);
    Py_XINCREF(traceback);
  }

  python_error(python_error&& other) noexcept
      : type(other.type),
        value(other.value),
        traceback(other.traceback),
        message(std::move(other.message)) {
    other.type = nullptr;
    other.value = nullptr;
    other.traceback = nullptr;
  }

  ~python_error() override {
    if (type || value || traceback) {
      pybind11::gil_scoped_acquire gil;
      Py_XDECREF(type);
      Py_XDECREF(value);
      Py_XDECREF(traceback);
    }
  }

  const char* what() const noexcept override {
    return message.c_str();
  }

  void build_message() {
    // Ensure we have the GIL.
    pybind11::gil_scoped_acquire gil;

    // No errors should be set when we enter the function since PyErr_Fetch
    // clears the error indicator.
    TORCH_INTERNAL_ASSERT(!PyErr_Occurred());

    // Default message.
    message = "python_error";

    // Try to retrieve the error message from the value.
    if (value != nullptr) {
      // Reference count should not be zero.
      TORCH_INTERNAL_ASSERT(Py_REFCNT(value) > 0);

      PyObject* pyStr = PyObject_Str(value);
      if (pyStr != nullptr) {
        PyObject* encodedString =
            PyUnicode_AsEncodedString(pyStr, "utf-8", "strict");
        if (encodedString != nullptr) {
          char* bytes = PyBytes_AS_STRING(encodedString);
          if (bytes != nullptr) {
            // Set the message.
            message = std::string(bytes);
          }
          Py_XDECREF(encodedString);
        }
        Py_XDECREF(pyStr);
      }
    }

    // Clear any errors since we don't want to propagate errors for functions
    // that are trying to build a string for the error message.
    PyErr_Clear();
  }

  /** Saves the exception so that it can be re-thrown on a different thread */
  inline void persist() {
    if (type)
      return; // Don't overwrite exceptions
    // PyErr_Fetch overwrites the pointers
    pybind11::gil_scoped_acquire gil;
    Py_XDECREF(type);
    Py_XDECREF(value);
    Py_XDECREF(traceback);
    PyErr_Fetch(&type, &value, &traceback);
    build_message();
  }

  /** Sets the current Python error from this exception */
  inline void restore() {
    if (!type)
      return;
    // PyErr_Restore steals references
    pybind11::gil_scoped_acquire gil;
    Py_XINCREF(type);
    Py_XINCREF(value);
    Py_XINCREF(traceback);
    PyErr_Restore(type, value, traceback);
  }

  PyObject* type;
  PyObject* value;
  PyObject* traceback;

  // Message to return to the user when 'what()' is invoked.
  std::string message;
};

bool THPException_init(PyObject* module);

namespace torch {

// Set python current exception from a C++ exception
TORCH_PYTHON_API void translate_exception_to_python(const std::exception_ptr&);

TORCH_PYTHON_API std::string processErrorMsg(std::string str);

// Abstract base class for exceptions which translate to specific Python types
struct PyTorchError : public std::exception {
  PyTorchError() = default;
  PyTorchError(std::string msg_) : msg(std::move(msg_)) {}
  virtual PyObject* python_type() = 0;
  const char* what() const noexcept override {
    return msg.c_str();
  }
  std::string msg;
};

// Declare a printf-like function on gcc & clang
// The compiler can then warn on invalid format specifiers
#ifdef __GNUC__
#define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \
  __attribute__((format(printf, FORMAT_INDEX, VA_ARGS_INDEX)))
#else
#define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX)
#endif

// Translates to Python IndexError
struct IndexError : public PyTorchError {
  using PyTorchError::PyTorchError;
  IndexError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
  PyObject* python_type() override {
    return PyExc_IndexError;
  }
};

// Translates to Python TypeError
struct TypeError : public PyTorchError {
  using PyTorchError::PyTorchError;
  TORCH_PYTHON_API TypeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
  PyObject* python_type() override {
    return PyExc_TypeError;
  }
};

// Translates to Python ValueError
struct ValueError : public PyTorchError {
  using PyTorchError::PyTorchError;
  TORCH_PYTHON_API ValueError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
  PyObject* python_type() override {
    return PyExc_ValueError;
  }
};

// Translates to Python NotImplementedError
struct NotImplementedError : public PyTorchError {
  NotImplementedError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
  NotImplementedError() = default;
  PyObject* python_type() override {
    return PyExc_NotImplementedError;
  }
};

// Translates to Python AttributeError
struct AttributeError : public PyTorchError {
  AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
  PyObject* python_type() override {
    return PyExc_AttributeError;
  }
};

// Translates to Python LinAlgError
struct LinAlgError : public PyTorchError {
  LinAlgError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
  PyObject* python_type() override {
    return THPException_LinAlgError;
  }
};

// ATen warning handler for Python
struct PyWarningHandler {
  // Move actual handler into a separate class with a noexcept
  // destructor. Otherwise, we need to force all WarningHandler
  // subclasses to have a noexcept(false) destructor.
  struct InternalHandler : at::WarningHandler {
    ~InternalHandler() override = default;
    void process(const c10::Warning& warning) override;

    std::vector<c10::Warning> warning_buffer_;
  };

 public:
  /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
  TORCH_PYTHON_API PyWarningHandler() noexcept(true);
  // NOLINTNEXTLINE(bugprone-exception-escape)
  TORCH_PYTHON_API ~PyWarningHandler() noexcept(false);

  /** Call if an exception has been thrown

   *  Necessary to determine if it is safe to throw from the desctructor since
   *  std::uncaught_exception is buggy on some platforms and generally
   *  unreliable across dynamic library calls.
   */
  void set_in_exception() {
    in_exception_ = true;
  }

 private:
  InternalHandler internal_handler_;
  at::WarningHandler* prev_handler_;
  bool in_exception_;
};

namespace detail {
template <typename Func, size_t i>
using Arg = typename invoke_traits<Func>::template arg<i>::type;

template <typename Func, size_t... Is>
auto wrap_pybind_function_impl_(
    Func&& f,
    std::index_sequence<Is...>,
    bool release_gil) {
  using result_type = typename invoke_traits<Func>::result_type;
  namespace py = pybind11;

  // f=f is needed to handle function references on older compilers
  return [f = std::forward<Func>(f),
          release_gil](Arg<Func, Is>... args) -> result_type {
    HANDLE_TH_ERRORS
    if (release_gil) {
      py::gil_scoped_release no_gil;
      return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
    } else {
      return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
    }
    END_HANDLE_TH_ERRORS_PYBIND
  };
}
} // namespace detail

// Wrap a function with TH error and warning handling.
// Returns a function object suitable for registering with pybind11.
template <typename Func>
auto wrap_pybind_function(Func&& f) {
  using traits = invoke_traits<Func>;
  return torch::detail::wrap_pybind_function_impl_(
      std::forward<Func>(f), std::make_index_sequence<traits::arity>{}, false);
}

// Wrap a function with TH error, warning handling and releases the GIL.
// Returns a function object suitable for registering with pybind11.
template <typename Func>
auto wrap_pybind_function_no_gil(Func&& f) {
  using traits = invoke_traits<Func>;
  return torch::detail::wrap_pybind_function_impl_(
      std::forward<Func>(f), std::make_index_sequence<traits::arity>{}, true);
}

} // namespace torch