Repository URL to install this package:
|
Version:
1.1.3 ▾
|
import logging
import os
import random
from typing import Optional, Union, Tuple, cast
import torch.distributed as dist
import numpy as np
import torch
from sarus_llm.device import get_device
logger = logging.getLogger(__name__)
def set_seed(
seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None
) -> int:
"""Function that sets seed for pseudo-random number generators across commonly used libraries.
This seeds PyTorch, NumPy, and the python.random module. For distributed jobs, each local process
sets its own seed, computed seed + rank.
For more details, see https://pytorch.org/docs/stable/notes/randomness.html.
Args:
seed (Optional[int]): the integer value seed. If `None`, a random seed will be generated and set.
debug_mode (Optional[Union[str, int]]): Controls debug_mode settings for deterministic operations within PyTorch.
* If `None`, don't set any PyTorch global values.
* If "default" or 0, don't error or warn on nondeterministic operations and additionally enable PyTorch CuDNN benchmark.
* If "warn" or 1, warn on nondeterministic operations and disable PyTorch CuDNN benchmark.
* If "error" or 2, error on nondeterministic operations and disable PyTorch CuDNN benchmark.
* For more details, see :func:`torch.set_deterministic_debug_mode` and
https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms.
Returns:
int: the current seed
Raises:
ValueError: If the input seed value is outside the required range.
"""
world_size, rank = get_world_size_and_rank()
max_val = np.iinfo(np.uint32).max - world_size + 1
min_val = np.iinfo(np.uint32).min
if seed is None:
rand_seed = torch.randint(min_val, max_val, (1,))
seed = _broadcast_tensor(rand_seed, 0).item() # type:ignore # sync seed across ranks
seed = cast(int, seed)
if seed < min_val or seed > max_val:
raise ValueError(
f"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]"
)
local_seed = seed + rank
if rank == 0:
logger.debug(
f"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}"
)
torch.manual_seed(local_seed)
np.random.seed(local_seed)
random.seed(local_seed)
if debug_mode is not None:
logger.debug(f"Setting deterministic debug mode to {debug_mode}")
torch.set_deterministic_debug_mode(debug_mode)
deterministic_debug_mode = torch.get_deterministic_debug_mode()
if deterministic_debug_mode == 0:
logger.debug("Disabling cuDNN deterministic mode")
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
else:
logger.debug("Enabling cuDNN deterministic mode")
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
return seed
def _broadcast_tensor(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcasts a tensor from a source to all other processes.
Args:
tensor (torch.Tensor): Tensor to broadcast.
src (int, optional): Source rank. Defaults to 0.
Returns:
torch.Tensor: Broadcasted tensor.
"""
if dist.is_available() and dist.is_initialized():
device = tensor.device
if dist.get_backend() == "nccl":
tensor = tensor.to(get_device("cuda"))
dist.broadcast(tensor, src=src, group=None)
return tensor.to(device)
else:
return tensor
def get_world_size_and_rank() -> Tuple[int, int]:
"""Function that gets the current world size (aka total number
of ranks) and rank number of the current process in the default process group.
Returns:
Tuple[int, int]: world size, rank
"""
if dist.is_available() and dist.is_initialized():
return torch.distributed.get_world_size(), torch.distributed.get_rank()
else:
return 1, 0