Repository URL to install this package:
|
Version:
0.0.26 ▾
|
torch-wrapper
/
util_torch.py
|
|---|
import random
import warnings
import numpy as np
import torch
from packaging import version
def seed_everything(seed: int):
"""mimic pytorch lightning
"""
seed = int(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def seed_reproducible(seed: int, fail_on_nondeterministic_algorithms=True):
"""seed_everything and also makes the program deterministic.
Follows: https://pytorch.org/docs/stable/notes/randomness.html
first `seed_everything`
"""
seed_everything(seed)
# causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance.
import torch.backends.cudnn
torch.backends.cudnn.benchmark = False
# Avoiding nondeterministic algorithms
if fail_on_nondeterministic_algorithms:
torch.use_deterministic_algorithms(True)
else:
torch_version = version.parse(torch.__version__)
supported = version.parse('1.11') <= torch_version
if supported:
torch.use_deterministic_algorithms(True, warn_only=True)
else:
warnings.warn('not fail_on_nondeterministic_algorithms is not supported due to '
'missing warn_only in torch < v1.11')
torch.use_deterministic_algorithms(True)