Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ distributed / _tensor / api.py

# Copyright (c) Meta Platforms, Inc. and affiliates
import copy
import warnings
from typing import Callable, cast, Dict, Optional, Sequence

import torch
import torch.nn as nn

import torch.distributed._tensor.dispatch as op_dispatch
from torch.distributed._tensor.device_mesh import DeviceMesh, get_global_device_mesh
from torch.distributed._tensor.placement_types import (
    _Partial,
    DTensorSpec,
    Placement,
    Replicate,
    Shard,
)
from torch.distributed._tensor.sharding_prop import ShardingPropagator
from torch.distributed._tensor.redistribute import Redistribute
from torch.utils._pytree import tree_flatten


__all__ = ["DTensor", "distribute_tensor", "distribute_module"]

# NOTE [Autograd interaction between torch.Tensor]
#
# The autograd functions defined below are being used by the public
# facing APIs (i.e. from_local, to_local) to ensure our DTensor
# works together with torch.Tensor within autograd engine. This
# allows DistributedTensor to exist on part of the module hierarchy
# and still able to calculate gradients across the torch.Tensor and
# DistributedTensor boundary.
# As an example, we have the a module that consists of submodules
# A, B, and C, the execution flow would be like:
#  input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor)
#
# Suppose I only want to make Module B be a sharded module with
# DistributedTensor params, we would need to make the folloing
# flow to work:
#
#  input(torch.Tensor) -> Module A
#       -> DTensor input -> Sharded Module B -> DTensor output
#           -> output (torch.Tensor) -> Module C -> output (torch.Tensor)
#
# We need the conversion from Module A to DTensor input, which is
# `from_local`, and conversion from DTensor output to output, which
# is `to_local`, thus these two functions must be Autograd functions.
#
class _ToTorchTensor(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: "DTensor"):  # type: ignore[override]
        ctx.dtensor_device_mesh = input.device_mesh
        ctx.dtensor_placements = input.placements
        ctx.dtensor_shape = input.shape
        ctx.dtensor_requires_grad = input.requires_grad
        return input._local_tensor.detach()

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):  # type: ignore[override]
        device_mesh = ctx.dtensor_device_mesh
        placements = ctx.dtensor_placements
        return DTensor(
            grad_output,
            device_mesh,
            placements,
            size=ctx.dtensor_shape,
            requires_grad=grad_output.requires_grad,
        )


class _FromTorchTensor(torch.autograd.Function):
    @staticmethod
    def forward(  # type: ignore[override]
        ctx,  # pyre-ignore[2]: Parameter must be annotated.
        input: torch.Tensor,
        device_mesh: DeviceMesh,
        placements: Sequence[Placement],
        run_check: bool,
    ) -> "DTensor":
        ctx.previous_placement = placements
        ctx.previous_device_mesh = device_mesh

        if run_check:
            # TODO: by default check tensor metas across rank
            # TODO: See if we need to make this run_check logic
            # have a corresponding backward.
            for idx, placement in enumerate(placements):
                if placement.is_replicate():
                    # broadcast rank 0 tensor to all ranks
                    # only broadcast if run_check is True
                    input = input.contiguous()
                    device_mesh.broadcast(input, mesh_dim=idx)

        # if it's not by default run_check, we assume user is certain that each
        # rank has the same tensor shape, and we just use that to calculate the
        # global shape
        tensor_shape = list(input.size())
        for idx, placement in enumerate(placements):
            if placement.is_shard():
                shard_dim = cast(Shard, placement).dim
                local_dim_size = tensor_shape[shard_dim]
                tensor_shape[shard_dim] = local_dim_size * device_mesh.size(idx)

        dist_tensor = DTensor(
            input,
            device_mesh,
            placements,
            size=torch.Size(tensor_shape),
            # requires_grad of the dist tensor depends on if input
            # requires_grad or not
            requires_grad=input.requires_grad,
        )
        return dist_tensor

    @staticmethod
    def backward(ctx, grad_output: "DTensor"):  # type: ignore[override]
        previous_placement = ctx.previous_placement
        previous_device_mesh = ctx.previous_device_mesh

        # reshard to the placement when creating DistributedTensor
        # so that the gradient layout matches, and we could return
        # local gradients directly
        if grad_output.placements != previous_placement:
            # pyre-fixme[16]: `Redistribute` has no attribute `apply`.
            grad_output = Redistribute.apply(
                grad_output, previous_device_mesh, previous_placement
            )

        # TODO: backward is also differentiable now, add a test
        # to test higher level gradients.
        return grad_output.to_local(), None, None, None


class DTensor(torch.Tensor):  # pyre-ignore[13]: pyre is bad at __new__
    _local_tensor: torch.Tensor
    _spec: DTensorSpec
    __slots__ = ["_local_tensor", "_spec"]

    # class attribute that handles operator placements propagation
    # rules, keyed by aten op name, value is propagation func
    _propagator: ShardingPropagator = ShardingPropagator()

    # class attribute that handles custom registered ops, all handled
    # custom ops should appear in this table, and overriding the default
    # operators that's been covered by _op_to_rules or fallbacks.
    # (custom operator is the highest priority when dispatching).
    # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
    _custom_dispatch_ops: Dict[str, Callable] = {}

    @staticmethod
    def __new__(
        cls,
        local_tensor: torch.Tensor,
        device_mesh: DeviceMesh,
        placements: Sequence[Placement],
        *,
        size: torch.Size,
        requires_grad: bool = False,
    ) -> "DTensor":
        """
        Construct a DTensor from a local tensor, device mesh, and placement and
        other tensor properties (i.e. shape, requires_grad, strides, etc).
        Note: This is not a public API and it's only supposed to be used by the
            operator implementations and internals. If you want to construct a
            DTensor from a local tensor, consider using `DTensor.from_local`, if
            you want to construct a DTensor from a "global" tensor (where you
            already have tensor initialized and want to shard this tensor),
            consider using `distribute_tensor`.
        """
        # recover tensor strides from local tensor strides and global size info
        # in the case of sharding
        # TODO: we should try to use meta tensor for shape and stride calculation
        tensor_stride = list(local_tensor.stride())
        local_size = list(local_tensor.size())
        for placement in placements:
            if isinstance(placement, Shard):
                shard_dim = placement.dim
                # recover tensor stride by modifying the stride that larger than
                # the current stride on the shard_dim
                for i in range(len(tensor_stride)):
                    if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]:
                        # rescale the stride by the shard size
                        tensor_stride[i] = (
                            tensor_stride[i] // local_size[shard_dim]
                        ) * size[shard_dim]
            elif not isinstance(placement, (Replicate, _Partial)):
                raise RuntimeError(f"placement type {type(placement)} not supported!")

        if requires_grad != local_tensor.requires_grad:
            warnings.warn(
                "To construct DTensor from torch.Tensor, it's recommended to "
                "use local_tensor.detach() and make requires_grad consistent."
            )

        # new method instruct wrapper tensor from local_tensor and add
        # placement spec, it does not do actual distribution
        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
            cls,
            size,
            strides=tensor_stride,
            dtype=local_tensor.dtype,
            device=local_tensor.device,
            layout=local_tensor.layout,
            requires_grad=requires_grad,
        )
        # deepcopy and set spec
        r._spec = DTensorSpec(device_mesh, copy.deepcopy(placements), shape=r.size())
        # detach local tensor from autograd graph as we initialize the
        # distributed tensor and autograd will be working on top of
        # the wrapper tensor directly instead of local torch.Tensor
        r._local_tensor = local_tensor.detach()
        return r

    # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.
    # pyre-fixme[3]: Return type must be annotated.
    def __repr__(self):
        # TODO: consider all_gather the local tensors for better debugging
        return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"

    @classmethod
    # pyre-fixme[3]: Return type must be annotated.
    # pyre-fixme[2]: Parameter must be annotated.
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        # check that we are not getting mixed vanilla and Distributed tensors
        arg_list, _ = tree_flatten(args)
        for arg in arg_list:
            if isinstance(arg, torch.Tensor) and not isinstance(arg, DTensor):
                raise RuntimeError(
                    f"{func}: got mixed distributed and non-distributed tensors."
                )

        if kwargs is None:
            kwargs = {}

        return op_dispatch.operator_dispatch(
            func,
            args,
            kwargs,
            DTensor._propagator,
            DTensor._custom_dispatch_ops,
        )

    @classmethod
    def from_local(
        cls,
        local_tensor: torch.Tensor,
        device_mesh: Optional[DeviceMesh] = None,
        placements: Optional[Sequence[Placement]] = None,
        run_check: bool = True,
    ) -> "DTensor":
        """
        Create a :class:`DTensor` from a local torch.Tensor on each rank
        according to the `device_mesh` and `placements` specified.

        Args:
            local_tensor (torch.Tensor): local torch.Tensor on each rank.
            device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
                tensor, if not specified, must be called under a DeviceMesh
                context manager, default: None
            placements (List[:class:`Placement`], optional): the placements that
                describes how to place the local torch.Tensor on DeviceMesh, must
                have the same number of elements as `device_mesh.ndim`. If not
                specified, we will by default replicate the tensor across the
                `device_mesh` from the first rank of each dimension of the `device_mesh`.
            run_check (bool, optional): indicate whether to run check across ranks
                to check meta information and data. if have :class:`Replicate` in
                `placements`, the data on first rank of the device mesh dimension
                will be broadcasted to other ranks.

        Returns:
            A :class:`DTensor` object

        .. note:: `from_local` is differentiable, the `requires_grad` of the created
            `DTensor` object will depend on if `local_tensor` requires_grad or not.
        """
        # if same shape/dtype, no need to run_check, if not, must allgather
        # the metadatas to check the size/dtype across ranks
        # There should be no data communication unless there's replication
        # strategy, where we broadcast the replication from the first rank
        # in the mesh dimension
        device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
        # convert the local tensor to desired device base on device mesh's device_type
        if not local_tensor.is_meta:
            local_tensor = local_tensor.to(device_mesh.device_type)

        # set default placements to replicated if not specified
        if placements is None:
            placements = [Replicate() for _ in range(device_mesh.ndim)]

        # `from_local` is differentiable, and the gradient of the dist tensor this function
        # created should flow back the gradients to the local_tensor, so we call an autograd
        # function to construct the dist tensor instead.
        return _FromTorchTensor.apply(  # pyre-ignore[16]: autograd func
            local_tensor, device_mesh, placements, run_check
        )

    def to_local(self) -> torch.Tensor:
        """
        Get the local tensor of this DTensor on its current rank. For sharding it returns
        a local shard of the logical tensor view, for replication it returns the replica on
        its current rank.

        Returns:
            A :class:`torch.Tensor` object that represents the local tensor of its current rank.

        .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned
            will depend on if the `DTensor` requires_grad or not.
        """
        return _ToTorchTensor.apply(self)  # pyre-ignore[16]: autograd func

    def redistribute(
        self,
        device_mesh: Optional[DeviceMesh] = None,
        placements: Optional[Sequence[Placement]] = None,
    ) -> "DTensor":
        """
        `redistribute` performs necessary collective operations that redistribute the current
        DTensor from its current placements to a new placements, or from is current DeviceMesh
        to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by
        specifying a Replicate placement for each dimension of the DeviceMesh.

        Args:
            device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
                DTensor, if not specified, must be called under a DeviceMesh
                context manager, default: None
            placements (List[:class:`Placement`], optional): the new placements that
                describes how to place the DTensor into the DeviceMesh, must
                have the same number of elements as `device_mesh.ndim`.

        Returns:
            A :class:`DTensor` object

        .. note:: `redistribute` is differentiable.
        """
        # This API perform necessary transformations and get
        # a new DTensor with the new spec. i.e. for
        # sharding it's a reshard behavior.
        # Note that redistribute currently only supports out
        # of place redistribution, i.e. it always create a new
        # DTensor object and leave the original one unchanged.
        device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
        # raise error if new placements not specified
        if placements is None:
            raise RuntimeError("placements is needed for redistribute!")
Loading ...