Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ distributed / _shard / sharded_tensor / _ops / chunk.py

import torch
from torch.distributed._shard.sharded_tensor import (
    _sharded_op_impl,
    ShardedTensor,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec


def register_chunk_op(op):
    @_sharded_op_impl(op)
    def sharded_chunk(types, args=(), kwargs=None, pg=None):
        """
        Handles ``__torch_function__`` dispatch for the chunk op.
        If we chunk by the non-sharding dim, we just directly chunk the
        local tensor and create a list of sharded tensor based on them.

        Warnings: Chunk by the sharding dim is not supported.

        Args: same as ``torch.chunk``.

        Return:
            List[ShardedTensor]: Chunk results as a list of ShardedTensor.
        """
        st = args[0]
        chunk_num = args[1]
        dim = kwargs.get("dim")
        dim = dim if dim else 0

        # Validate types
        if not isinstance(st, ShardedTensor):
            raise TypeError(
                f"torch function '{op.__name__}', with args: {args} and "
                f"kwargs: {kwargs} are called for non ShardedTensor!"
            )
        spec = st.sharding_spec()
        if not isinstance(spec, ChunkShardingSpec):
            raise NotImplementedError("Only ChunkShardingSpec is supported for chunk.")
        if spec.dim == dim or st.dim() + spec.dim == dim or st.dim() + dim == spec.dim:  # type: ignore[operator]
            raise NotImplementedError("Chunk by sharding dim is not supported.")

        local_tensor = st.local_tensor()
        st_size = st.size()
        dim = dim if dim > 0 else st.dim() + dim
        results = []
        for chunk_tensor in local_tensor.chunk(chunk_num, dim=dim):
            new_st_size = (*st_size[:dim], chunk_tensor.size(dim), *st_size[dim + 1 :])  # type: ignore[index]
            results.append(
                ShardedTensor._init_from_local_tensor(
                    chunk_tensor.contiguous(),
                    st.sharding_spec(),
                    new_st_size,
                    process_group=pg,
                )
            )
        return results


chunk_ops = [
    torch.chunk,
    torch.Tensor.chunk,
]
for op in chunk_ops:
    register_chunk_op(op)