# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Callable, cast, Dict, Tuple, Union, Optional
import torch
import torch.distributed._tensor.api as dtensor
from torch.distributed._tensor.op_schema import (
ArgsType,
KwargsType,
OutputSpecType,
)
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.distributed._tensor.sharding_prop import ShardingPropagator
from torch.distributed._tensor.redistribute import redistribute_dtensor
from torch.utils._pytree import tree_flatten, tree_unflatten
"""
If _ENABLE_FALLBACK set to False, dispatch will fail when an op doesn't
have a sharding rule registered.
"""
_ENABLE_FALLBACK = False
def wrap(res: object, spec: OutputSpecType) -> object:
if isinstance(res, torch.Tensor):
assert spec is not None and isinstance(
spec, DTensorSpec
), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
return dtensor.DTensor(
res,
spec.mesh,
spec.placements,
size=spec.shape,
requires_grad=res.requires_grad,
)
elif isinstance(res, list):
assert spec is not None and isinstance(
spec, list
), f"output spec does not match with output! Expected list, got {spec}."
return [
dtensor.DTensor(e, s.mesh, s.placements, size=s.shape)
for e, s in zip(res, spec)
]
elif isinstance(res, tuple):
assert spec is not None and isinstance(
spec, tuple
), f"output spec does not match with output! Expected tuple, got {spec}"
# NOTE: local results might return Optional Tensor from ATen op, so we need to
# handle that case and make sure we don't wrap None with DTensor.
# (i.e. native_layer_norm.backward)
return tuple(
dtensor.DTensor(e, s.mesh, s.placements, size=s.shape)
if e is not None and s is not None
else None
for e, s in zip(res, spec)
)
else:
# if the res contains only non tensor values, we simply return it without rewrapping
return res
def pack_args_kwargs_with_local_tensor(
args: Union[ArgsType, KwargsType],
args_schema: Union[ArgsType, KwargsType],
redistribute_with_schema: bool = False,
) -> Union[ArgsType, KwargsType]:
flatten_args, args_tree_spec = tree_flatten(args)
flatten_args_schema, _ = tree_flatten(args_schema)
for i, arg in enumerate(flatten_args):
if isinstance(arg, dtensor.DTensor):
if redistribute_with_schema:
target_spec = flatten_args_schema[i]
arg = redistribute_dtensor(
arg, target_spec.mesh, target_spec.placements
)
# reuse the schema list and update it with local tensor
flatten_args_schema[i] = arg._local_tensor
return tree_unflatten(flatten_args_schema, args_tree_spec)
def _reshape_alias(
x: torch.Tensor, shape: Tuple[int, ...], strides: Tuple[int, ...]
) -> torch.Tensor:
return torch.ops.aten.view(x, shape)
_CURRENT_DECOMPOSITION_TABLE: Dict[Callable[..., object], Callable[..., object]] = {
torch.ops.aten._reshape_alias.default: _reshape_alias,
}
def operator_dispatch(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
sharding_propagator: ShardingPropagator,
custom_dispatch_ops: Optional[Dict[str, Callable[..., object]]] = None,
) -> object:
# first we need to lift some private aten aliases to public calls
if op_call in _CURRENT_DECOMPOSITION_TABLE:
return _CURRENT_DECOMPOSITION_TABLE[op_call](*args, **kwargs)
# STEP 0. See if there's a user defined custom aten operator
# implementations. Custom operators take the highest priority
if custom_dispatch_ops is not None and str(op_call) in custom_dispatch_ops:
# dispatch to user defined custom distributed tensor ops
return custom_dispatch_ops[str(op_call)](*args, **kwargs)
# unwrap the args/kwargs schema
op_schema = sharding_propagator.prepare_op_schema(op_call, args, kwargs)
output_sharding = sharding_propagator.propagate_op_sharding(op_call, op_schema)
# if the schema suggestion from sharding prop is not the same instance as the
# input op_schema, it indicates a reshard, we need to redistribute the input
# tensors before calling the local op
assert output_sharding.schema_suggestions is not None
needs_redistribute = output_sharding.schema_suggestions[0] is not op_schema
suggested_input_schema = output_sharding.schema_suggestions[0]
local_tensor_args = pack_args_kwargs_with_local_tensor(
args,
suggested_input_schema.args_schema,
redistribute_with_schema=needs_redistribute,
)
local_tensor_kwargs = pack_args_kwargs_with_local_tensor(
kwargs,
suggested_input_schema.kwargs_schema,
redistribute_with_schema=needs_redistribute,
)
# run local op computation with potentially modified args/kwargs
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
local_tensor_kwargs = cast(Dict[str, object], local_tensor_kwargs)
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
if suggested_input_schema.is_inplace:
# inplace op should return self instead of re-wrapping
self = cast(dtensor.DTensor, args[0])
self._spec = cast(DTensorSpec, output_sharding.output_spec)
return self
elif suggested_input_schema.is_out_variant:
# out variant could possibly have multiple out args (i.e. lu_unpack.out)
output_specs = (
(output_sharding.output_spec,)
if not isinstance(output_sharding.output_spec, tuple)
else output_sharding.output_spec
)
out_dts = []
spec_idx = 0
for arg in suggested_input_schema.func_schema.arguments:
if arg.is_out:
out_dt = cast(dtensor.DTensor, kwargs[arg.name])
out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
out_dts.append(out_dt)
spec_idx += 1
assert len(out_dts) >= 1, "out variant should have at least one out arg"
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
else:
return wrap(local_results, output_sharding.output_spec)