import collections
import contextlib
import copy
import cProfile
import dataclasses
import datetime
import dis
import enum
import functools
import gc
import inspect
import itertools
import logging.config
import math
import operator
import os
import pstats
import re
import sys
import time
import types
import typing
import weakref
from contextlib import contextmanager
from functools import lru_cache, wraps
from typing import Any, Dict, List
try:
import numpy as np
HAS_NUMPY = True
except ModuleNotFoundError:
np = None # type: ignore[assignment]
HAS_NUMPY = False
import importlib
import torch
import torch.fx.experimental.symbolic_shapes
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses.fake_tensor import FakeTensor
from torch.nn.modules.lazy import LazyModuleMixin
from torch.utils._pytree import tree_flatten, tree_map
from . import config, logging as torchdynamo_logging
counters = collections.defaultdict(collections.Counter)
troubleshooting_url = "https://pytorch.org/docs/master/dynamo/troubleshooting.html"
log = logging.getLogger(__name__)
# profiling compilation time
compilation_metrics = collections.OrderedDict()
timer_counter = itertools.count()
def tabulate(rows, headers):
try:
import tabulate
return tabulate.tabulate(rows, headers=headers)
except ImportError:
return "\n".join(
", ".join(map(str, row)) for row in itertools.chain([headers], rows)
)
def dynamo_profiled(func):
@wraps(func)
def profile_wrapper(*args, **kwargs):
global timer_counter
datafn = (
func.__name__ + f"{next(timer_counter)}.profile"
) # Name the data file sensibly
prof = cProfile.Profile()
prof.enable()
retval = prof.runcall(func, *args, **kwargs)
prof.disable()
print(f"### Cprofile for {func.__name__} iter {next(timer_counter)} ###")
ps = pstats.Stats(prof)
ps.sort_stats(pstats.SortKey.TIME).print_stats(20)
ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20)
prof.dump_stats(datafn)
return retval
return profile_wrapper
frame_phase_timing = collections.OrderedDict()
curr_frame = 0
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
def increment_frame():
global curr_frame
curr_frame = curr_frame + 1
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
def reset_frame_count():
global curr_frame
frame_phase_timing.clear()
curr_frame = 0
op_count = 0
def increment_op_count(cnt):
global op_count
op_count += cnt
# Print a report of time spent so far
# Ex:
# TIMING:
# entire_frame_compile:8.574629999999999
# backend_compile:5.26806
def print_time_report():
total = 0
total_by_key = {}
for frame, timings in frame_phase_timing.items():
for key, timing in timings.items():
total += timing
if key not in total_by_key:
total_by_key[key] = timing
else:
total_by_key[key] += timing
out = "TIMING:"
for key, value in total_by_key.items():
out = f"{out} {key}:{round(value, 5)}"
print(out)
# dynamo_timed API works as a function decorator
# By wrapping a function in dynamo_timed, we can store a record in compilation_metrics
# where the key is the functions name.
# For example:
#
# @dynamo_timed
# def _foo(...):
#
# Would show up as an entry in our timing dict:
# OrderedDict([('bar.<locals>._foo', [0.083690, 0.23949, 3.1425e-05])])
# This is extremely useful for granular debugging.
#
# For a higher-level mode, pass a phase_name into dynamo_timed
# phase_names record an extra record into a separate compilation timing structure,
# one keyed on frame+name rather than function.
# The frame is incremented outside of this function, in def increment_frame() above.
def dynamo_timed(original_function=None, phase_name=None):
def dynamo_timed_inner(func):
@wraps(func)
def time_wrapper(*args, **kwargs):
key = func.__qualname__
if key not in compilation_metrics:
compilation_metrics[key] = []
t0 = time.time()
r = func(*args, **kwargs)
time_spent = time.time() - t0
# print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
compilation_metrics[key].append(time_spent)
if phase_name:
frame_key = str(curr_frame)
if frame_key not in frame_phase_timing:
frame_phase_timing[frame_key] = {}
assert (
phase_name not in frame_phase_timing[frame_key]
), f"Duplicate phase name {phase_name} for frame {frame_key}"
frame_phase_timing[frame_key][phase_name] = time_spent
return r
return time_wrapper
if original_function:
return dynamo_timed_inner(original_function)
return dynamo_timed_inner
def compile_times(repr="str", aggregate=False):
"""
Get metrics about torchdynamo frontend/backend compilation times.
Accumulates information from functions tagged with `@dynamo_timed`.
repr='str' returns a printable string for user interaction, and 'csv'
returns headers, rows which can be logged for output
aggregate causes values from multiple compilations (e.g. split graphs)
to be accumulated into one value. If false, expect more than one value
per metric.
"""
def fmt_fn(values, item_fn=lambda x: x):
if aggregate:
return item_fn(sum(values))
return ", ".join(map(item_fn, values))
if repr == "str":
rows = [
(k, fmt_fn(compilation_metrics[k], item_fn=lambda x: f"{x:.4f}"))
for k in compilation_metrics
]
out = "TorchDynamo compilation metrics:\n"
out += tabulate(rows, headers=("Function", "Runtimes (s)"))
return out
elif repr == "csv":
values = [
fmt_fn(v, item_fn=lambda x: f"{x:.6f}")
for v in compilation_metrics.values()
]
headers = list(compilation_metrics.keys())
return headers, values
tensortype_to_dtype = {
torch.FloatTensor: (torch.float32, torch.float),
torch.DoubleTensor: (torch.float64, torch.double),
torch.HalfTensor: (torch.float16, torch.half),
torch.BFloat16Tensor: (torch.bfloat16,),
torch.ByteTensor: (torch.uint8,),
torch.CharTensor: (torch.int8,),
torch.LongTensor: (torch.int64, torch.long),
torch.IntTensor: (torch.int32, torch.int),
torch.ShortTensor: (torch.int16, torch.short),
torch.BoolTensor: (torch.bool,),
}
class DuplicateWarningChecker:
def __init__(self, maxsize=4096):
self.maxsize = maxsize
self.reset()
def reset(self):
self.set = collections.OrderedDict()
def add(self, key):
if key in self.set:
self.set.move_to_end(key, last=True)
if not config.verbose:
return False
else:
self.set[key] = None
while len(self.set) > self.maxsize:
self.set.popitem(last=False)
return True
graph_break_dup_warning_checker = DuplicateWarningChecker()
def init_logging():
torchdynamo_logging.init_logging(
config.log_level, log_file_name=config.log_file_name
)
graph_break_dup_warning_checker.reset()
def format_graph_tabular(graph):
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in graph.nodes]
return tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"])
def format_bytecode(prefix, name, filename, line_no, code):
return f"{prefix} {name} {filename}\
line {line_no} \n{dis.Bytecode(code).dis()}\n "
def gen_record_file_name(exc, code):
return f"{get_debug_dir()}/error_recordings/\
{code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec"
def write_record_to_file(filename, exec_record):
try:
if os.path.exists(filename):
log.warning(
f"Unable to write execution record {filename}; file already exists."
)
else:
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "wb") as f:
exec_record.dump(f)
except Exception:
log.error(f"Unable to write execution record {filename}", exc_info=1)
def count_calls(g: fx.Graph):
c = 0
for n in g.nodes:
if "call" in n.op:
c += 1
return c
def identity(x):
return x
def nothing(*args, **kwargs):
pass
class ExactWeakKeyDictionary:
"""Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality"""
def __init__(self):
self.values = dict()
self.refs = dict()
def __getitem__(self, key):
return self.values[id(key)]
def get(self, key, default=None):
return self.values.get(id(key), default)
def __contains__(self, key):
return id(key) in self.values
def __setitem__(self, key, value):
idx = id(key)
if idx not in self.refs:
self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx))
self.values[idx] = value
def _remove_id(self, idx):
if idx in self.values:
del self.values[idx]
if idx in self.refs:
del self.refs[idx]
def clear(self):
self.refs.clear()
self.values.clear()
def istype(obj, allowed_types):
"""isinstance() without subclasses"""
if isinstance(allowed_types, (tuple, list, set)):
Loading ...