Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ distributed / _tensor / ops / tensor_ops.py

# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, List, Optional, Sequence, Tuple

import torch

from torch.distributed._tensor.api import (
    _Partial,
    DTensorSpec,
    Placement,
    Replicate,
    Shard,
)
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
from torch.distributed._tensor.ops.utils import register_prop_rule, normalize_dim


aten = torch.ops.aten

# NOTE: the default propagation rule should apply for
# any operator that does not return a DTensor, i.e.
# for operators that only returns int/float/bool, we by
# default still propagate the spec, this is to ensure
# that we only return None for the case where the sharding
# propagation failed, and we should do auto-redistribute
def default_prop_rule(op_schema: OpSchema) -> OutputSharding:
    # by default prop the first arg spec
    return OutputSharding(op_schema.args_spec[0])


def prop_create_like(op_schema: OpSchema) -> OutputSharding:
    # For operators that create tensors with same shape as input but
    # with specific content that does not depend on the input, we
    # can propagate Sharding, but we have to make sure we move from
    # partial to replicated.
    input_spec = op_schema.args_spec[0]
    output_spec = DTensorSpec(
        mesh=input_spec.mesh,
        placements=tuple(
            Replicate() if isinstance(p, _Partial) else p for p in input_spec.placements
        ),
        ndim=input_spec.ndim,
        shape=input_spec.shape,
    )
    return OutputSharding(output_spec=output_spec)


@register_prop_rule(aten._local_scalar_dense.default)
def no_shard_prop_rule(op_schema: OpSchema) -> OutputSharding:
    # some tensor ops should not support shard, i.e. local_scalar_dense
    # shouldn't work for shard as it requires numel == 1
    # by default prop the first arg spec
    tensor_spec = op_schema.args_spec[0]
    for placement in tensor_spec.placements:
        if placement.is_shard():
            return OutputSharding(
                None,
                failed_reason=f"Op does not support input placements "
                f"with `Shard`, but found placements: "
                f"{tensor_spec.placements}",
            )
    # otherwise default prop the first arg spec
    return OutputSharding(tensor_spec)


def new_factory_rule(op_schema: OpSchema) -> OutputSharding:
    # this op would benefit from backward sharding propagation!
    # Since we cannot do that yet, just return replicated
    input = op_schema.args_schema[0]
    size = torch.Size(cast(Sequence[int], op_schema.args_schema[1]))
    assert isinstance(input, DTensorSpec)

    return OutputSharding(
        output_spec=DTensorSpec(
            mesh=input.mesh,
            placements=[Replicate()] * input.mesh.ndim,
            shape=size,
            ndim=len(size),
        )
    )


default_prop_ops = [
    aten._to_copy.default,
    aten.clone.default,
    aten.contiguous.default,
    aten.copy_.default,
    aten.detach.default,
    aten.is_same_size.default,
    aten.new_empty_strided.default,
]

create_like_ops = [
    aten.empty_like.default,
    aten.fill_.Scalar,
    aten.full_like.default,
    aten.ones_like.default,
    aten.zero_.default,
    aten.zeros_like.default,
]

new_factory_ops = [
    aten.new_full.default,
    aten.new_ones.default,
    aten.new_zeros.default,
]

for op in default_prop_ops:
    register_prop_rule(op)(default_prop_rule)

for op in create_like_ops:
    register_prop_rule(op)(prop_create_like)

for op in new_factory_ops:
    register_prop_rule(op)(new_factory_rule)


@register_prop_rule(aten.bucketize.Tensor)
def prop_bucketize(op_schema: OpSchema) -> OutputSharding:
    """
    Point-wise on the first input (just propagate input sharding).
    Expect replicated for second input.
    """
    input_schema, boundaries = op_schema.args_schema
    assert isinstance(input_schema, DTensorSpec)
    assert isinstance(boundaries, DTensorSpec)

    if all(isinstance(p, Replicate) for p in boundaries.placements):
        return OutputSharding(output_spec=input_schema)
    else:
        return OutputSharding(
            output_spec=None,
            schema_suggestions=[
                OpSchema(
                    func_schema=op_schema.func_schema,
                    args_schema=(
                        input_schema,
                        DTensorSpec(
                            mesh=boundaries.mesh,
                            placements=[Replicate()] * len(boundaries.placements),
                            ndim=boundaries.ndim,
                            shape=boundaries.shape,
                        ),
                    ),
                    kwargs_schema=op_schema.kwargs_schema,
                )
            ],
        )


def unshard_tensor_dim(
    placements: Sequence[Placement], dim: int
) -> Sequence[Placement]:
    """Disallow the given tensor dimension to be sharded"""
    return tuple(
        p if (not isinstance(p, Shard) or p.dim != dim) else Replicate()
        for p in placements
    )


def is_tensor_dim_sharded(
    spec: DTensorSpec, dim: int
) -> bool:
    """Return True if tensor dim is sharded"""
    return (dim < spec.ndim) and spec.dim_map[dim] >= 0


def _prop_all_but_dim(
    op_schema: OpSchema, dim: int, out_shape: torch.Size
) -> OutputSharding:
    """
    Considering an op that takes its input as first argument, forwards all shardings
    except for the given dimension.
    """
    input_spec = op_schema.args_schema[0]
    assert isinstance(input_spec, DTensorSpec)

    output_placements = unshard_tensor_dim(input_spec.placements, dim=dim)
    output_spec = DTensorSpec(
        mesh=input_spec.mesh,
        placements=output_placements,
        shape=out_shape,
        ndim=input_spec.ndim,
    )

    if input_spec.placements == output_placements:
        out = OutputSharding(output_spec=output_spec)
    else:
        suggested_input_spec = DTensorSpec(
            mesh=input_spec.mesh,
            placements=output_placements,
            ndim=input_spec.ndim,
            shape=input_spec.shape,
        )
        out = OutputSharding(
            output_spec=None,
            schema_suggestions=[
                OpSchema(
                    func_schema=op_schema.func_schema,
                    args_schema=(suggested_input_spec,) + op_schema.args_schema[1:],
                    kwargs_schema=op_schema.kwargs_schema,
                ),
            ],
        )
    return out


@register_prop_rule(aten.slice.Tensor)
def prop_slice(op_schema: OpSchema) -> OutputSharding:
    """NOTE: can be further optimized (right now it replicates before slicing on a sharded dimension)"""
    defaults = (None, 0, None, None, 1)
    input_spec, dim, start, end, step = (
        op_schema.args_schema + defaults[len(op_schema.args_schema) :]
    )
    assert isinstance(input_spec, DTensorSpec)
    assert isinstance(dim, int)
    assert start is None or isinstance(start, int)
    assert end is None or isinstance(end, int)
    assert isinstance(step, int)

    # normalize arguments
    if dim < 0:
        dim += input_spec.ndim
    if start is None:
        start = 0
    if step is None:
        step = 1
    if end is None or end > input_spec.shape[dim]:
        end = input_spec.shape[dim]
    if start < 0:
        start += input_spec.shape[dim]
    if end < 0:
        end += input_spec.shape[dim]

    if start == 0 and end == input_spec.shape[dim] and step == 1:
        return OutputSharding(output_spec=input_spec)

    # shape propagation
    slice_len = (end - start + step - 1) // step
    out_shape = torch.Size(
        tuple(input_spec.shape[0:dim])
        + (slice_len,)
        + tuple(input_spec.shape[dim + 1 :])
    )

    return _prop_all_but_dim(op_schema, dim=dim, out_shape=out_shape)


@register_prop_rule(aten.slice_scatter.default)
def prop_slice_scatter(op_schema: OpSchema) -> OutputSharding:
    # 1. number of dimensions in input and src need to match.
    # 2. number of elements on all non-dim need to match between input and src.
    # 3. numer of elements in src in dim need to match the slice size.
    # Given the above:
    # - We suggest for src to follow the sharding of input, except on the scatter dimension,
    #   where our best bet for now is to make them replicated as a fall-back.
    #   TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding.

    defaults = (None, None, 0, None, None, 1)
    input, src, dim, start, end, step = (
        op_schema.args_schema + defaults[len(op_schema.args_schema) :]
    )
    assert isinstance(input, DTensorSpec)
    assert isinstance(src, DTensorSpec)
    assert isinstance(dim, int)

    if dim < 0:
        dim += input.ndim

    # first, we keep the input sharding, except for the input dimension
    # also, we cannot allow partial sum anymore.
    input_suggestion = tuple(
        Replicate()
        if isinstance(p, _Partial) or (isinstance(p, Shard) and p.dim == dim)
        else p
        for p in input.placements
    )

    if input_suggestion == tuple(input.placements) and src.placements == tuple(
        input.placements
    ):
        # if our sharding is correct, the output sharding will be the same as the input.
        return OutputSharding(
            output_spec=DTensorSpec(
                mesh=input.mesh,
                placements=input.placements,
                shape=input.shape,
                ndim=input.ndim,
            )
        )
    else:
        # otherwise, return the suggestion.
        return OutputSharding(
            output_spec=None,
            schema_suggestions=[
                OpSchema(
                    func_schema=op_schema.func_schema,
                    args_schema=(
                        DTensorSpec(
                            mesh=input.mesh,
                            placements=input_suggestion,
                            shape=input.shape,
                            ndim=input.ndim,
                        ),
                        DTensorSpec(
                            mesh=src.mesh,
                            placements=input_suggestion,
                            shape=src.shape,
                            ndim=src.ndim,
                        ),
                    )
                    + op_schema.args_schema[2:],
                    kwargs_schema=op_schema.kwargs_schema,
                )
            ],
        )


@register_prop_rule(aten.index_select.default)
def prop_index_select(op_schema: OpSchema) -> OutputSharding:
    values_spec, dim, indices_spec = op_schema.args_schema

    assert isinstance(values_spec, DTensorSpec)
    assert isinstance(dim, int)
    assert isinstance(indices_spec, DTensorSpec)

    all_indices_spec: List[Optional[DTensorSpec]] = [
        indices_spec if dim == i else None for i in range(values_spec.ndim)
    ]

    result = prop_index(
        OpSchema(
            func_schema=op_schema.func_schema,
            args_schema=(values_spec, all_indices_spec),
            kwargs_schema=op_schema.kwargs_schema,
        )
    )
    if result.schema_suggestions:
        result.schema_suggestions = [
            OpSchema(
                func_schema=op_schema.func_schema,
                args_schema=(s.args_schema[0], dim, s.args_schema[1][dim]),
                kwargs_schema=op_schema.kwargs_schema,
            )
            for s in result.schema_suggestions
Loading ...