Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ distributed / _tensor / dispatch.py

# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Callable, cast, Dict, Tuple, Union, Optional

import torch

import torch.distributed._tensor.api as dtensor
from torch.distributed._tensor.op_schema import (
    ArgsType,
    KwargsType,
    OutputSpecType,
)
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.distributed._tensor.sharding_prop import ShardingPropagator
from torch.distributed._tensor.redistribute import redistribute_dtensor
from torch.utils._pytree import tree_flatten, tree_unflatten


"""
If _ENABLE_FALLBACK set to False, dispatch will fail when an op doesn't
have a sharding rule registered.
"""
_ENABLE_FALLBACK = False


def wrap(res: object, spec: OutputSpecType) -> object:
    if isinstance(res, torch.Tensor):
        assert spec is not None and isinstance(
            spec, DTensorSpec
        ), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
        return dtensor.DTensor(
            res,
            spec.mesh,
            spec.placements,
            size=spec.shape,
            requires_grad=res.requires_grad,
        )
    elif isinstance(res, list):
        assert spec is not None and isinstance(
            spec, list
        ), f"output spec does not match with output! Expected list, got {spec}."
        return [
            dtensor.DTensor(e, s.mesh, s.placements, size=s.shape)
            for e, s in zip(res, spec)
        ]
    elif isinstance(res, tuple):
        assert spec is not None and isinstance(
            spec, tuple
        ), f"output spec does not match with output! Expected tuple, got {spec}"

        # NOTE: local results might return Optional Tensor from ATen op, so we need to
        # handle that case and make sure we don't wrap None with DTensor.
        # (i.e. native_layer_norm.backward)
        return tuple(
            dtensor.DTensor(e, s.mesh, s.placements, size=s.shape)
            if e is not None and s is not None
            else None
            for e, s in zip(res, spec)
        )
    else:
        # if the res contains only non tensor values, we simply return it without rewrapping
        return res


def pack_args_kwargs_with_local_tensor(
    args: Union[ArgsType, KwargsType],
    args_schema: Union[ArgsType, KwargsType],
    redistribute_with_schema: bool = False,
) -> Union[ArgsType, KwargsType]:
    flatten_args, args_tree_spec = tree_flatten(args)
    flatten_args_schema, _ = tree_flatten(args_schema)

    for i, arg in enumerate(flatten_args):
        if isinstance(arg, dtensor.DTensor):
            if redistribute_with_schema:
                target_spec = flatten_args_schema[i]
                arg = redistribute_dtensor(
                    arg, target_spec.mesh, target_spec.placements
                )

            # reuse the schema list and update it with local tensor
            flatten_args_schema[i] = arg._local_tensor

    return tree_unflatten(flatten_args_schema, args_tree_spec)


def _reshape_alias(
    x: torch.Tensor, shape: Tuple[int, ...], strides: Tuple[int, ...]
) -> torch.Tensor:
    return torch.ops.aten.view(x, shape)


_CURRENT_DECOMPOSITION_TABLE: Dict[Callable[..., object], Callable[..., object]] = {
    torch.ops.aten._reshape_alias.default: _reshape_alias,
}


def operator_dispatch(
    op_call: torch._ops.OpOverload,
    args: Tuple[object, ...],
    kwargs: Dict[str, object],
    sharding_propagator: ShardingPropagator,
    custom_dispatch_ops: Optional[Dict[str, Callable[..., object]]] = None,
) -> object:
    # first we need to lift some private aten aliases to public calls
    if op_call in _CURRENT_DECOMPOSITION_TABLE:
        return _CURRENT_DECOMPOSITION_TABLE[op_call](*args, **kwargs)

    # STEP 0. See if there's a user defined custom aten operator
    # implementations. Custom operators take the highest priority
    if custom_dispatch_ops is not None and str(op_call) in custom_dispatch_ops:
        # dispatch to user defined custom distributed tensor ops
        return custom_dispatch_ops[str(op_call)](*args, **kwargs)

    # unwrap the args/kwargs schema
    op_schema = sharding_propagator.prepare_op_schema(op_call, args, kwargs)

    output_sharding = sharding_propagator.propagate_op_sharding(op_call, op_schema)

    # if the schema suggestion from sharding prop is not the same instance as the
    # input op_schema, it indicates a reshard, we need to redistribute the input
    # tensors before calling the local op
    assert output_sharding.schema_suggestions is not None
    needs_redistribute = output_sharding.schema_suggestions[0] is not op_schema
    suggested_input_schema = output_sharding.schema_suggestions[0]

    local_tensor_args = pack_args_kwargs_with_local_tensor(
        args,
        suggested_input_schema.args_schema,
        redistribute_with_schema=needs_redistribute,
    )
    local_tensor_kwargs = pack_args_kwargs_with_local_tensor(
        kwargs,
        suggested_input_schema.kwargs_schema,
        redistribute_with_schema=needs_redistribute,
    )

    # run local op computation with potentially modified args/kwargs
    local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
    local_tensor_kwargs = cast(Dict[str, object], local_tensor_kwargs)
    local_results = op_call(*local_tensor_args, **local_tensor_kwargs)

    if suggested_input_schema.is_inplace:
        # inplace op should return self instead of re-wrapping
        self = cast(dtensor.DTensor, args[0])
        self._spec = cast(DTensorSpec, output_sharding.output_spec)
        return self
    elif suggested_input_schema.is_out_variant:
        # out variant could possibly have multiple out args (i.e. lu_unpack.out)
        output_specs = (
            (output_sharding.output_spec,)
            if not isinstance(output_sharding.output_spec, tuple)
            else output_sharding.output_spec
        )
        out_dts = []
        spec_idx = 0
        for arg in suggested_input_schema.func_schema.arguments:
            if arg.is_out:
                out_dt = cast(dtensor.DTensor, kwargs[arg.name])
                out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
                out_dts.append(out_dt)
                spec_idx += 1

        assert len(out_dts) >= 1, "out variant should have at least one out arg"
        return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
    else:
        return wrap(local_results, output_sharding.output_spec)