from typing import Callable, Dict, Tuple
import torch
import torch.distributed._tensor.api as dtensor
from torch._ops import OpOverload
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.utils._pytree import tree_map
"""
Print information on ops input shape and sharding for debugging purposes.
"""
_DEBUG_VERBOSE = False
def unwrap_schema(e: object) -> object:
return e._spec if isinstance(e, dtensor.DTensor) else e
class ShardingPropagator:
def __init__(self) -> None:
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
def register_sharding_prop_rule(
self, op_overload: OpOverload, rule_func: Callable[[OpSchema], OutputSharding]
):
"""
Register a sharding propagation rule for an operator.
"""
self.op_to_rules[op_overload] = rule_func
def prepare_op_schema(
self,
op_call: OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object]
) -> OpSchema:
"""
This unwrap the args/kwargs DTensor to DTensorSpec and pack them
into an OpSchema for sharding propagation usage.
"""
args_schema = tree_map(unwrap_schema, args)
kwargs_schema = tree_map(unwrap_schema, kwargs)
op_schema = OpSchema(op_call._schema, args_schema, kwargs_schema)
if _DEBUG_VERBOSE and torch.distributed.get_rank() == 0:
print(f"OpSchema({op_schema})")
local_shapes = tree_map(
lambda t: t.to_local().shape if isinstance(t, dtensor.DTensor) else None,
args,
)
print(f" local shapes: {local_shapes}")
return op_schema
def propagate_op_sharding(
self, op_overload: OpOverload, op_schema: OpSchema
) -> OutputSharding:
"""
Propagate the sharding for an operator given the op_schema.
"""
sharding_prop_func = self.op_to_rules.get(op_overload, None)
if sharding_prop_func is None:
# step 1. If there's not even one sharding rule
# implemented for the operator, we error out.
raise NotImplementedError(
f"Operator {op_overload} does not have a DistributedTensor rule registered."
)
# step 2. there's sharding propagation rule, run
# sharding propagation to get the output sharding
try:
output_sharding = sharding_prop_func(op_schema)
except Exception as e:
raise RuntimeError(
f"Sharding propagation failed on op {op_overload}.\n"
f"Input schema: {op_schema}.\n"
f"Error: {e}"
) from e
# step 3. if can't get output_spec from sharding
# propagation (i.e. no rules apply for input
# placements), we return the output sharding
# with schema suggestions, which can be used to
# decide how to do redistribute on inputs
if output_sharding.output_spec is None:
if output_sharding.schema_suggestions is None:
raise RuntimeError(
f"Sharding propagation failed on op {op_overload}!"
f"Input schema: {op_schema}."
f"Failed reason: {output_sharding.failed_reason}"
)
else:
# we do auto redistribute on inputs if necessary
# to get an eligble input, which we will pick a
# schema suggestion base on the redistribute cost.
# For now we simply pick the first suggestion.
# TODO: implement full auto distribute with a
# simple cost estimation model
suggested_input_schema = output_sharding.schema_suggestions[0]
# run sharding propagation again with suggested schema
propagation_res = sharding_prop_func(suggested_input_schema)
# we set the output sharding with the new propagation result
# so that dispatching know both output_spec and schema_suggestions
# exist, which indicates a reshard is needed
output_sharding.output_spec = propagation_res.output_spec
else:
# if sharding propagation succeed, we set the schema suggestion to
# the default op_schema, which indicates no reshard is needed
output_sharding.schema_suggestions = [op_schema]
return output_sharding
class _CachingPropagator(ShardingPropagator):
"""
A sharding propagator that caches the propagation results.
This is currently experimental for Tensor Parallel usage.
"""
def __init__(self, op_to_rules=None) -> None:
super().__init__()
if op_to_rules is not None:
self.op_to_rules = op_to_rules
# cache table for sharding propagation results, we might need to
# limit the size of the cache table in the future
self.cached_prop_results: Dict[OpSchema, OutputSharding] = {}
def propagate_op_sharding(
self, op_overload: OpOverload, op_schema: OpSchema
) -> OutputSharding:
"""
Propagate the sharding for an operator given the op_schema.
Cache the propagation results to avoid running propagation again.
"""
if op_schema in self.cached_prop_results:
return self.cached_prop_results[op_schema]
else:
# call DTensor's propagate_op_sharding to get the prop result
output_sharding = super().propagate_op_sharding(op_overload, op_schema)
# update cached table
self.cached_prop_results[op_schema] = output_sharding
return output_sharding