import collections
import dataclasses
import enum
import itertools as it
import logging
from typing import (
Any,
cast,
DefaultDict,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
)
import torch
from torch._C import FunctionSchema
from torch._C._autograd import _ProfilerResult
from torch._C._profiler import (
_EventType,
_ExtraFields_Allocation,
_ExtraFields_TorchOp,
_ProfilerEvent,
_TensorMetadata,
RecordScope,
)
from torch._utils import _element_size
from torch.profiler import _utils
TensorAndID = Tuple["TensorKey", int]
log = logging.getLogger(__name__)
class Category(enum.Enum):
INPUT = enum.auto()
TEMPORARY = enum.auto()
ACTIVATION = enum.auto()
GRADIENT = enum.auto()
AUTOGRAD_DETAIL = enum.auto()
PARAMETER = enum.auto()
OPTIMIZER_STATE = enum.auto()
class Action(enum.Enum):
PREEXISTING = enum.auto()
CREATE = enum.auto()
INCREMENT_VERSION = enum.auto()
DESTROY = enum.auto()
@dataclasses.dataclass
class _Storage:
"""Bundle storage pointer and id.
All profiling logic should use `allocation_id`, however it is useful to
print storage pointers for debugging and unit tests sometimes look up
values using the storage data pointer of a live Tensor."""
ptr: int
allocation_id: int
def __repr__(self) -> str:
return f"{hex(self.ptr):>18} ({self.allocation_id})"
def __eq__(self, other: Any) -> bool:
return isinstance(other, _Storage) and self.allocation_id == other.allocation_id
def __hash__(self) -> int:
return hash(self.allocation_id)
@dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True)
class TensorKey:
"""Hashable identifier for a storage which has been asigned an ID.
A detailed description of Tensor IDs and why they are needed is given in
`torch/csrc/profiler/collection.h` when `TensorID` is declared. To
summarize, multiple Storage buffers can map to the same logical Tensor.
This dataclass is used to refer to a concrete in-memory StorageImpl of
a Tensor.
"""
id: int
storage: _Storage
device: torch.device
def __repr__(self) -> str:
return f"id={self.id}: {repr(self.storage):<24} ({self.device})"
def __lt__(self, other: "TensorKey") -> bool:
return self._as_sortable < other._as_sortable
@staticmethod
def _make(
tensor_id: Optional[int],
storage_ptr: Optional[int],
allocation_id: Optional[int],
device: torch.device,
) -> Optional["TensorKey"]:
if (
tensor_id is not None
and storage_ptr is not None
and allocation_id is not None
):
return TensorKey(tensor_id, _Storage(storage_ptr, allocation_id), device)
return None
@classmethod
def from_allocation(cls, alloc: _ExtraFields_Allocation) -> Optional["TensorKey"]:
return cls._make(alloc.id, alloc.ptr, alloc.allocation_id, alloc.device)
@classmethod
def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]:
if t is not None:
return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device)
return None
@property
def _as_sortable(self) -> Tuple[int, int, str, int]:
return self.id, self.storage.allocation_id, self.device.type, self.device.index
def _extract_parameters_and_gradients(
node: _ProfilerEvent,
) -> Iterator[Tuple[Optional[TensorKey], Optional[TensorKey]]]:
children = node.children
# AccumulateGrad is used in the Autograd engine to handle gradient updates.
# There are two possible cases:
# 1) This is a newly created gradient Tensor. In that case there is nothing
# to accumulate, so autograd simply detaches the Tensor.
#
# 2) There is a preexisting gradient Tensor and we need to add the newly
# computed update. This is done with an in-place add (aten::add_) op.
# (The underscore suffix denotes "in-place".)
if (
node.typed[0] == _EventType.TorchOp
and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION
# TODO(robieta): Move away from load bearing names
and node.name == "torch::autograd::AccumulateGrad"
and children
and children[0].typed[0] == _EventType.TorchOp
and children[0].name in ("aten::detach", "aten::add_")
and children[0].typed[1].inputs
and isinstance(children[0].typed[1].inputs[0], _TensorMetadata)
):
yield None, TensorKey.from_tensor(children[0].typed[1].inputs[0])
# We directly instrument `torch.nn.Module` and `torch.optim.Optimizer`
# NOTE: The values captured by the python tracer are cached; they can be
# used to build up labels but do not imply that a Tensor was live at
# a particular time.
elif node.typed[0] == _EventType.PyCall:
typed_fields = node.typed[1]
assert typed_fields.module is None or typed_fields.optimizer is None
if typed_fields.module is not None:
for _, p, p_grad in typed_fields.module.parameters:
yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad)
if typed_fields.optimizer is not None:
for p, p_grad, _ in typed_fields.optimizer.parameters:
yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad)
def extract_parameters(node: _ProfilerEvent) -> Iterator[TensorKey]:
for p, p_grad in _extract_parameters_and_gradients(node):
if p is not None:
yield p
def extract_gradients(
node: _ProfilerEvent,
) -> Iterator[Tuple[Optional[TensorKey], TensorKey]]:
for p, p_grad in _extract_parameters_and_gradients(node):
if p_grad is not None:
yield p, p_grad
def get_scopes(event: Optional[_ProfilerEvent]) -> Tuple[RecordScope, ...]:
scopes = []
while event:
if event.typed[0] == _EventType.TorchOp:
scopes.append(event.typed[1].scope)
event = event.parent
return tuple(scopes)
class SchemaMatcher:
"""Lookup operator schema based on profiled name.
When profiling we record the operator's name but not the schema. However
some analysis requires that information. Fortunately we can look up
registered schema from the recorded name. We do not, however, record the
overload and so we must compare the profiled arguments with all overloads
to determine viable matches.
Note: Once https://github.com/pytorch/pytorch/issues/78871 is completed
this code will be obsolete.
"""
@classmethod
def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> Tuple[Optional[bool], ...]:
"""Determine which inputs may have mutated based on function schema.
Note that we don't need to resolve down to a single schema to perform
this analysis. An input is mutable if it is mutable in any overload. In
practice, however, it is overwhelmingly common to match a single
overload. If we cannot find any valid schema then we must be
conservative and assume all inputs are mutable.
"""
mutable: Optional[List[bool]] = None
for schema in cls.match_schemas(t):
mutable = mutable or [False for _ in schema.arguments]
for i, arg in enumerate(schema.arguments):
mutable[i] |= getattr(arg.alias_info, "is_write", False)
return tuple(mutable or (None for _ in t.inputs))
@classmethod
def match_schemas(cls, t: _ExtraFields_TorchOp) -> Tuple[FunctionSchema, ...]:
signature = tuple(
# Tensor
TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata)
#
# TensorList
else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list)
#
# Scalar and uncaptured inputs.
else i
for i in t.inputs
)
def matches(schema) -> bool:
return len(schema.arguments) == len(signature) and all(
cls._types_match(observed, schema_arg.type)
for observed, schema_arg in zip(signature, schema.arguments)
)
return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s))
@classmethod
def _types_match(cls, observed, schema_type) -> bool:
if isinstance(schema_type, torch._C.OptionalType):
schema_type = schema_type.getElementType()
return observed is None or cls._types_match(observed, schema_type)
if isinstance(schema_type, torch._C.AnyType):
return True
if schema_type.isSubtypeOf(torch._C.ListType.ofTensors()):
return isinstance(observed, list) and all(
isinstance(i, TensorKey) for i in observed
)
type_map: Tuple[Tuple[Any, Union[type, Tuple[type, ...]]], ...] = (
(torch._C.TensorType, TensorKey),
(torch._C.NoneType, type(None)),
(torch._C.BoolType, bool),
(torch._C.IntType, int),
(torch._C.FloatType, float),
(torch._C.ComplexType, complex),
(torch._C.NumberType, (bool, int, float, complex)),
)
for jit_type, py_types in type_map:
if isinstance(schema_type, jit_type):
return isinstance(observed, py_types)
# Profiler only records a subset of possible argument types. If we
# reach this point then the schema must call for a type that profiler
# does not record. Thus, the schema can only be a match if `observed`
# is also None.
return observed is None
@staticmethod
def lookup_schemas(name: str) -> Optional[Tuple[FunctionSchema, ...]]:
# TODO(robieta):
# _jit_get_schemas_for_operator is quite expensive. (~100us / call)
# Consider adding `functools.lru_cache` if that becomes an issue.
try:
# Schema lookup will throw if `name` is malformed. (For example,
# schemas must be namespaced and schema lookup will fail if name
# does not include "::".) We simply catch the exception and return
# `None` to denote that `name` cannot be an operator name.
#
# Note that record_function annotations also go through this path,
# so it is expected that some names will not correspond to PyTorch
# operators.
return tuple(torch._C._jit_get_schemas_for_operator(name))
except RuntimeError:
return None
class OpTree:
def __init__(self, result: _ProfilerResult) -> None:
self._root_nodes = result.experimental_event_tree()
self._sorted_nodes = tuple(sorted(self.dfs(), key=lambda x: x.start_time_ns))
def dfs(self, *args, **kwargs) -> Iterator[_ProfilerEvent]:
yield from _utils.traverse_dfs(self._root_nodes, *args, **kwargs)
@property
def sorted_nodes(self) -> Tuple[_ProfilerEvent, ...]:
return self._sorted_nodes
class SizeMap:
def __init__(self, op_tree: OpTree) -> None:
self._values: Dict[TensorKey, int] = {}
for node in op_tree.sorted_nodes:
if node.typed[0] == _EventType.TorchOp:
for t in self._flat_tensor_inputs(node.typed[1]):
self._update_values(t)
elif node.typed[0] == _EventType.PyCall:
typed_fields = node.typed[1]
assert typed_fields.module is None or typed_fields.optimizer is None
if typed_fields.module is not None:
for _, p, p_grad in typed_fields.module.parameters:
self._update_values(p)
self._update_values(p_grad)
if typed_fields.optimizer is not None:
for p, p_grad, state in typed_fields.optimizer.parameters:
self._update_values(p)
self._update_values(p_grad)
for _, t in state:
self._update_values(t)
allocations: Dict[TensorKey, int] = {}
for node in op_tree.sorted_nodes:
if node.typed[0] == _EventType.Allocation:
alloc_fields = node.typed[1]
key = TensorKey.from_allocation(alloc_fields)
if key:
new_size = abs(alloc_fields.alloc_size)
prior_size = allocations.setdefault(key, new_size)
# It is possible to resize Storage in PyTorch, however we
# key on data pointer so most resizes will be treated as a
Loading ...