import collections
import contextlib
import cProfile
import functools
import itertools
import logging
import os.path
import pstats
import shutil
import subprocess
import sys
from typing import Any, List
from unittest.mock import patch
from functorch.compile import (
config as functorch_config,
draw_graph,
get_aot_graph_name,
get_graph_being_compiled,
)
import torch
from torch import fx as fx
from torch._dynamo import config as dynamo_config
from torch._dynamo.debug_utils import save_graph_repro, wrap_compiler_debug
from torch._dynamo.utils import get_debug_dir, init_logging
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import TensorMetadata
from torch.fx.passes.tools_common import legalize_graph
from . import config, ir # noqa: F811, this is needed
from .scheduler import (
BaseSchedulerNode,
FusedSchedulerNode,
NopKernelSchedulerNode,
OutputNode,
SchedulerNode,
)
from .virtualized import V
log = logging.getLogger(__name__)
@functools.lru_cache(None)
def has_dot():
try:
subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
return True
except subprocess.SubprocessError:
return False
def draw_buffers(nodes, print_graph=False, fname=None):
"""
Draw a graph in fname.svg.
nodes is a list of SchedulerNode objects.
"""
if not has_dot():
log.warning("draw_buffers() requires `graphviz` package")
return
if fname is None:
fname = get_graph_being_compiled()
graph = create_fx_from_snodes(nodes)
for node in graph.nodes:
if "fusion_meta" not in node.meta:
continue
group = node.meta["fusion_meta"].group
if isinstance(group, tuple):
group = group[1]
# gather meta data
dtype = None
if isinstance(node, ir.ComputedBuffer):
dtype = node.data.dtype
metadata = TensorMetadata(group, dtype, None, None, None, None, None)
node.meta["tensor_meta"] = metadata
if print_graph:
print(graph)
gm = GraphModule({}, graph)
legalize_graph(gm)
gm.graph.lint()
draw_graph(gm, fname, clear_meta=False)
def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
"""
Creates a FX Graph from a list of SchedulerNode objects.
"""
def get_fake_func(name):
def func1(*args):
return 0
func1.__name__ = name
return func1
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"])
func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]}
buf_to_fx_node = {}
graph = torch.fx.Graph()
first_node = None
outputs = []
group: Any = None
# create call_function node for each Buffer and Kernel
for snode in snodes:
if snode.is_extern():
node_type = "extern"
group = node_type
elif snode.is_template():
node_type = "template"
group = node_type
elif isinstance(snode, NopKernelSchedulerNode):
node_type = "nop"
group = node_type
elif isinstance(snode, SchedulerNode):
node_type = "compute"
group = snode.group
elif isinstance(snode, FusedSchedulerNode):
node_type = "fused"
group = snode.group
else:
raise RuntimeError("Unknown node type")
node_func = func_dict[node_type]
fx_node = graph.call_function(node_func, args=(), kwargs=None)
def in_output(snode):
if isinstance(snode, FusedSchedulerNode):
return any([in_output(x) for x in snode.snodes])
return any([isinstance(user.node, OutputNode) for user in snode.users])
if in_output(snode):
outputs.append(fx_node)
name = snode.get_name()
fx_node.name = name
fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type)
if isinstance(snode, FusedSchedulerNode):
for x in snode.snodes:
buf_to_fx_node[x.get_name()] = fx_node
buf_to_fx_node[name] = fx_node
if first_node is None:
first_node = fx_node
# create edges between nodes
for snode in snodes:
name = snode.get_name()
deps = snode.read_writes.reads
fx_node = buf_to_fx_node[name]
new_args = []
for dep in deps:
if dep.name in buf_to_fx_node:
dep_node = buf_to_fx_node[dep.name]
else:
with graph.inserting_before(first_node):
dep_node = graph.placeholder(dep.name)
buf_to_fx_node[dep.name] = dep_node
new_args.append(dep_node)
fx_node.args = tuple(new_args)
graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
return graph
@contextlib.contextmanager
def enable_aot_logging():
compile_debug = bool(os.environ.get("TORCH_COMPILE_DEBUG", False))
debug_graphs = functorch_config.debug_graphs
debug_joint_graphs = functorch_config.debug_joint
import torch._functorch.aot_autograd
log = logging.getLogger(torch._functorch.aot_autograd.__name__)
stack = contextlib.ExitStack()
stack.enter_context(patch("functorch.compile.config.log_level", logging.DEBUG))
# if user has specified they want to see graphs via either env var
# add stream to std out
if debug_graphs or debug_joint_graphs:
stdout_handler = logging.StreamHandler(sys.stdout)
log.addHandler(stdout_handler)
stack.callback(lambda: log.removeHandler(stdout_handler))
if not compile_debug:
try:
yield
finally:
stack.close()
return
# Enable all graphs to be logged to a file by setting the flags to True
# and the log level of the file logger to DEBUG
stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
stack.enter_context(patch("functorch.compile.config.debug_graphs", True))
stack.enter_context(patch("functorch.compile.config.debug_joint", True))
path = os.path.join(get_debug_dir(), "aot_torchinductor")
if not os.path.exists(path):
os.makedirs(path)
fh = logging.FileHandler(
os.path.join(
path,
f"aot_{get_aot_graph_name()}_debug.log",
)
)
fh.setLevel(logging.DEBUG)
fh.setFormatter(
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
)
log.addHandler(fh)
try:
yield
finally:
log.removeHandler(fh)
stack.close()
class DebugContext:
_counter = itertools.count()
@staticmethod
def wrap(fn):
@functools.wraps(fn)
def inner(*args, **kwargs):
with DebugContext():
return fn(*args, **kwargs)
return wrap_compiler_debug(inner, compiler_name="inductor")
@staticmethod
def create_debug_dir(folder_name):
for n in DebugContext._counter:
dirname = os.path.join(
get_debug_dir(),
"aot_torchinductor",
f"{folder_name}.{n}",
)
if not os.path.exists(dirname):
os.makedirs(dirname)
return dirname
def __init__(self):
self._prof = None
self._path = None
self._stack = contextlib.ExitStack()
def rename(self, new_path: str):
if not self._path:
return
assert new_path.endswith(".debug"), new_path
if os.path.exists(new_path):
shutil.rmtree(new_path)
try:
os.rename(self._path, new_path)
self._path = new_path
except OSError:
# other OS might have troubling renaming dir with open files
pass
def fopen(self, filename):
assert self._path
return open(os.path.join(self._path, filename), "w")
def filename(self, suffix):
return os.path.join(self._path, suffix)
def upload_tar(self):
if config.trace.upload_tar is not None:
import tarfile
assert self._path
tar_file = os.path.join(
self._path, f"{os.path.basename(self._path)}.tar.gz"
)
with tarfile.open(tar_file, "w:gz") as tar:
tar.add(self._path, arcname=os.path.basename(self._path))
config.trace.upload_tar(tar_file)
def __enter__(self):
log = logging.getLogger("torch._inductor")
if not log.handlers:
init_logging()
if config.debug:
def reset_log_level(level):
dynamo_config.log_level = level
self._stack.callback(reset_log_level, dynamo_config.log_level)
dynamo_config.log_level = logging.DEBUG
self._stack.enter_context(V.set_debug_handler(self))
if not config.trace.enabled:
return
self._path = self.create_debug_dir(get_aot_graph_name())
if config.trace.debug_log:
self._setup_log_capture("debug.log", logging.DEBUG)
if config.trace.info_log:
self._setup_log_capture("info.log", logging.INFO)
if config.trace.compile_profile:
self._prof = cProfile.Profile()
self._prof.enable()
def _setup_log_capture(self, filename, level):
log = logging.getLogger("torch._inductor")
fd = self._stack.enter_context(self.fopen(filename))
ch = logging.StreamHandler(fd)
ch.setLevel(level)
ch.setFormatter(
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
)
log.addHandler(ch)
log.setLevel(min(log.level, level))
self._stack.callback(log.removeHandler, ch)
def __exit__(self, exc_type, exc_val, exc_tb):
if self._prof:
self._prof.disable()
self._save_profile_data()
if self._path:
self.upload_tar()
log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
self._stack.close()
def _save_profile_data(self):
self._prof.dump_stats(self.filename("compile.prof"))
with self.fopen("compile.stats") as fd:
stats = pstats.Stats(self._prof, stream=fd)
Loading ...