Repository URL to install this package:
|
Version:
0.0.11 ▾
|
torch-wrapper
/
util_torch.py
|
|---|
import random
import numpy as np
import torch
def seed_everything(seed: int):
"""mimic pytorch lightning
"""
seed = int(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def seed_reproducible(seed: int, fail_on_nondeterministic_algorithms=True):
"""seed_everything and also makes the program deterministic.
Follows: https://pytorch.org/docs/stable/notes/randomness.html
first `seed_everything`
"""
seed_everything(seed)
# causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance.
import torch.backends.cudnn
torch.backends.cudnn.benchmark = False
# Avoiding nondeterministic algorithms
if fail_on_nondeterministic_algorithms:
torch.use_deterministic_algorithms(True, warn_only=False)
else:
torch.use_deterministic_algorithms(True, warn_only=True)