from typing import Any, Dict, List, Optional
from collections import defaultdict
from warnings import warn
import torch
import torch.cuda
from torch._C._profiler import _ExperimentalConfig
from torch.autograd import (
_disable_profiler,
_enable_profiler,
_kineto_step,
_prepare_profiler,
_ProfilerResult,
_supported_activities,
DeviceType,
kineto_available,
ProfilerActivity,
ProfilerConfig,
ProfilerState,
)
from torch.autograd.profiler_util import (
_filter_name,
_filter_stack_entry,
_rewrite_name,
EventList,
FunctionEvent,
MEMORY_EVENT_NAME,
MemRecordsAcc,
OUT_OF_MEMORY_EVENT_NAME,
)
from torch.futures import Future
__all__ = ["profile", "record_function", "emit_itt", "emit_nvtx", "load_nvprof", "EnforceUnique",
"parse_nvprof_trace", "KinetoStepTracker", "EventList", "FunctionEvent", "MemRecordsAcc"]
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator as _ContextDecorator
except ImportError:
import functools
class _ContextDecorator: # type: ignore[no-redef]
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapped
class profile:
"""Context manager that manages autograd profiler state and holds a summary of results.
Under the hood it just records events of functions being executed in C++ and
exposes those events to Python. You can wrap any code into it and it will
only report runtime of PyTorch functions.
Note: profiler is thread local and is automatically propagated into the async tasks
Args:
enabled (bool, optional): Setting this to False makes this context manager a no-op.
use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.
Adds approximately 4us of overhead to each tensor operation.
record_shapes (bool, optional): If shapes recording is set, information
about input dimensions will be collected. This allows one to see which
dimensions have been used under the hood and further group by them
using prof.key_averages(group_by_input_shape=True). Please note that
shape recording might skew your profiling data. It is recommended to
use separate runs with and without shape recording to validate the timing.
Most likely the skew will be negligible for bottom most events (in a case
of nested function calls). But for higher level functions the total
self cpu time might be artificially increased because of the shape
collection.
with_flops (bool, optional): If with_flops is set, the profiler will estimate
the FLOPs (floating point operations) value using the operator's input shape.
This allows one to estimate the hardware performance. Currently,
this option only works for the matrix multiplication and 2D convolution operators.
profile_memory (bool, optional): track tensor memory allocation/deallocation.
with_stack (bool, optional): record source information (file and line number) for the ops.
with_modules (bool): record module hierarchy (including function names)
corresponding to the callstack of the op. e.g. If module A's forward call's
module B's forward which contains an aten::add op,
then aten::add's module hierarchy is A.B
Note that this support exist, at the moment, only for TorchScript models
and not eager mode models.
use_kineto (bool, optional): experimental, enable profiling with Kineto profiler.
use_cpu (bool, optional): profile CPU events; setting to ``False`` requires
``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling.
experimental_config (_ExperimentalConfig) : A set of experimental options
used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
.. warning:
Enabling memory profiling or source attribution incurs additional profiler
overhead
.. warning:
This context managers should not be called recursively, i.e. no nested
instances are allowed
.. warning:
Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
one cannot use the profiler with ``use_cuda = True`` to benchmark
DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
please use ``use_cuda = False`` or ``num_workers = 0``.
Example:
>>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
>>> x = torch.randn((1, 1), requires_grad=True)
>>> with torch.autograd.profiler.profile() as prof:
>>> for _ in range(100): # any normal python code, really!
>>> y = x ** 2
>>> y.backward()
>>> # NOTE: some columns were removed for brevity
>>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
----------------------------------- --------------- --------------- ---------------
Name Self CPU total CPU time avg Number of Calls
----------------------------------- --------------- --------------- ---------------
mul 32.048ms 32.048ms 200
pow 27.041ms 27.041ms 200
PowBackward0 9.727ms 55.483ms 100
torch::autograd::AccumulateGrad 9.148ms 9.148ms 100
torch::autograd::GraphRoot 691.816us 691.816us 100
----------------------------------- --------------- --------------- ---------------
"""
def __init__(
self,
enabled=True,
*,
use_cuda=False,
record_shapes=False,
with_flops=False,
profile_memory=False,
with_stack=False,
with_modules=False,
use_kineto=False,
use_cpu=True,
experimental_config=None):
self.enabled: bool = enabled
if not self.enabled:
return
self.use_cuda = use_cuda
self.function_events: Optional[EventList] = None
self.entered = False
self.record_shapes = record_shapes
self.with_flops = with_flops
self.record_shapes |= self.with_flops
self.profile_memory = profile_memory
self.with_stack = with_stack
self.with_modules = with_modules
self.use_cpu = use_cpu
if experimental_config is None:
experimental_config = _ExperimentalConfig()
self.experimental_config = experimental_config
self.kineto_results: Optional[_ProfilerResult] = None
if not self.use_cpu:
assert use_kineto, \
"Device-only events supported only with Kineto (use_kineto=True)"
if self.use_cuda and not torch.cuda.is_available():
warn("CUDA is not available, disabling CUDA profiling")
self.use_cuda = False
self.kineto_activities = set()
if self.use_cpu:
self.kineto_activities.add(ProfilerActivity.CPU)
self.profiler_kind = ProfilerState.KINETO
if self.use_cuda:
if (not use_kineto or ProfilerActivity.CUDA not in
_supported_activities()):
assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
else:
self.kineto_activities.add(ProfilerActivity.CUDA)
assert len(self.kineto_activities) > 0, \
"No activities specified for the profiler"
def config(self):
return ProfilerConfig(
self.profiler_kind,
self.record_shapes,
self.profile_memory,
self.with_stack,
self.with_flops,
self.with_modules,
self.experimental_config)
def __enter__(self):
if not self.enabled:
return
if self.entered:
raise RuntimeError("Profiler context manager is not reentrant")
self._prepare_trace()
self._start_trace()
return self
def _prepare_trace(self):
self.entered = True
_prepare_profiler(self.config(), self.kineto_activities)
def _start_trace(self):
self.entered = True
_enable_profiler(self.config(), self.kineto_activities)
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return
if self.use_cuda:
torch.cuda.synchronize()
self.kineto_results = _disable_profiler()
parsed_results = self._parse_kineto_results(self.kineto_results)
self.function_events = EventList(
parsed_results,
use_cuda=self.use_cuda,
profile_memory=self.profile_memory,
with_flops=self.with_flops)
self.function_events._build_tree()
return False
def __repr__(self):
if self.function_events is None:
return '<unfinished torch.autograd.profile>'
return repr(self.function_events)
def __str__(self):
if self.function_events is None:
return '<unfinished torch.autograd.profile>'
return str(self.function_events)
def _check_finish(self):
if self.function_events is None:
raise RuntimeError("Profiler didn't finish running")
def table(
self,
sort_by=None,
row_limit=100,
max_src_column_width=75,
max_name_column_width=55,
max_shapes_column_width=80,
header=None,
top_level_events_only=False
):
self._check_finish()
assert self.function_events is not None
return self.function_events.table(
sort_by=sort_by,
row_limit=row_limit,
max_src_column_width=max_src_column_width,
max_name_column_width=max_name_column_width,
max_shapes_column_width=max_shapes_column_width,
header=header,
top_level_events_only=top_level_events_only
)
table.__doc__ = EventList.table.__doc__
def export_chrome_trace(self, path):
self._check_finish()
if kineto_available():
self.kineto_results.save(path) # type: ignore[union-attr]
else:
return self.function_events.export_chrome_trace(path) # type: ignore[union-attr]
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
self._check_finish()
assert self.function_events is not None, "Expected profiling results"
assert self.with_stack, "export_stacks() requires with_stack=True"
return self.function_events.export_stacks(path, metric)
def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
self._check_finish()
assert self.function_events is not None, "Expected profiling results"
return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
key_averages.__doc__ = EventList.key_averages.__doc__
def total_average(self):
self._check_finish()
assert self.function_events is not None, "Expected profiling results"
return self.function_events.total_average()
total_average.__doc__ = EventList.total_average.__doc__
@property
def self_cpu_time_total(self):
""" Returns total time spent on CPU obtained as a sum of
all self times across all the events.
"""
self._check_finish()
assert self.function_events is not None
return self.function_events.self_cpu_time_total
def _parse_kineto_results(self, result):
# result.events() has most of the events - PyTorch op-level and device-level events
trace_start_us = result.trace_start_us()
mem_records = [[evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME]
oom_records = [evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME]
mem_records_acc = MemRecordsAcc(mem_records)
def _cpu_memory_usage(mem_record):
return mem_record.nbytes() if \
mem_record.device_type() in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP] \
else 0
def _cuda_memory_usage(mem_record):
return mem_record.nbytes() if \
mem_record.device_type() in [DeviceType.CUDA, DeviceType.HIP] \
else 0
# Create and return FunctionEvent list
function_events = []
cuda_corr_map: Dict[int, List[FunctionEvent]] = {}
max_evt_id = 0
for kineto_event in result.events():
if _filter_name(kineto_event.name()):
continue
rel_start_us = kineto_event.start_us() - trace_start_us
rel_end_us = rel_start_us + kineto_event.duration_us()
abs_end_us = kineto_event.start_us() + kineto_event.duration_us()
cpu_memory_usage = 0
cuda_memory_usage = 0
if kineto_event.device_type() == DeviceType.CPU:
# find the corresponding memory allocation events
Loading ...