Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-wrapper / util_torch.py
Size: Mime:
import random

import numpy as np
import torch


def seed_everything(seed: int):
    """mimic pytorch lightning

    """
    seed = int(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def seed_reproducible(seed: int, fail_on_nondeterministic_algorithms=True):
    """seed_everything and also makes the program deterministic.

    Follows: https://pytorch.org/docs/stable/notes/randomness.html
    first `seed_everything`
    """
    seed_everything(seed)
    # causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance.
    import torch.backends.cudnn
    torch.backends.cudnn.benchmark = False
    # Avoiding nondeterministic algorithms
    if fail_on_nondeterministic_algorithms:
        torch.use_deterministic_algorithms(True, warn_only=False)
    else:
        torch.use_deterministic_algorithms(True, warn_only=True)