Repository URL to install this package:
Version:
2.1.2+cpu ▾
|
from typing import Callable, Dict, List, Set
import torch
import torch.fx as fx
import torch.utils._pytree as pytree
from torch import Tensor
from torch.distributed._tensor import DeviceMesh, Replicate, Shard
from torch.distributed._tensor.ops.view_ops import (
DimSpec,
InputDim,
ops as view_op_rules,
)
from torch.distributed._tensor.placement_types import _Partial, DTensorSpec
aten = torch.ops.aten
class BatchDimAnalyzer:
"""
This class is used to analyze the batch dimension of each tensor/node in the
graph. We need to know the batch dimension of each tensor/node so that we know
exactly the sharding layout of intermediate tensors.
We possibly should evaluate using symbolic shapes to track the batch dimension.
We can experiment it later with dynamo integration (as dynamo have mark_dynamic
API which allows marking batch dimension only) or try to use FakeTensorMode to
mark the batch dimension. For now, let's just use the batch dimension of the first
input tensor as the hint to track the batch dimension of all tensors/nodes in
the graph.
"""
def __init__(self, batch_dim: int = 0) -> None:
self.batch_dim = batch_dim
self.batch_dim_map: Dict[fx.Node, int] = {}
# batch dim size is used to track the batch dim size of the input tensor
self.batch_dim_size = -1
self.dim_rule_map: Dict[torch._ops.OpOverload, Callable[..., torch.Tensor]] = {
aten.squeeze.default: torch.squeeze,
aten.squeeze.dim: torch.squeeze,
aten.view.default: Tensor.view,
aten.reshape.default: torch.reshape,
aten._unsafe_view.default: Tensor.view,
aten.unsqueeze.default: torch.unsqueeze,
aten.expand.default: Tensor.expand,
aten.permute.default: torch.permute,
aten.repeat.default: Tensor.repeat,
aten.transpose.int: torch.transpose,
}
def init_batch_dim_size(self, batch_dim_size: int) -> None:
"""
initialize batch dim size base on the first input batch size
"""
if self.batch_dim_size != -1 and self.batch_dim_size != batch_dim_size:
raise RuntimeError(
f"batch dim size is already initialized! "
f"Found new batch size: {batch_dim_size} not "
f"matching existing batch dim size: {self.batch_dim_size}!"
)
self.batch_dim_size = batch_dim_size
def set_batch_dim(self, node: fx.Node, batch_dim: int) -> None:
self.batch_dim_map[node] = batch_dim
def get_batch_dim(self, node: fx.Node) -> int:
if node not in self.batch_dim_map:
raise RuntimeError(f"batch dim analysis failed on node: {node}!")
return self.batch_dim_map[node]
def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int:
"""
compute the batch dimension for the `node`
"""
assert self.batch_dim_size != -1, "batch dim size is not initialized!"
if node in self.batch_dim_map:
# if batch dim already computed, simply return it
return self.batch_dim_map[node]
if node.target in self.dim_rule_map:
view_op_rule = view_op_rules[self.dim_rule_map[node.target]] # type: ignore[index]
args_val = pytree.tree_map_only(fx.Node, lambda n: n.meta["val"], node.args)
kwargs_val = pytree.tree_map_only(
fx.Node, lambda n: n.meta["val"], node.kwargs
)
output_dim_rules = view_op_rule.dim_map(*args_val, **kwargs_val)
def collect_input_dim(cmd: DimSpec, input_dims: Set[int]):
if isinstance(cmd, InputDim):
input_dims.add(cmd.input_dim)
for inp in cmd.inputs():
collect_input_dim(inp, input_dims)
output_dim_to_input_dims: List[Set[int]] = []
for inp in output_dim_rules:
input_dims: Set[int] = set()
collect_input_dim(inp, input_dims=input_dims)
output_dim_to_input_dims.append(input_dims)
operand = node.all_input_nodes[0]
operand_batch_dim = self.get_batch_dim(operand)
for output_dim, input_dims in enumerate(output_dim_to_input_dims):
if operand_batch_dim in input_dims:
self.set_batch_dim(node, output_dim)
# update batch dim size before return
# this is because batch dim size might change during the middle
self.batch_dim_size = node.meta["val"].shape[output_dim]
return output_dim
# if there's no hints from the output_dim_rules, we infer from output
# shape to see if there's batch dim, and shard correspondingly
node_val = node.meta["val"]
if isinstance(node_val, (list, tuple)):
shapes = [val.shape for val in node_val]
else:
shapes = [node_val.shape]
# for reduction op that reduces over the sharded batch dim
# we don't generate partial, but rather, we generate shard
# This is because the intention of data parallel is to never
# do full reduction across batch dimension, it would still
# keep the reduction activation as sharded.
full_reduction = False
# loop through the dim size to find the output batch dim
for shape in shapes:
if len(shape) == 0:
full_reduction = True
for i, dim_size in enumerate(shape):
if dim_size == self.batch_dim_size:
self.set_batch_dim(node, i)
return i
operands = node.all_input_nodes
if not operands:
# if there's no operands, it must be factory ops and it's a tensor
# generated for computation and should be marked as replicated
self.set_batch_dim(node, -1)
# -1 means replicated
return -1
else:
# if there's operand we see the operand have batch dim, if operand
# have batch dim but output does not, it's either a full reduction,
# where we should stay sharded, or it's a reduction on batch dim only
# where we should produce partial
operand_batch_dim = -1
for operand in operands:
if operand in self.batch_dim_map:
operand_batch_dim = self.get_batch_dim(operand)
# self.get_batch_dim(operands[0])
if operand_batch_dim < 0:
# if operand does not have batch dim, we also don't have batch dim
self.set_batch_dim(node, operand_batch_dim)
return operand_batch_dim
elif full_reduction:
self.set_batch_dim(node, operand_batch_dim)
return operand_batch_dim
else:
# if operand have batch dim but output does not, it should
# produce partial, we use -2 to indicate partial
self.set_batch_dim(node, -2)
return -2
def compute_act_spec(self, node: fx.Node, mesh: DeviceMesh) -> DTensorSpec:
"""
This function first compute the batch dimension for the current node,
then generate the sharding spec that shards on the batch dimension.
"""
node_batch_dim = self.compute_batch_dim(node)
if node_batch_dim == -1:
# indicate this activation is replicated
act_spec = DTensorSpec(mesh=mesh, placements=(Replicate(),))
elif node_batch_dim == -2:
# indicate this activation is partial
act_spec = DTensorSpec(mesh=mesh, placements=(_Partial(),))
else:
# indicate this activation is Shard
act_spec = DTensorSpec(mesh=mesh, placements=(Shard(node_batch_dim),))
return act_spec