Repository URL to install this package:
|
Version:
2.4.0 ▾
|
# mypy: allow-untyped-defs
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.debug.comm_mode import CommDebugMode
def get_sharding_prop_cache_info():
"""
Get the cache info for the sharding propagation cache, used for debugging purpose only.
This would return a named tuple showing hits, misses, maxsize and cursize of the sharding
propagator cache.
"""
return (
DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined]
)