Repository URL to install this package:
|
Version:
2.4.0 ▾
|
import itertools
from dataclasses import dataclass
from typing import List, Set, Tuple
from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy
from torch.distributed._tensor.placement_types import (
DTensorSpec,
Partial,
Placement,
Replicate,
Shard,
)
from torch.distributed.device_mesh import DeviceMesh
@dataclass
class EinsumDims:
contracting_dims: List[str]
batch_dims: List[str]
lhs_out_only_dims: List[str]
rhs_out_only_dims: List[str]
@classmethod
def parse_equation(cls, equation: str) -> Tuple[List[str], str]:
# parse einop equation and extract arg specs
"""
Parse the einsum equation str to input dim chars and output dim char
"""
inputs, outputs = equation.split("->")
input_dims, output_dims = inputs.split(","), outputs.split(",")
# NOTE: only support at most two inputs, and single output
# extend to support more inputs if needed in future
assert len(input_dims) <= 2, "Only support at most two inputs"
assert len(output_dims) == 1, "Only support single output"
output_dim = output_dims[0]
return input_dims, output_dim
@classmethod
def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims":
"""
Parse the dims and extract the contracting, batch, and free dimensions
for the left and right hand sides.
"""
dim_char_set: Set[str] = set()
for input_dim in input_dims:
dim_char_set.update(input_dim)
# get a determinisitc order of all dim chars
all_dim_chars = sorted(dim_char_set)
# parse input and output dimensions
lhs_out_only_dims, rhs_out_only_dims = [], []
batch_dims, contracting_dims = [], []
for dim_char in all_dim_chars:
if dim_char not in output_dim:
contracting_dims.append(dim_char)
else:
is_batch_dim = True
for input_dim in input_dims:
is_batch_dim = is_batch_dim and dim_char in input_dim
if is_batch_dim:
batch_dims.append(dim_char)
else:
assert (
len(input_dims) == 2
), "free dimension only supported for two inputs!"
lhs, rhs = input_dims
if dim_char in lhs:
lhs_out_only_dims.append(dim_char)
elif dim_char in rhs:
rhs_out_only_dims.append(dim_char)
else:
raise RuntimeError("Invalid dimension character")
return cls(
contracting_dims=contracting_dims,
batch_dims=batch_dims,
lhs_out_only_dims=lhs_out_only_dims,
rhs_out_only_dims=rhs_out_only_dims,
)
def gen_einsum_strategies(
equation: str,
mesh: DeviceMesh,
*,
linearity: bool = False,
) -> OpStrategy:
"""
Generate a strategy list for the ops that follow einsum style notation.
"""
# parse einop equation and extract dims
input_dims, output_dim = EinsumDims.parse_equation(equation)
edims = EinsumDims.parse_dims(input_dims, output_dim)
all_mesh_dim_strategies = []
# generate strategies for each mesh dim
for mesh_dim in range(mesh.ndim):
mesh_dim_strategies = []
# placement list stores placements of [output, input1, input2, ...]
# first we always have replicate all for inputs and output
placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1)
mesh_dim_strategies.append(placement_list)
if mesh.size(mesh_dim) <= 1:
# only replicate strategy for mesh dim with size 1
# TODO: see if this is valid for the submesh case
continue
# split batch dim
for batch_dim in edims.batch_dims:
output_batch_dim = output_dim.index(batch_dim)
placement_list = [Shard(output_batch_dim)]
for input_dim in input_dims:
input_batch_dim = input_dim.index(batch_dim)
placement_list.append(Shard(input_batch_dim))
mesh_dim_strategies.append(placement_list)
# split contracting dim
for contracting_dim in edims.contracting_dims:
placement_list = [Partial()]
for input_dim in input_dims:
input_contracting_dim = input_dim.index(contracting_dim)
placement_list.append(Shard(input_contracting_dim))
mesh_dim_strategies.append(placement_list)
# split lhs free dim
for lhs_dim in edims.lhs_out_only_dims:
lhs_free_dim = output_dim.index(lhs_dim)
# this means split the lhs input and output
# i.e. S(0), R -> S(0)
lhs_placement_list: List[Placement] = [
Shard(lhs_free_dim),
Shard(lhs_free_dim),
Replicate(),
]
mesh_dim_strategies.append(lhs_placement_list)
# split rhs free dim
for rhs_dim in edims.rhs_out_only_dims:
rhs_free_dim = output_dim.index(rhs_dim)
rhs_placement_list: List[Placement] = [
Shard(rhs_free_dim),
Replicate(),
Shard(rhs_free_dim),
]
mesh_dim_strategies.append(rhs_placement_list)
# linearity strategy
if linearity:
linearity_placement_list: List[Placement] = [Partial()]
for input_dim in input_dims:
linearity_placement_list.append(Partial())
mesh_dim_strategies.append(linearity_placement_list)
all_mesh_dim_strategies.append(mesh_dim_strategies)
# generate strategies for entire mesh
strategy_combs = itertools.product(*all_mesh_dim_strategies)
# TODO: filter out invalid strategies, at this point we generate
# all possible strategies without considering the whether the tensor
# dim could be sharded or not, we would need to filter out invalid
# strategies base on the actual tensor shape
# (i.e. for Shard, tensor dim size must > mesh size)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = []
for specs in zip(*strategy_comb):
spec_list.append(DTensorSpec(mesh, tuple(specs)))
strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:])
all_strategies.append(strat)
return OpStrategy(all_strategies)