import torch
from torch.distributed._shard.sharded_tensor import (
_sharded_op_impl,
ShardedTensor,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
def register_chunk_op(op):
@_sharded_op_impl(op)
def sharded_chunk(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the chunk op.
If we chunk by the non-sharding dim, we just directly chunk the
local tensor and create a list of sharded tensor based on them.
Warnings: Chunk by the sharding dim is not supported.
Args: same as ``torch.chunk``.
Return:
List[ShardedTensor]: Chunk results as a list of ShardedTensor.
"""
st = args[0]
chunk_num = args[1]
dim = kwargs.get("dim")
dim = dim if dim else 0
# Validate types
if not isinstance(st, ShardedTensor):
raise TypeError(
f"torch function '{op.__name__}', with args: {args} and "
f"kwargs: {kwargs} are called for non ShardedTensor!"
)
spec = st.sharding_spec()
if not isinstance(spec, ChunkShardingSpec):
raise NotImplementedError("Only ChunkShardingSpec is supported for chunk.")
if spec.dim == dim or st.dim() + spec.dim == dim or st.dim() + dim == spec.dim: # type: ignore[operator]
raise NotImplementedError("Chunk by sharding dim is not supported.")
local_tensor = st.local_tensor()
st_size = st.size()
dim = dim if dim > 0 else st.dim() + dim
results = []
for chunk_tensor in local_tensor.chunk(chunk_num, dim=dim):
new_st_size = (*st_size[:dim], chunk_tensor.size(dim), *st_size[dim + 1 :]) # type: ignore[index]
results.append(
ShardedTensor._init_from_local_tensor(
chunk_tensor.contiguous(),
st.sharding_spec(),
new_st_size,
process_group=pg,
)
)
return results
chunk_ops = [
torch.chunk,
torch.Tensor.chunk,
]
for op in chunk_ops:
register_chunk_op(op)