Repository URL to install this package:
|
Version:
2.0.0rc1 ▾
|
# isort: off
try:
import torch # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"PyTorch isn't installed. To install PyTorch, run 'pip install torch'"
)
# isort: on
from ray.train.torch.torch_checkpoint import TorchCheckpoint
from ray.train.torch.config import TorchConfig
from ray.train.torch.torch_predictor import TorchPredictor
from ray.train.torch.torch_trainer import TorchTrainer
from ray.train.torch.train_loop_utils import (
TorchWorkerProfiler,
accelerate,
backward,
enable_reproducibility,
get_device,
prepare_data_loader,
prepare_model,
prepare_optimizer,
)
__all__ = [
"TorchTrainer",
"TorchCheckpoint",
"TorchConfig",
"accelerate",
"get_device",
"prepare_model",
"prepare_optimizer",
"prepare_data_loader",
"backward",
"enable_reproducibility",
"TorchWorkerProfiler",
"TorchPredictor",
]