Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-wrapper / util_torch.py
Size: Mime:
import random
import warnings

import numpy as np
import torch
from packaging import version


def seed_everything(seed: int):
    """mimic pytorch lightning

    """
    seed = int(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def seed_reproducible(seed: int, fail_on_nondeterministic_algorithms=True):
    """seed_everything and also makes the program deterministic.

    Follows: https://pytorch.org/docs/stable/notes/randomness.html
    first `seed_everything`
    """
    seed_everything(seed)
    # causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance.
    import torch.backends.cudnn
    torch.backends.cudnn.benchmark = False
    # Avoiding nondeterministic algorithms
    if fail_on_nondeterministic_algorithms:
        torch.use_deterministic_algorithms(True)
    else:
        torch_version = version.parse(torch.__version__)
        supported = version.parse('1.11') <= torch_version
        if supported:
            torch.use_deterministic_algorithms(True, warn_only=True)
        else:
            warnings.warn('not fail_on_nondeterministic_algorithms is not supported due to '
                          'missing warn_only in torch < v1.11')
            torch.use_deterministic_algorithms(True)