import json
import math
import os
import re
from typing import Dict, List, Optional, Set
import torch
from torch.profiler import profile
import torch.utils.benchmark as benchmark
from torch.profiler._utils import index_of_first_match, traverse_bfs, traverse_dfs
from torch._C._profiler import (_ProfilerEvent, _ExtraFields_TorchOp,
_ExtraFields_PyCCall, _ExtraFields_PyCall,
_EventType)
class Pattern:
'''
Base class for all patterns, subclass this class and implement match()
to define custom patterns.
In subclass, define description and skip property.
'''
def __init__(self, prof: profile, should_benchmark: bool = False):
self.prof = prof
self.should_benchmark = should_benchmark
self.name = "Please specify a name for pattern"
self.description = "Please specify a description for pattern"
self.url = ""
assert prof.profiler is not None and prof.profiler.kineto_results is not None
self.event_tree = prof.profiler.kineto_results.experimental_event_tree(
)
self.tid_root: Dict[int, List[_ProfilerEvent]] = {}
for event in self.event_tree:
self.tid_root.setdefault(event.start_tid, []).append(event)
@property
def skip(self):
return False
def report(self, event: _ProfilerEvent):
msg = f"{self.description}\n[Source Code Location] {source_code_location(event)}"
return msg
def eventTreeTraversal(self):
'''
Traverse the event tree and yield all events.
Override this method in subclass to customize the traversal.
'''
yield from traverse_dfs(self.event_tree)
def summary(self, events: List[_ProfilerEvent]):
default_summary = f"{self.name}: {len(events)} events matched."
if self.should_benchmark:
# If benchmark summary is not empty, use it.
return self.benchmark_summary(
events) if hasattr( # type: ignore[attr-defined]
self, 'benchmark') else default_summary
return default_summary
def benchmark_summary(self, events: List[_ProfilerEvent]):
def format_time(time_ns: int):
unit_lst = ["ns", "us", "ms"]
for unit in unit_lst:
if time_ns < 1000:
return f"{time_ns:.2f} {unit}"
time_ns //= 1000
return f"{time_ns:.2f} s"
assert hasattr(self, 'benchmark'), 'Please implement benchmark()'
shapes_factor_map = self.benchmark( # type: ignore[attr-defined]
events)
original_time = sum(event.duration_time_ns for event in events)
new_time = sum(shapes_factor_map[input_shapes(event)] *
event.duration_time_ns for event in events)
return (
f"{self.name}: {len(events)} events matched. "
f"Total Estimated Speedup: {format_time(original_time - new_time)} ({round(original_time/new_time, 2)}X)"
)
def match(self, event: _ProfilerEvent):
'''
Return True if the event matches the pattern.
This method should be overriden in subclass.
'''
raise NotImplementedError
def matched_events(self):
if self.skip:
return []
matched_events = []
for event in self.eventTreeTraversal():
if self.match(event):
matched_events.append(event)
return matched_events
def root_of(self, event: _ProfilerEvent):
while event.parent:
event = event.parent
return event
def siblings_of(self, event: _ProfilerEvent):
if event.parent:
children = event.parent.children
else:
children = self.tid_root[event.start_tid]
index = children.index(event)
return children[:index], children[index + 1:]
def next_of(self, event: _ProfilerEvent):
_, next_events = self.siblings_of(event)
return next_events[0] if next_events else None
def prev_of(self, event: _ProfilerEvent):
prev_events, _ = self.siblings_of(event)
return prev_events[-1] if prev_events else None
def go_up_until(self, event: _ProfilerEvent, predicate):
if not event:
return None
while event.parent and not predicate(event):
event = event.parent
return event
# Patterns
class NamePattern(Pattern):
def __init__(self,
prof: profile,
name: str,
should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.description = f"Matched Name Event: {name}"
self.name = name
def match(self, event: _ProfilerEvent):
return re.search(self.name, event.name) is not None
class ExtraCUDACopyPattern(Pattern):
'''
This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU.
example: torch.zeros((100, 100)).to("cuda")
Pattern:
build-in method |build-in method
... | aten::to
aten::fill_/aten::zero_ | aten::_to_copy
Algorithm:
We start at node aten::to, go parent events' previous events,
and check if we have a aten::fill_/aten::zero_ as we keep going down the tree.
We always select the last child in the children list when we go down the tree.
If at any step we failed, it is not a match.
'''
def __init__(self, prof: profile, should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.name = "Extra CUDA Copy Pattern"
self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU."
self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device"
self.init_ops = {
"aten::fill_", "aten::zero_", "aten::normal_", "aten::uniform_"
}
@property
def skip(self):
return not self.prof.with_stack or not self.prof.record_shapes
def match(self, event):
# TODO: We should also check tensor identities
if event.name != "aten::to":
return False
to_event = event
if not event.children:
return False
event = event.children[-1]
if event.name != "aten::_to_copy":
return False
if not event.children:
return False
event = event.children[-1]
if event.name != "aten::copy_":
return False
# aten::copy_ should have the first 2 args dtype the same
dtypes = input_dtypes(event)
if len(dtypes) < 2:
return False
if dtypes[0] is None or dtypes[0] != dtypes[1]:
return False
event = to_event
# Up one level
event = event.parent
if event is None:
return False
# Check if we have a aten::fill_ in previous leaf
event = self.prev_of(event)
if event is None:
return False
while event.children:
event = event.children[-1]
# aten::zero_ is a special optimzation case where fill_ is not called
if event.name in self.init_ops:
return True
return event.name in self.init_ops
# TODO: Check if tensor is reused
def benchmark(self, events: List[_ProfilerEvent]):
shapes_factor_map = {input_shapes(event): 0.0 for event in events}
for shape in shapes_factor_map:
size = shape[0]
to_timer = benchmark.Timer(stmt='torch.ones(size).to("cuda")',
globals={'size': size})
de_timer = benchmark.Timer(stmt='torch.ones(size, device="cuda")',
globals={'size': size})
to_time = to_timer.timeit(10).mean
de_time = de_timer.timeit(10).mean
shapes_factor_map[shape] = de_time / to_time
return shapes_factor_map
class ForLoopIndexingPattern(Pattern):
'''
This pattern identifies if we use a for loop to index a tensor that
can be vectorized.
example:
tensor = torch.empty((100, 100))
for i in range(100):
tensor[i] = i
Pattern:
aten::select | ... | aten::select | ... (Repeat)
Algorithm:
We start at node aten::select, and we check if we can find this alternating patterns.
We also keep a dictionary to avoid duplicate match in the for loop.
'''
def __init__(self, prof: profile, should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.name = "For Loop Indexing Pattern"
self.description = "For loop indexing detected. Vectorization recommended."
self.visited: Set[int] = set()
def eventTreeTraversal(self):
'''
We need to use BFS traversal order to avoid duplicate match.
'''
yield from traverse_bfs(self.event_tree)
def match(self, event: _ProfilerEvent):
if event.name != "aten::select":
return False
if event.id in self.visited:
return False
repeat_count = 1
_, next = self.siblings_of(event)
if len(next) <= 1:
return False
# Custom event list matching
def same_ops(list1, list2):
if len(list1) != len(list2):
return False
for op1, op2 in zip(list1, list2):
if op1.name != op2.name:
return False
return True
# Record the ops between two aten::select
next_select_idx = index_of_first_match(
next, lambda e: e.name == "aten::select")
if next_select_idx is None:
return False
indexing_ops = [event] + next[:next_select_idx]
next = next[len(indexing_ops) - 1:]
for i in range(0, len(next), len(indexing_ops)):
if same_ops(indexing_ops, next[i:i + len(indexing_ops)]):
repeat_count += 1
self.visited.add(next[i].id)
else:
break
return repeat_count >= 10
class FP32MatMulPattern(Pattern):
def __init__(self, prof: profile, should_benchmark: bool = False):
super().__init__(prof, should_benchmark)
self.name = "FP32 MatMul Pattern"
self.description = (
"You are currently using GPU that supports TF32. "
"Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'"
)
self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
@property
def skip(self):
if torch.version.hip is not None:
has_tf32 = False
else:
# Anything less than sm_80 is not Ampere which doesn't support TF32
has_tf32 = all(
int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list())
return has_tf32 is False or super().skip or not self.prof.record_shapes
def match(self, event: _ProfilerEvent):
# If we saw this pattern once, we don't need to match it again
if event.tag != _EventType.TorchOp:
return False
assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
if event.name == "aten::mm":
if event.extra_fields.allow_tf32_cublas is False:
return True
return False
def report(self, event: _ProfilerEvent):
return self.description
def benchmark(self, events: List[_ProfilerEvent]):
shapes_factor_map = {input_shapes(event): 0.0 for event in events}
for shape in shapes_factor_map:
matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32)
matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32)
fp32_timer = benchmark.Timer(stmt='torch.mm(matrixA, matrixB)',
globals={
"matrixA": matrixA,
"matrixB": matrixB
})
tf32_timer = benchmark.Timer(
stmt='torch.mm(matrixA, matrixB)',
setup='torch.backends.cuda.matmul.allow_tf32 = True',
globals={
"matrixA": matrixA,
"matrixB": matrixB
})
torch.backends.cuda.matmul.allow_tf32 = False
fp32_time = fp32_timer.timeit(10).mean
tf32_time = tf32_timer.timeit(10).mean
shapes_factor_map[shape] = tf32_time / fp32_time
return shapes_factor_map
Loading ...