Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ _C / _distributed_c10d.pyi

from torch import Tensor
from enum import Enum
from typing import Optional, List, Any, overload
from datetime import timedelta

# This module is defined in torch/csrc/distributed/c10d/init.cpp

_DEFAULT_FIRST_BUCKET_BYTES: int
_DEFAULT_NO_TIMEOUT: timedelta

class BuiltinCommHookType(Enum):
    ALLREDUCE = ...
    FP16_COMPRESS = ...

def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
def _register_builtin_comm_hook(reducer: Reducer, comm_hook_type: BuiltinCommHookType): ...

def _get_ddp_logging_data(reducer: Reducer): ...
def _set_construction_logging_data(
    reducer: Reducer,
    module_name: str,
    device_ids: List[int],
    output_device: int,
    broadcast_buffers: bool): ...

class _GradBucket:
    def __init__(self, tensors: List[Tensor]): ...
    def get_tensors(self) -> List[Tensor]: ...

class Reducer:
    def __init__(
        self,
        replicas: List[List[Tensor]],
        bucket_indices: List[List[int]],
        process_group: ProcessGroup,
        expect_sparse_gradients: List[List[bool]],
        bucket_bytes_cap: int,
        find_unused_parameters: bool,
        gradient_as_bucket_view: bool,
    ): ...
    def initialize_buckets(self, bucket_indices: List[List[int]]): ...
    ...

class ReduceOp(Enum):
    SUM = ...
    PRODUCT = ...
    MIN = ...
    MAX = ...
    BAND = ...
    BOR = ...
    BXOR = ...
    UNUSED = ...

class BroadcastOptions:
    rootRank: int
    rootTensor: int
    timeout: timedelta

class AllreduceOptions:
    reduceOp: ReduceOp
    timeout: timedelta

class AllreduceCoalescedOptions(AllreduceOptions):
    ...

class ReduceOptions:
    reduceOp: ReduceOp
    rootRank: int
    rootTensor: int
    timeout: timedelta

class AllGatherOptions:
    timeout: timedelta

class GatherOptions:
    rootRank: int
    timeout: timedelta

class ScatterOptions:
    rootRank: int
    timeout: timedelta

class ReduceScatterOptions:
    reduceOp: ReduceOp
    timeout: timedelta

class BarrierOptions:
    device_ids: List[int]
    timeout: timedelta

class AllToAllOptions:
    timeout: timedelta

class Store:
    def set(self, key: str, value: str): ...
    def get(self, key: str) -> bytes: ...
    def add(self, key: str, value: int) -> int: ...
    def delete_key(self, key: str) -> bool: ...
    def num_keys(self) -> int: ...
    def set_timeout(self, timeout: timedelta): ...
    @overload
    def wait(self, keys: List[str]): ...
    @overload
    def wait(self, keys: List[str], timeout: timedelta): ...

class FileStore(Store):
    def __init__(
        self,
        path: str,
        numWorkers: int
    ): ...

class HashStore(Store):
    def __init__(self): ...

class TCPStore(Store):
    def __init__(
        self,
        host_name: str,
        port: int,
        world_size: int,
        is_master: bool,
        timeout: timedelta,
    ): ...

class PrefixStore(Store):
    def __init__(
        self,
        prefix: str,
        store: Store
    ): ...

class Work:
    def is_completed(self) -> bool: ...
    def is_success(self) -> bool: ...
    def exception(self) -> Any: ...
    def wait(self, timeout: timedelta = _DEFAULT_NO_TIMEOUT) -> bool: ...
    def source_rank(self) -> int: ...
    def _source_rank(self) -> int: ...
    def result(self) -> List[Tensor]: ...
    def synchronize(self): ...
    ...

class ProcessGroup:
    def __init__(self): ...
    def rank(self) -> int: ...
    def size(self) -> int: ...
    @overload
    def broadcast(
        self,
        tensors: List[Tensor],
        opts = BroadcastOptions(),
    ) -> Work: ...
    @overload
    def broadcast(
        self,
        tensor: Tensor,
        root: int,
    ) -> Work: ...
    @overload
    def allreduce(
        self,
        tensors: List[Tensor],
        opts: AllreduceOptions = AllreduceOptions(),
    ) -> Work:  ...
    @overload
    def allreduce(
        self,
        tensors: List[Tensor],
        op = ReduceOp.SUM,
    ) -> Work: ...
    @overload
    def allreduce(
        self,
        tensor: Tensor,
        op = ReduceOp.SUM,
    ) -> Work: ...
    def allreduce_coalesced(
        self,
        tensors: List[Tensor],
        opts = AllreduceCoalescedOptions(),
    ) -> Work: ...
    @overload
    def reduce(
        self,
        tensors: List[Tensor],
        opts = ReduceOptions(),
    ) -> Work: ...
    @overload
    def reduce(
        self,
        tensor: Tensor,
        root: int,
        op = ReduceOp.SUM,
    ) -> Work: ...
    @overload
    def allgather(
        self,
        output_tensors: List[List[Tensor]],
        input_tensors: List[Tensor],
        opts = AllGatherOptions(),
    ) -> Work: ...
    @overload
    def allgather(
        self,
        output_tensors: List[Tensor],
        input_tensor: Tensor,
    ) -> Work: ...
    def allgather_coalesced(
        self,
        output_lists: List[List[Tensor]],
        input_list: List[Tensor],
        opts = AllGatherOptions(),
    ) -> Work: ...
    @overload
    def gather(
        self,
        output_tensors: List[List[Tensor]],
        input_tensors: List[Tensor],
        opts = GatherOptions(),
    ) -> Work: ...
    @overload
    def gather(
        self,
        output_tensors: List[Tensor],
        input_tensor: Tensor,
        root: int,
    ) -> Work: ...
    @overload
    def scatter(
        self,
        output_tensors: List[Tensor],
        input_tensors: List[List[Tensor]],
        opts = ScatterOptions(),
    ) -> Work: ...
    @overload
    def scatter(
        self,
        output_tensor: Tensor,
        input_tensors: List[Tensor],
        root: int,
    ) -> Work: ...
    @overload
    def reduce_scatter(
        self,
        output_tensors: List[Tensor],
        input_tensors: List[List[Tensor]],
        opts = ReduceScatterOptions(),
    ) -> Work: ...
    @overload
    def reduce_scatter(
        self,
        output_tensors: Tensor,
        input_tensor: List[Tensor],
    ) -> Work: ...
    @overload
    def alltoall_base(
        self,
        output_tensor: Tensor,
        input_tensor: Tensor,
        output_split_sizes: List[int],
        input_split_sizes: List[int],
        opts = AllToAllOptions(),
    ) -> Work: ...
    @overload
    def alltoall_base(
        self,
        output: Tensor,
        input: Tensor,
        output_split_sizes: List[int],
        input_split_sizes: List[int],
    ) -> Work: ...
    @overload
    def alltoall(
        self,
        output_tensor: List[Tensor],
        input_tensor: List[Tensor],
        opts = AllToAllOptions(),
    ) -> Work: ...
    @overload
    def alltoall(
        self,
        output: List[Tensor],
        input: List[Tensor],
    ) -> Work: ...
    def send(
        self,
        tensors: List[Tensor],
        dstRank: int,
        tag: int,
    ) -> Work: ...
    def recv(
        self,
        tensors: List[Tensor],
        srcRank: int,
        tag: int,
    ) -> Work: ...
    def recv_anysource(
        self,
        tensors: List[Tensor],
        tag: int
    ) -> Work: ...
    def barrier(
        self,
        opts = BarrierOptions()
    ) -> Work: ...

class ProcessGroupRoundRobin(ProcessGroup): ...
def _round_robin_process_groups(
    process_groups: List[ProcessGroup],
) -> ProcessGroupRoundRobin: ...


class ProcessGroupGloo(ProcessGroup):
    class Device: ...
    def __init__(
        self,
        store: Store,
        rank: int,
        size: int,
        timeout: timedelta,
    ): ...
    @staticmethod
    def create_device(hostname = str(), interface = str()) -> Device: ...
    ...

class ProcessGroupNCCL(ProcessGroup):
    def __init__(
        self,
        store: Store,
        rank: int,
        size: int,
        timeout: timedelta,
    ): ...
    @staticmethod
    def _group_start() -> None: ...
    @staticmethod
    def _group_end() -> None: ...
    ...

class ProcessGroupMPI(ProcessGroup):
    def __init__(
        self,
        rank: int,
        size: int,
Loading ...