Repository URL to install this package:
Version:
2.4.1 ▾
|
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import functools
import itertools
import operator
from typing import cast, Iterable, List, Sequence, Tuple, Union
import torch
from torch.distributed._tensor._collective_utils import redistribute_cost
from torch.distributed._tensor._op_schema import (
OpSchema,
OpStrategy,
PlacementStrategy,
RuntimeSchemaInfo,
)
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import (
DTensorSpec,
Partial,
Placement,
Replicate,
Shard,
)
# convenient wrapper to register sharding propagation rules
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def register_prop_rule(op, schema_info=None):
# pyre-fixme[53]: Captured variable `func` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def wrapper(impl):
overloads = op if isinstance(op, list) else [op]
for overload in overloads:
DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule(
overload, impl, schema_info
)
return impl
return wrapper
def register_op_strategy(op, schema_info=None):
# pyre-fixme[53]: Captured variable `func` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
# For every ATen op that accepts any args in this list,
# the arg itself can impact the strides (and potentially the sharding strategy)
# of the output tensor.
# thus, we will detect ATen schemas with any of these args and ensure
# that they get specialized here.
arg_names_that_require_specializing_cache_strategy = [
"memory_format",
]
def wrapper(impl):
if isinstance(op, list):
overloads = op
else:
overloads = [op]
for overload in overloads:
curr_schema_info = None
if schema_info is None:
specialized_args = [
a.name
for a in overload._schema.arguments
if a.name in arg_names_that_require_specializing_cache_strategy
]
if any(specialized_args):
curr_schema_info = RuntimeSchemaInfo(
static_kwargkey=specialized_args
)
else:
curr_schema_info = schema_info
DTensor._op_dispatcher.sharding_propagator.register_op_strategy(
overload, impl, curr_schema_info
)
return impl
return wrapper
def as_list(
x: Union[List[object], object]
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
# which is an object but treated as a list by the tracer. Therefore, keep
# `immutable_list` intact here as well.
if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list):
return x
else:
return [x]
def normalize_dim(dim: int, ndim: int) -> int:
return dim if dim >= 0 else dim + ndim
def normalize_dims(dims: Union[int, Sequence[int]], ndim: int) -> Sequence[int]:
"""Normalize a dim or a sequence of dims, so that they are all positive."""
if isinstance(dims, int):
dims = (normalize_dim(dims, ndim),)
elif isinstance(dims, list):
dims = [normalize_dim(dim, ndim) for dim in dims]
elif isinstance(dims, tuple):
dims = tuple([normalize_dim(dim, ndim) for dim in dims])
return dims
def normalize_to_torch_size(size) -> torch.Size:
"""
Unify variable types of size argument to torch.Size
Acceptable types include:
int, Sequence[int], Tuple[int], Tuple[Sequence[int]],
or torch.Size
"""
if isinstance(size, torch.Size):
return size
if isinstance(size, int):
torch_size = [size]
elif len(size) == 1 and isinstance(size[0], Sequence):
torch_size = list(size[0])
else:
torch_size = list(size)
return torch.Size(torch_size)
def prod(xs: Iterable[int]) -> int:
return functools.reduce(operator.mul, xs, 1)
def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
"""Check if the shape is shardable according to the spec."""
# number of shards in each tensor dimension
shards_map = [1] * len(shape)
for i, placement in enumerate(spec.placements):
if placement.is_shard():
shard_dim = cast(Shard, placement).dim
shards_map[shard_dim] *= spec.mesh.size(i)
for i, dim_size in enumerate(shape):
# TODO: maybe we should determine is_shardable based on
# whether it's evenly sharded or not
if shards_map[i] > 1 and dim_size < shards_map[i]:
return False
return True
def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
"""Check if the shape is evenly shardable according to the spec."""
# number of shards in each tensor dimension
shards_map = [1] * len(shape)
for i, placement in enumerate(spec.placements):
if placement.is_shard():
shard_dim = cast(Shard, placement).dim
shards_map[shard_dim] *= spec.mesh.size(i)
for i, dim_size in enumerate(shape):
if shards_map[i] > 1 and (dim_size % shards_map[i] != 0):
return False
return True
def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool:
"""Return True if tensor dim is sharded."""
return any(p.is_shard(dim) for p in spec.placements)
def is_tensor_partial(spec: DTensorSpec) -> bool:
"""Return True if tensor is partial on the mesh."""
return any(p.is_partial() for p in spec.placements)
def infer_broadcast_dims_map(
common_shape: torch.Size, input_shape: torch.Size
) -> List[int]:
# infer the broadcast dims map, where it maps from the common shape dim to the input shape dim
# this is aligned with the broadcast semantics
common_ndim = len(common_shape)
input_ndim = len(input_shape)
broadcast_dims_map = [-1] * common_ndim
for idx in range(-1, -1 - input_ndim, -1):
if input_shape[idx] == common_shape[idx]:
broadcast_dims_map[common_ndim + idx] = input_ndim + idx
return broadcast_dims_map
def map_placements_after_broadcast(
placements: Tuple[Placement, ...],
shape: torch.Size,
broadcast_dims_map: List[int],
) -> Tuple[Placement, ...]:
"""Map each placement based on the output shape after broadcast."""
new_placements: List[Placement] = []
for placement in placements:
if isinstance(placement, (Replicate, Partial)):
new_placements.append(placement)
else:
assert isinstance(placement, Shard)
shard_dim = normalize_dim(placement.dim, len(shape))
new_shard_dim = broadcast_dims_map[shard_dim]
if new_shard_dim != -1:
# there's a map from the common shape shard dim to
# the input shape shard dim before broadcasting,
# use that instead
new_placements.append(Shard(new_shard_dim))
else:
# there's no map between common shape shard dim and
# the input shape shard dim before broadcasting,
# in this case it means implicit broadcasting happen
# in this dim, so we can just mark it as replicate
# and implict broadcast will broadcast automatically
# to the sharded shape
new_placements.append(Replicate())
return tuple(new_placements)
def generate_redistribute_costs(
src_strategy: OpStrategy, dst_spec: DTensorSpec
) -> List[float]:
redistribute_costs: List[float] = []
for strat in src_strategy.strategies:
redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec))
return redistribute_costs
def expand_to_full_mesh_op_strategy(
mesh: DeviceMesh,
op_schema: OpSchema,
single_mesh_dim_strategies: List[List[Placement]],
*,
input_index: int = 1,
inplace_op: bool = False,
) -> OpStrategy:
# Expand the single_mesh_dim_strategies to full mesh dim strategies.
all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = []
for specs in zip(*strategy_comb):
spec_list.append(DTensorSpec(mesh, tuple(specs)))
input_specs = spec_list[input_index:]
input_args_strategy = op_schema.args_strategy
assert len(input_specs) == len(input_args_strategy)
self_spec = input_args_strategy[0].strategies[0].output_spec
if inplace_op and self_spec.placements != input_specs[0].placements:
# if it's inplace op, we would only allow the placement strategy to be added when the
# input_spec matches the first argument's runtime sharding, otherwise we skip
continue
# check inputs shardable
inputs_shardable = all(
is_tensor_shardable(inp.shape, s)
for inp, s in zip(input_args_strategy, input_specs)
)
# only add to the all_strategies list when all inputs are shardable
if inputs_shardable:
redistribute_cost = [
generate_redistribute_costs(input_strategy, input_spec)
for input_strategy, input_spec in zip(input_args_strategy, input_specs)
]
strategy = PlacementStrategy(
output_specs=tuple(spec_list[:input_index])
if input_index > 1
else spec_list[0],
input_specs=input_specs,
redistribute_cost=redistribute_cost,
)
all_strategies.append(strategy)
return OpStrategy(all_strategies)