#pragma once
#include <c10/util/Exception.h>
#include <mutex>
#include <vector>
namespace torch {
namespace autograd {
namespace utils {
// Warning handler for multi-threaded contexts. Gather warnings from
// all threads into a single queue, then process together at the end
// in the main thread.
class DelayWarningHandler : public at::WarningHandler {
public:
~DelayWarningHandler() override = default;
void replay_warnings();
private:
void process(const c10::Warning& warning) override;
std::vector<c10::Warning> warnings_;
std::mutex mutex_;
};
} // namespace utils
} // namespace autograd
} // namespace torch