Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / torch / csrc / Exceptions.h

#pragma once

#include <exception>
#include <string>
#include <memory>
#include <queue>
#include <mutex>

#include <c10/util/Exception.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/THP_export.h>
#include <torch/csrc/utils/auto_gil.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <c10/util/StringUtil.h>
#include <ATen/detail/FunctionTraits.h>

static inline void PyErr_SetString(PyObject* type, const std::string& message) {
  PyErr_SetString(type, message.c_str());
}
/// NOTE [ Conversion Cpp Python Warning ]
/// The warning handler cannot set python warnings immediately
/// as it requires acquiring the GIL (potential deadlock)
/// and would need to cleanly exit if the warning raised a
/// python error. To solve this, we buffer the warnings and
/// process them when we go back to python.
/// This requires the two try/catch blocks below to handle the
/// following cases:
///   - If there is no Error raised in the inner try/catch, the
///     buffered warnings are processed as python warnings.
///     - If they don't raise an error, the function process with the
///       original return code.
///     - If any of them raise an error, the error is set (PyErr_*) and
///       the destructor will raise a cpp exception python_error() that
///       will be caught by the outer try/catch that will be able to change
///       the return value of the function to reflect the error.
///   - If an Error was raised in the inner try/catch, the inner try/catch
///     must set the python error. The buffered warnings are then
///     processed as cpp warnings as we cannot predict before hand
///     whether a python warning will raise an error or not and we
///     cannot handle two errors at the same time.
/// This advanced handler will only be used in the current thread.
/// If any other thread is used, warnings will be processed as
/// cpp warnings.
#define HANDLE_TH_ERRORS                                             \
  try {                                                              \
    torch::PyWarningHandler __enforce_warning_buffer;                \
    try {

// Only catch torch-specific exceptions
#define CATCH_TH_ERRORS(retstmnt)                                    \
    catch (python_error & e) {                                       \
      e.restore();                                                   \
      retstmnt;                                                      \
    }                                                                \
    catch (const c10::IndexError& e) {                               \
      auto msg = torch::get_cpp_stacktraces_enabled() ?              \
                    e.what() : e.what_without_backtrace();           \
      PyErr_SetString(PyExc_IndexError, torch::processErrorMsg(msg)); \
      retstmnt;                                                      \
    }                                                                \
    catch (const c10::ValueError& e) {                               \
      auto msg = torch::get_cpp_stacktraces_enabled() ?              \
                    e.what() : e.what_without_backtrace();           \
      PyErr_SetString(PyExc_ValueError, torch::processErrorMsg(msg)); \
      retstmnt;                                                      \
    }                                                                \
    catch (const c10::TypeError& e) {                               \
      auto msg = torch::get_cpp_stacktraces_enabled() ?              \
                    e.what() : e.what_without_backtrace();           \
      PyErr_SetString(PyExc_TypeError, torch::processErrorMsg(msg)); \
      retstmnt;                                                      \
    }                                                                \
    catch (const c10::Error& e) {                                    \
      auto msg = torch::get_cpp_stacktraces_enabled() ?              \
                    e.what() : e.what_without_backtrace();           \
      PyErr_SetString(PyExc_RuntimeError, torch::processErrorMsg(msg)); \
      retstmnt;                                                      \
    }                                                                \
    catch (torch::PyTorchError & e) {                                \
      auto msg = torch::processErrorMsg(e.what());                   \
      PyErr_SetString(e.python_type(), msg);                         \
      retstmnt;                                                      \
    }

#define CATCH_ALL_ERRORS(retstmnt)                                   \
    CATCH_TH_ERRORS(retstmnt)                                        \
    catch (const std::exception& e) {                                \
      auto msg = torch::processErrorMsg(e.what());                   \
      PyErr_SetString(PyExc_RuntimeError, msg);                      \
      retstmnt;                                                      \
    }

#define END_HANDLE_TH_ERRORS_PYBIND                                      \
    }                                                                    \
    catch(...) {                                                         \
      __enforce_warning_buffer.set_in_exception();                       \
      throw;                                                             \
    }                                                                    \
  }                                                                      \
  catch (py::error_already_set & e) {                                    \
    throw;                                                               \
  }                                                                      \
  catch (py::builtin_exception & e) {                                    \
    throw;                                                               \
  }                                                                      \
  catch (torch::jit::JITException & e) {                                 \
    throw;                                                               \
  }                                                                      \
  CATCH_ALL_ERRORS(throw py::error_already_set())

#define END_HANDLE_TH_ERRORS_RET(retval)                             \
    }                                                                \
    catch(...) {                                                     \
      __enforce_warning_buffer.set_in_exception();                   \
      throw;                                                         \
    }                                                                \
  }                                                                  \
  CATCH_ALL_ERRORS(return retval)

#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)

extern PyObject *THPException_FatalError;

// Throwing this exception means that the python error flags have been already
// set and control should be immediately returned to the interpreter.
struct python_error : public std::exception {
  python_error() : type(nullptr), value(nullptr), traceback(nullptr) {}

  python_error(const python_error& other)
      : type(other.type),
        value(other.value),
        traceback(other.traceback),
        message(other.message) {
    pybind11::gil_scoped_acquire gil;
    Py_XINCREF(type);
    Py_XINCREF(value);
    Py_XINCREF(traceback);
  }

  python_error(python_error&& other) {
    type = other.type;
    value = other.value;
    traceback = other.traceback;
    message = std::move(other.message);
    other.type = nullptr;
    other.value = nullptr;
    other.traceback = nullptr;
  }

  ~python_error() override {
    if (type || value || traceback) {
      pybind11::gil_scoped_acquire gil;
      Py_XDECREF(type);
      Py_XDECREF(value);
      Py_XDECREF(traceback);
    }
  }

  virtual const char* what() const noexcept override {
    return message.c_str();
  }

  void build_message() {
    // Ensure we have the GIL.
    pybind11::gil_scoped_acquire gil;

    // No errors should be set when we enter the function since PyErr_Fetch
    // clears the error indicator.
    TORCH_INTERNAL_ASSERT(!PyErr_Occurred());

    // Default message.
    message = "python_error";

    // Try to retrieve the error message from the value.
    if (value != nullptr) {
      // Reference count should not be zero.
      TORCH_INTERNAL_ASSERT(Py_REFCNT(value) > 0);

      PyObject* pyStr = PyObject_Str(value);
      if (pyStr != nullptr) {
        PyObject* encodedString =
            PyUnicode_AsEncodedString(pyStr, "utf-8", "strict");
        if (encodedString != nullptr) {
          char* bytes = PyBytes_AS_STRING(encodedString);
          if (bytes != nullptr) {
            // Set the message.
            message = std::string(bytes);
          }
          Py_XDECREF(encodedString);
        }
        Py_XDECREF(pyStr);
      }
    }

    // Clear any errors since we don't want to propagate errors for functions
    // that are trying to build a string for the error message.
    PyErr_Clear();
  }

  /** Saves the exception so that it can be re-thrown on a different thread */
  inline void persist() {
    if (type) return; // Don't overwrite exceptions
    // PyErr_Fetch overwrites the pointers
    pybind11::gil_scoped_acquire gil;
    Py_XDECREF(type);
    Py_XDECREF(value);
    Py_XDECREF(traceback);
    PyErr_Fetch(&type, &value, &traceback);
    build_message();
  }

  /** Sets the current Python error from this exception */
  inline void restore() {
    if (!type) return;
    // PyErr_Restore steals references
    pybind11::gil_scoped_acquire gil;
    Py_XINCREF(type);
    Py_XINCREF(value);
    Py_XINCREF(traceback);
    PyErr_Restore(type, value, traceback);
  }

  PyObject* type;
  PyObject* value;
  PyObject* traceback;

  // Message to return to the user when 'what()' is invoked.
  std::string message;
};

bool THPException_init(PyObject *module);

namespace torch {

THP_CLASS std::string processErrorMsg(std::string str);

THP_API bool get_cpp_stacktraces_enabled();

// Abstract base class for exceptions which translate to specific Python types
struct PyTorchError : public std::exception {
  PyTorchError(const std::string& msg_ = std::string()): msg(msg_) {}
  virtual PyObject* python_type() = 0;
  const char* what() const noexcept override {
    return msg.c_str();
  }
  std::string msg;
};

// Declare a printf-like function on gcc & clang
// The compiler can then warn on invalid format specifiers
#ifdef __GNUC__
#  define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \
    __attribute__((format (printf, FORMAT_INDEX, VA_ARGS_INDEX)))
#else
#  define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX)
#endif

// Translates to Python IndexError
struct IndexError : public PyTorchError {
  using PyTorchError::PyTorchError;
  IndexError(const char *format, ...) TORCH_FORMAT_FUNC(2, 3);
  PyObject* python_type() override {
    return PyExc_IndexError;
  }
};

// Translates to Python TypeError
struct TypeError : public PyTorchError {
  using PyTorchError::PyTorchError;
  TORCH_API TypeError(const char *format, ...) TORCH_FORMAT_FUNC(2, 3);
  PyObject* python_type() override {
    return PyExc_TypeError;
  }
};

// Translates to Python ValueError
struct ValueError : public PyTorchError {
  using PyTorchError::PyTorchError;
  ValueError(const char *format, ...) TORCH_FORMAT_FUNC(2, 3);
  PyObject* python_type() override {
    return PyExc_ValueError;
  }
};

// Translates to Python NotImplementedError
struct NotImplementedError : public PyTorchError {
  NotImplementedError() {}
  PyObject* python_type() override {
    return PyExc_NotImplementedError;
  }
};

// Translates to Python AttributeError
struct AttributeError : public PyTorchError {
  AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
  PyObject* python_type() override {
    return PyExc_AttributeError;
  }
};

struct WarningMeta {
  WarningMeta(const c10::SourceLocation& _source_location,
      const std::string& _msg,
      const bool _verbatim) :
      source_location_{_source_location},
      msg_{_msg},
      verbatim_{_verbatim} { }


  const c10::SourceLocation source_location_;
  const std::string msg_;
  const bool verbatim_;
};

// ATen warning handler for Python
struct PyWarningHandler: at::WarningHandler {
public:
/// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
  TORCH_API PyWarningHandler() noexcept(true);
  TORCH_API ~PyWarningHandler() noexcept(false) override;

  void process(const at::SourceLocation &source_location,
               const std::string &msg,
               const bool verbatim) override;

  /** Call if an exception has been thrown

   *  Necessary to determine if it is safe to throw from the desctructor since
   *  std::uncaught_exception is buggy on some platforms and generally
   *  unreliable across dynamic library calls.
   */
  void set_in_exception() {
    in_exception_ = true;
  }

private:
  using warning_buffer_t = std::vector<WarningMeta>;

  warning_buffer_t warning_buffer_;

  at::WarningHandler* prev_handler_;
  bool in_exception_;
};
Loading ...