Repository URL to install this package:
Version:
2.0.0rc1 ▾
|
import logging
from typing import TypeVar, Dict
from ray.train._internal.utils import Singleton
from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI
EncodedData = TypeVar("EncodedData")
logger = logging.getLogger(__name__)
@DeveloperAPI
class BackendConfig:
"""Parent class for configurations of training backend."""
@property
def backend_cls(self):
return Backend
@DeveloperAPI
class Backend(metaclass=Singleton):
"""Singleton for distributed communication backend.
Attributes:
share_cuda_visible_devices: If True, each worker
process will have CUDA_VISIBLE_DEVICES set as the visible device
IDs of all workers on the same node for this training instance.
If False, each worker will have CUDA_VISIBLE_DEVICES set to the
device IDs allocated by Ray for that worker.
"""
share_cuda_visible_devices: bool = False
def on_start(self, worker_group: WorkerGroup, backend_config: BackendConfig):
"""Logic for starting this backend."""
pass
def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig):
"""Logic for shutting down the backend."""
pass
@staticmethod
def encode_data(data_dict: Dict) -> EncodedData:
"""Logic to encode a data dict before sending to the driver.
This function will be called on the workers for any data that is
sent to the driver via ``session.report()``.
"""
return data_dict
@staticmethod
def decode_data(encoded_data: EncodedData) -> Dict:
"""Logic to decode an encoded data dict.
This function will be called on the driver after receiving the
encoded data dict from the worker.
"""
return encoded_data