Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ backends / cudnn / rnn.py

import torch.cuda

try:
    from torch._C import _cudnn
except ImportError:
    # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
    # so it's safe to not emit any checks here.
    _cudnn = None  # type: ignore


def get_cudnn_mode(mode):
    if mode == 'RNN_RELU':
        return int(_cudnn.RNNMode.rnn_relu)
    elif mode == 'RNN_TANH':
        return int(_cudnn.RNNMode.rnn_tanh)
    elif mode == 'LSTM':
        return int(_cudnn.RNNMode.lstm)
    elif mode == 'GRU':
        return int(_cudnn.RNNMode.gru)
    else:
        raise Exception("Unknown mode: {}".format(mode))


# NB: We don't actually need this class anymore (in fact, we could serialize the
# dropout state for even better reproducibility), but it is kept for backwards
# compatibility for old models.
class Unserializable(object):

    def __init__(self, inner):
        self.inner = inner

    def get(self):
        return self.inner

    def __getstate__(self):
        # Note: can't return {}, because python2 won't call __setstate__
        # if the value evaluates to False
        return "<unserializable>"

    def __setstate__(self, state):
        self.inner = None


def init_dropout_state(dropout, train, dropout_seed, dropout_state):
    dropout_desc_name = 'desc_' + str(torch.cuda.current_device())
    dropout_p = dropout if train else 0
    if (dropout_desc_name not in dropout_state) or (dropout_state[dropout_desc_name].get() is None):
        if dropout_p == 0:
            dropout_state[dropout_desc_name] = Unserializable(None)
        else:
            dropout_state[dropout_desc_name] = Unserializable(torch._cudnn_init_dropout_state(  # type: ignore
                dropout_p,
                train,
                dropout_seed,
                self_ty=torch.uint8,
                device=torch.device('cuda')))
    dropout_ts = dropout_state[dropout_desc_name].get()
    return dropout_ts