Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ distributed / remote_device.py

from typing import Optional, Union

import torch


class _remote_device:
    """
    Represents a device on a remote worker.

    Args:
        remote_device (str or torch.device): Represents a device on a remote worker.
            The string format should be one of the following:

                1. "<workername>/<device>", where the device field can be parsed as torch.device type.
                   E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
                   In addition, the device field can be optional and the default value is "cpu".
                2. "rank:<rank>/<device>", where <rank> is the rank of the
                   process and device can be parsed as torch.device type.
                   E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"
                3. <workername> and <rank> are optional and formats like "cpu"
                    and "cuda:1", just represent local devices.
    """

    def __init__(self, remote_device: Union[str, torch.device]):
        PARSE_ERROR = (
            f"Could not parse remote_device: {remote_device}. The valid format is "
            "'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'"
        )
        self._worker_name = None
        self._rank = None
        self._device: Optional[Union[str, int, torch.device]] = None

        if isinstance(remote_device, torch.device):
            self._device = remote_device
        elif isinstance(remote_device, str):
            fields = remote_device.split("/")
            if len(fields) == 2:
                self._worker_name, self._device = fields
            elif len(fields) == 1:
                # Check if this is a valid device.
                if _remote_device._is_valid_local_device(fields[0]):
                    self._device = fields[0]
                else:
                    self._worker_name = fields[0]
                    self._device = "cpu"
            else:
                raise ValueError(PARSE_ERROR)
        else:
            raise TypeError(f'Invalid type for remote_device: {type(remote_device)}')

        # Do some basic sanity check (no empty string)
        if self._worker_name is not None and not self._worker_name:
            raise ValueError(PARSE_ERROR)

        # Validate the device.
        self._device = torch.device(self._device)

        # Check for rank based format.
        if self._worker_name is not None:
            fields = self._worker_name.split(":")
            if len(fields) == 2:
                # rank:<rank>/device format, extract rank
                if fields[0] == "rank" and fields[1].isdigit():
                    self._rank = int(fields[1])  # type: ignore[assignment]
                    self._worker_name = None
                else:
                    raise ValueError(PARSE_ERROR)
            elif len(fields) > 2:
                raise ValueError(PARSE_ERROR)

    @staticmethod
    def _is_valid_local_device(device):
        # Check for torch.device
        try:
            torch.device(device)
            return True
        except Exception:
            return False

    def worker_name(self) -> Optional[str]:
        """
        Returns the name of remote worker representing the remote device.
        Returns ``None`` if no worker name is available.
        """
        return self._worker_name

    def rank(self) -> Optional[int]:
        """
        Returns the rank of remote worker representing the remote device.
        Returns ``None`` if no rank is available.
        """
        return self._rank

    def device(self) -> torch.device:
        """
        Returns the local device on the remote worker.
        """
        return self._device  # type: ignore[return-value]

    def __repr__(self):
        if self._device is not None:
            if self._worker_name is not None:
                return f'{self._worker_name}/{self._device}'
            elif self._rank is not None:
                return f'rank:{self._rank}/{self._device}'
            else:
                return str(self._device)
        else:
            if self._worker_name is not None:
                return f'{self._worker_name}'
            elif self._rank is not None:
                return f'{self._rank}'
            else:
                raise RuntimeError('Invalid state!')

    def __eq__(self, other):
        if not isinstance(other, _remote_device):
            return False

        if (
            self._worker_name == other._worker_name
            and self._device == other._device
            and self._rank == other._rank
        ):
            return True

        return False


    def __hash__(self):
        return hash(self._worker_name) ^ \
            hash(self._device) ^ \
            hash(self._rank)