Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / ATen / cuda / detail / DeviceThreadHandles.h

// Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
// These handles are tied to device, and these libraries requires/recommends not to
// share handles across host threads.
//
// These libraries recommend using one handle per host thread. We may not want to do
// this because threads are relatively light-weight, but creating and destroying
// handles is expensive (destroying the handle causes synchronizations). DataParallel,
// for example, creates new threads for each forward pass.
//
// This file implements a handle pool mechanism. The handle pool returns handles on
// demand as threads request them. If all existing handles in the pool are in use,
// it creates a new one. As threads terminate, they release handles back into the pool.
// In this way, the handle pool never creates more handles than the high-water mark of
// active threads, so it's efficient with DataParallel.

#pragma once

#include <unordered_map>
#include <vector>
#include <utility>
#include <mutex>
#include <memory>

#include <ATen/cuda/Exceptions.h>

namespace at { namespace cuda { namespace {

template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {

    struct Handle {
    Handle_t handle;
    Handle(bool create = false) : handle(nullptr)
    {
        if(create) Create(&handle);
    }
    // std::vector.emplace() and push_back() may route through temporaries and call
    // copy/move constructors along the way.  If this is the case, we don't want
    // the destructors of temporaries to call cudnnDestroy on the handle.
    // We can achieve safety (for the narrow case of stashing within std::vectors)
    // by making Handle moveable but not copyable, and transferring handle ownership
    // to the latest constructed object.  This is not a substitute for full-blown
    // reference counting, but reference counting may be overkill here.
    // Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
    // unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
    Handle(const Handle& rhs) = delete;
    // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
    Handle(Handle&& rhs) : Handle() { std::swap(handle, rhs.handle); }
    // operator= takes argument by value
    Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
    ~Handle() {
        if(handle) Destroy(handle);
    }
    };

    std::mutex mutex;

    // Handles are lazily created as different threads request them,
    // but are never destroyed until the end of the process.
    // The maximum number of handles this process will create for each device is equal
    // to the high-water mark of the number of concurrently active threads that request
    // handles for that device.
    // When threads terminate, they release their handles back into the pool for reuse.
    // Otherwise, new handles would be created every time new threads were spawned,
    // resulting in poor performance for Python modules that repeatedly or frequently
    // spawned new sets of threads (like DataParallel, which creates a new set of threads
    // for each forward pass).
    //
    // To prevent potential deadlocks, we explicitly choose not to cap the number
    // of handles that are created per device.
    // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
    // only 4 can make forward progress at any time. The other 4 will not release their
    // handles until they exit, so the fifth cannot make progress until then.  This is
    // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
    // intermediate point (ie, before any of them have exited).  We have no way to anticipate
    // or enforce that user threads will not attempt such intermediate synchronization.
    // The only way to ensure safety is to avoid imposing a cap on the number of handles.
    std::unordered_map<int, std::vector<Handle>> created_handles;
    std::unordered_map<int, std::vector<Handle_t>> available_handles;

    // PoolWindow lazily creates and caches the handles that a particular thread is using,
    // so in the common case handle access doesn't incur either handle creation or a mutex lock.
    class PoolWindow
    {
    public:
    PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
    ~PoolWindow(){ release(); }

    Handle_t reserve(int device)
    {
        // If this thread already has a handle for this device, return it
        if(my_handles.find(device) != my_handles.end())
        return my_handles[device];

        // otherwise, either grab a handle from the pool if one is available,
        // or if not, create a new one.
        auto parent = weak_parent.lock();
        TORCH_CHECK(parent, "Cannot create handle during program termination");
        std::lock_guard<std::mutex> guard(parent->mutex);

        if(parent->available_handles[device].size() > 0)
        {
        my_handles[device] = parent->available_handles[device].back();
        parent->available_handles[device].pop_back();
        }
        else
        {
        // In local testing, I do observe that emplace_back sometimes routes through temporaries
        // that incur move-constructor and destructor calls.  See comments in Handle above.
        parent->created_handles[device].emplace_back(true /*create*/);
        my_handles[device] = parent->created_handles[device].back().handle;
        }

        return my_handles[device];
    }

    private:
    // Stores the per-device handles currently owned by this thread
    std::unordered_map<int, Handle_t> my_handles;

    std::weak_ptr<DeviceThreadHandlePool> weak_parent;

    // Called by the destructor.  Releases this thread's handles back into the pool.
    void release() {
        if(my_handles.size() > 0) {
            auto parent = weak_parent.lock();
            if (!parent) {
                // If this thread exits after atexit handlers have completed, the
                // cuda context itself may be invalid, so we must leak the handles.
                return;
            }

            std::lock_guard<std::mutex> guard(parent->mutex);
            for(auto d_h : my_handles)
                parent->available_handles[d_h.first].push_back(d_h.second);
        }
    }
    };

    // Warning:
    // If you want to change this function, be aware that this function will be called
    // by multiple threads and there is no mutex guarding the call of this function, so
    // make sure your implementation is thread-safe.
    PoolWindow *newPoolWindow() {
        // The returned pointer will be owned by a thread local variable
        // so that different threads does not share the same PoolWindow.
        return new PoolWindow(this->shared_from_this());
    }
};

}}}  // namespace at::cuda::detail::<anonymous>