Repository URL to install this package:
Version:
2.1.2+cpu ▾
|
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, List, Optional, Sequence, Tuple
import torch
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import pointwise_rule
from torch.distributed._tensor.ops.utils import register_prop_rule
from torch.distributed._tensor.placement_types import (
_Partial,
DTensorSpec,
Placement,
Replicate,
Shard,
)
aten = torch.ops.aten # pyre-ignore
@register_prop_rule( # pyre-ignore
[
aten._foreach_neg.default,
aten._foreach_reciprocal.default,
aten._foreach_sqrt.default,
]
)
def _prop__foreach_unaop(op_schema: OpSchema) -> OutputSharding:
self = op_schema.args_schema[0]
assert isinstance(self, list) and all(isinstance(s, DTensorSpec) for s in self)
# FIXME(@mrshenli): for sqrt, this is only mathematically correct for
# Replicate and Shard tensor.
return OutputSharding(output_spec=self)
@register_prop_rule( # pyre-ignore
[
aten._foreach_add.List,
aten._foreach_div.List,
aten._foreach_mul.List,
]
)
def _prop__foreach_binop_list(op_schema: OpSchema) -> OutputSharding:
self, other = op_schema.args_schema[:2]
scalar = None if len(op_schema.args_schema) < 3 else op_schema.args_schema[2]
assert isinstance(self, list) and all(
isinstance(s, DTensorSpec) for s in self
), f"Expect a List[DTensorSpec] but got {self}"
assert isinstance(other, list) and all(
isinstance(o, DTensorSpec) for o in other
), f"Expect a List[DTensorSpec] but got {other}"
assert len(self) == len(other), (
"Two tensor lists must match in length, "
f"but got {len(self)} and {len(other)}"
)
if any(s != o for s, o in zip(self, other)):
# If DTensorSpec for the two operand do not match, suggest using
# self's DTensorSpec. This will trigger allreduce if other is partial
# and self is replicated.
return OutputSharding(
output_spec=None,
schema_suggestions=[
OpSchema(
func_schema=op_schema.func_schema,
args_schema=(self, self, scalar) if scalar else (self, self),
kwargs_schema=op_schema.kwargs_schema,
is_inplace=op_schema.is_inplace,
is_out_variant=op_schema.is_out_variant,
)
],
)
else:
return OutputSharding(output_spec=self)
@register_prop_rule( # pyre-ignore
[
aten._foreach_add.Scalar,
aten._foreach_div.Scalar,
aten._foreach_mul.Scalar,
aten._foreach_sub.Scalar,
]
)
def _prop__foreach_binop_scalar(op_schema: OpSchema) -> OutputSharding:
self, scalar = op_schema.args_schema
assert isinstance(self, list) and all(isinstance(s, DTensorSpec) for s in self)
assert not isinstance(scalar, list)
return OutputSharding(output_spec=self)
@register_prop_rule( # pyre-ignore
[
aten._foreach_addcdiv.Scalar,
aten._foreach_addcmul.Scalar,
]
)
def _prop__foreach_addcop_scalar(op_schema: OpSchema):
self, tensor1, tensor2 = op_schema.args_schema[:3]
scalar = None if len(op_schema.args_schema) < 4 else op_schema.args_schema[3]
assert isinstance(self, list) and all(isinstance(s, DTensorSpec) for s in self)
assert isinstance(tensor1, list) and all(isinstance(s, DTensorSpec) for s in self)
assert isinstance(tensor2, list) and all(isinstance(s, DTensorSpec) for s in self)
if any(s != t1 or s != t2 for s, t1, t2 in zip(self, tensor1, tensor2)):
# If DTensorSpec for the two operand do not match, suggest using
# self's DTensorSpec. This will trigger allreduce if other is partial
# and self is replicated.
return OutputSharding(
output_spec=None,
schema_suggestions=[
OpSchema(
func_schema=op_schema.func_schema,
args_schema=(self, self, self, scalar)
if scalar
else (self, self, self),
kwargs_schema=op_schema.kwargs_schema,
is_inplace=op_schema.is_inplace,
is_out_variant=op_schema.is_out_variant,
)
],
)
else:
return OutputSharding(output_spec=self)
@register_prop_rule([aten._foreach_pow.ScalarAndTensor]) # pyre-ignore
def _prop__foreach_pow_scalar_and_tensor(op_schema: OpSchema):
scala, exponent = op_schema.args_schema
assert isinstance(exponent, list) and all(
isinstance(s, DTensorSpec) for s in exponent
)
return OutputSharding(output_spec=exponent)
@register_prop_rule([aten._fused_adam.default]) # pyre-ignore
def _prop__fused_adam(op_schema: OpSchema):
NT = 5
tesnor_list_args: Tuple[List[DTensorSpec]] = op_schema.args_schema[:NT] # type: ignore[assignment]
assert all(isinstance(schema, list) for schema in tesnor_list_args)
assert all(
isinstance(s, DTensorSpec) for schema in tesnor_list_args for s in schema
)
tensor_schemas: Tuple[List[DTensorSpec]] = [ # type: ignore[assignment]
schema for schema in tesnor_list_args if len(schema)
]
assert all(len(s) == len(tensor_schemas[0]) for s in tensor_schemas), (
"expect the same number of gradients and states, but got "
f"{[len(s) for s in tensor_schemas]}."
)
if any(any(t != ts[0] for t in ts) for ts in zip(*tensor_schemas)):
new_schemas: Tuple[List[DTensorSpec]] = tuple( # type: ignore[assignment]
op_schema.args_schema[0] if len(s) else s for s in tesnor_list_args
)
return OutputSharding(
output_spec=None,
schema_suggestions=[
OpSchema(
func_schema=op_schema.func_schema,
args_schema=new_schemas + op_schema.args_schema[NT:],
kwargs_schema=op_schema.kwargs_schema,
is_inplace=op_schema.is_inplace,
is_out_variant=op_schema.is_out_variant,
)
],
)
else:
return OutputSharding(output_spec=(op_schema.args_schema[0],) * NT) # type: ignore[arg-type]
@register_prop_rule(aten.nll_loss_forward.default) # pyre-ignore
def _prop_nll_loss_forward(op_schema: OpSchema) -> OutputSharding:
self, target = op_schema.args_schema[:2]
assert isinstance(self, DTensorSpec)
assert isinstance(target, DTensorSpec)
if self.placements != target.placements:
# Self and target must match in placements, which should be shard along
# batch dimension in data parallell use cases. Force redistribute.
# need to create a new self instead return (target, target) as target
# and self might not match in shape.
new_self = DTensorSpec(
mesh=self.mesh,
placements=target.placements,
tensor_meta=self.tensor_meta,
)
return OutputSharding(
output_spec=None,
schema_suggestions=[
OpSchema(
func_schema=op_schema.func_schema,
args_schema=(new_self, target) + op_schema.args_schema[2:],
kwargs_schema=op_schema.kwargs_schema,
is_inplace=op_schema.is_inplace,
is_out_variant=op_schema.is_out_variant,
)
],
)
else:
return OutputSharding(
output_spec=(
# by default, nll_loss_forward conducts a reduction and returns
# a scalar tensor, and hence the _Partial placements.
DTensorSpec(mesh=self.mesh, placements=(_Partial(),)),
# the 2nd output total_weight is always a scalar tensor
DTensorSpec(mesh=self.mesh, placements=(Replicate(),)),
)
)
@register_prop_rule(aten.nll_loss_backward.default) # pyre-ignore
def _prop_nll_loss_backward(op_schema: OpSchema) -> OutputSharding:
grad_output, self = op_schema.args_schema[:2]
assert isinstance(grad_output, DTensorSpec)
assert isinstance(self, DTensorSpec)
return OutputSharding(output_spec=self)
@register_prop_rule(aten.stack.default)
def _prop_stack(op_schema: OpSchema) -> OutputSharding:
tensors = op_schema.args_schema[0]
dim = 0 if len(op_schema.args_schema) == 1 else cast(int, op_schema.args_schema[1])
assert (
isinstance(tensors, list) and len(tensors) > 0
), "expect at least one tensor to stack"
assert all(
isinstance(t, DTensorSpec) for t in tensors
), f"expect a list of DTensorSpecs, but got {tensors}"
assert all(
t.shape == tensors[0].shape for t in tensors
), f"expect all tensors to have the same shape, but got {tensors}."
# TODO: provide schema_suggestions when placements do not match
assert all(
t.placements == tensors[0].placements for t in tensors
), f"expect all tensors to have the same placements, but got {tensors}."
assert all(
not p.is_shard(dim) for p in tensors[0].placements
), "DTensor does not support stack on sharded dimension."
return OutputSharding(
output_spec=DTensorSpec(mesh=tensors[0].mesh, placements=tensors[0].placements)
)
@register_prop_rule(aten.select.int)
def _prop_select(op_schema: OpSchema) -> OutputSharding:
tensor, dim = op_schema.args_schema[:2]
assert isinstance(tensor, DTensorSpec)
assert isinstance(dim, int)
placements: Sequence[Placement] = tensor.placements
assert all(
not p.is_shard(dim) for p in placements
), "DTensor does not support select on sharded dimension."
# select will remove one dimension, decrement dim of Shard placements by 1
# if they are larger than dim.
new_placements: List[Placement] = []
for p in placements:
# Using isinstance instead of is_shard so that mypy won't complain
# about accessing dim attribute.
if isinstance(p, Shard) and p.dim > dim:
new_placements.append(Shard(p.dim - 1))
else:
new_placements.append(p)
return OutputSharding(
output_spec=DTensorSpec(mesh=tensor.mesh, placements=tuple(new_placements))
)
@register_prop_rule(aten.native_layer_norm.default) # pyre-ignore
def _prop_native_layer_norm(op_schema: OpSchema) -> OutputSharding:
input, normalized_shape, weight, bias, eps = op_schema.args_schema
assert isinstance(input, DTensorSpec)
assert isinstance(normalized_shape, (tuple, list))
if weight is not None:
assert isinstance(weight, DTensorSpec)
assert all(isinstance(p, Replicate) for p in weight.placements)
if bias is not None:
assert isinstance(bias, DTensorSpec)
assert all(isinstance(p, Replicate) for p in bias.placements)
# only the left-most (non-normalized) dimensions of the input can be sharded
batch_ndim = len(input.shape) - len(normalized_shape)
assert all(
isinstance(p, Replicate) or (isinstance(p, Shard) and p.dim < batch_ndim,)
for p in input.placements
)
stats_spec = DTensorSpec(
mesh=input.mesh,
placements=input.placements,
)
return OutputSharding(output_spec=(input, stats_spec, stats_spec))
@register_prop_rule(aten.native_layer_norm_backward.default) # pyre-ignore
def _prop_native_layer_norm_backward(op_schema: OpSchema) -> OutputSharding:
(
grad,
input,
normalized_shape,
result1,
result2,
weight,
bias,
grad_input_mask,
) = op_schema.args_schema
assert isinstance(grad, DTensorSpec)
assert isinstance(grad_input_mask, (list, tuple))
if weight is not None:
assert isinstance(weight, DTensorSpec)
assert all(isinstance(s, Replicate) for s in weight.placements)
if bias is not None:
assert isinstance(bias, DTensorSpec)
assert all(isinstance(s, Replicate) for s in bias.placements)
# ensure sharding on dim 0, which will trigger the "Partial" output on
# weight and bias grads
assert any(
isinstance(s, Shard) and s.dim == 0 for s in grad.placements
), f"Got {grad.placements}"
weight_grad = (
DTensorSpec(
mesh=weight.mesh,
placements=tuple([_Partial()] * weight.mesh.ndim),
)
if weight
else None
)
bias_grad = (
DTensorSpec(
mesh=bias.mesh,
placements=tuple([_Partial()] * bias.mesh.ndim),
)
if bias
else None
)
return OutputSharding(
# NOTE: type errors below are legit. This is because DTensor currently
# doesn't support Optional return values. Need to be fixed in DTensor repo.
output_spec=(
grad if grad_input_mask[0] else None,
weight_grad if grad_input_mask[1] else None,
bias_grad if grad_input_mask[2] else None,
),
)
def _refine_sharding(
op_schema: OpSchema, active_dim: Optional[int]
) -> Sequence[Placement]:
"""
Considers 2 first inputs of op_schema as having same shape,
and returns suggested placement for a pointwise operation.
"""
# consider the operating dimension as a singleton to prevent sharding on it
# however, if active_dim is None, this means the input and output shapes are equal and
# we'll apply exactly the pointwise rule.
from torch.fx.passes.shape_prop import TensorMetadata
args_schema = []
for s in op_schema.args_schema[:2]:
assert isinstance(s, DTensorSpec) and s.tensor_meta is not None
args_schema.append(
DTensorSpec(
mesh=s.mesh, # type: ignore[attr-defined]
placements=s.placements, # type: ignore[attr-defined]
tensor_meta=TensorMetadata(
shape=torch.Size(
s.shape[0:active_dim] + (1,) + s.shape[active_dim + 1 :]
)
if active_dim is not None
else s.shape,
dtype=s.tensor_meta.dtype,
requires_grad=s.tensor_meta.requires_grad,
stride=s.tensor_meta.stride,
memory_format=s.tensor_meta.memory_format,
is_quantized=s.tensor_meta.is_quantized,
qparams=s.tensor_meta.qparams,
),
)
)
op_schema = OpSchema(
func_schema=op_schema.func_schema,
args_schema=args_schema, # type: ignore[arg-type]
kwargs_schema={},
is_inplace=op_schema.is_inplace,
is_out_variant=op_schema.is_out_variant,
)
output_sharding = pointwise_rule(op_schema, linearity=False)
if output_sharding.output_spec:
assert isinstance(output_sharding.output_spec, DTensorSpec)
return output_sharding.output_spec.placements
else:
assert output_sharding.schema_suggestions is not None
out_schema = output_sharding.schema_suggestions[0].args_schema[0]
assert isinstance(out_schema, DTensorSpec)
return tuple(out_schema.placements)
@register_prop_rule(aten.slice_scatter.default) # pyre-ignore
def prop_slice_scatter(op_schema: OpSchema) -> OutputSharding:
# 1. number of dimensions in input and src need to match.
# 2. number of elements on all non-dim need to match between input and src.
# 3. numer of elements in src in dim need to match the slice size.
# Given the above:
# - We suggest for src to follow the sharding of input, except on the scatter dimension,
# where our best bet for now is to make them replicated as a fall-back.
# TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding.
defaults = (None, None, 0, None, None, 1)
input, src, dim, start, end, step = (
op_schema.args_schema + defaults[len(op_schema.args_schema) :]
)
assert isinstance(input, DTensorSpec)
assert isinstance(src, DTensorSpec)
assert isinstance(dim, int)
if dim < 0:
dim += input.ndim
# if the input shape and the output shape are the same on the operating dimension,
# this is effectively a no-op, so we just propagate sharding as we would do for
# pointwise, no exceptions.
if input.shape[dim] == src.shape[dim]:
assert start == 0
assert end >= src.shape[dim] # type: ignore[operator]
dim = None
# apply sharding refinement as implemented in pointwise_rule
input_suggestion = list(_refine_sharding(op_schema, dim))
# apply the exception -- disallow sharding on the operating dimension.
for i, p in enumerate(input_suggestion):
if isinstance(p, Shard) and p.dim == dim:
input_suggestion[i] = Replicate()
input_suggestion = tuple(input_suggestion) # type: ignore[assignment]
if input_suggestion == tuple(input.placements) and src.placements == tuple(
input.placements
):
# if our sharding is correct, the output sharding will be the same as the input.
return OutputSharding(
output_spec=DTensorSpec(
mesh=input.mesh,
placements=input.placements,
)
)
else:
# otherwise, return the suggestion.
return OutputSharding(
output_spec=None,
schema_suggestions=[
OpSchema(
func_schema=op_schema.func_schema,
args_schema=(
DTensorSpec(
mesh=input.mesh,
placements=input_suggestion,
tensor_meta=input.tensor_meta,
),
DTensorSpec(
mesh=src.mesh,
placements=input_suggestion,
tensor_meta=src.tensor_meta,
),
)
+ op_schema.args_schema[2:],
kwargs_schema=op_schema.kwargs_schema,
)
],
)