import copy
import functools
import getpass
import logging
import os
import shutil
import subprocess
import textwrap
import uuid
from collections import Counter
from importlib import import_module
from tempfile import TemporaryFile
import torch
import torch.fx as fx
from torch._prims_common import is_float_dtype
from . import config
from .backends.registry import lookup_backend, register_debug_backend
from .utils import clone_inputs, get_debug_dir
log = logging.getLogger(__name__)
inductor_config = import_module("torch._inductor.config")
use_buck = inductor_config.is_fbcode()
extra_deps = []
extra_imports = ""
if use_buck:
extra_deps = [
"//caffe2/fb/custom_ops/sparsenn:sparsenn-all_operators",
"//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu",
"//caffe2/torch/fb/sparsenn:sparsenn_operators",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
]
extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
class BuckTargetWriter:
def __init__(self, filename):
self.subdir, self.py_file = os.path.split(filename)
self.target = self.py_file.replace(".py", "")
# Get main_module path from fbcode
self.path = f'{self.subdir.replace("/", ".")}.{self.target}'
self.path = self.path[self.path.find("fbcode.") :]
self.path = self.path[7:]
# Get cmd line path
tmp = self.subdir
tmp = tmp[tmp.find("fbcode/") :][7:]
self.cmd_line_path = f"//{tmp}:{self.target}"
def build(self):
extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps])
return textwrap.dedent(
f"""
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
python_binary(
name="{self.target}",
srcs = ["{self.py_file}"],
compile = False,
deps = [
"//caffe2:torch",
"//caffe2/functorch:functorch",
"//triton:triton",
],
cpp_deps = [
{extra_cpp_deps}
],
main_module = "{self.path}",
)
"""
)
def write(self, print_msg=True):
target_file = os.path.join(self.subdir, "TARGETS")
with open(target_file, "w") as fd:
fd.write(self.build())
# log.warning(f"Wrote isolation TARGETS file at {target_file}")
cmd = ["buck2", "run", "@mode/dev-nosan", self.cmd_line_path]
if print_msg:
log.warning(
f'Found an example that reproduces the error. Run this cmd to repro - {" ".join(cmd)}'
)
return cmd
def minifier_dir():
path = os.path.join(get_debug_dir(), "minifier")
if path is None:
path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}"
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
return path
class NNModuleToString:
safe_reprs = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.LayerNorm,
torch.nn.Dropout,
torch.nn.Softmax,
torch.nn.ReLU,
torch.nn.GELU,
torch.nn.Identity,
torch.nn.MaxPool2d,
torch.nn.Embedding,
torch.nn.Tanh,
torch.nn.ConvTranspose1d,
torch.nn.GLU,
torch.nn.LSTM,
torch.nn.Flatten,
torch.nn.AdaptiveAvgPool2d,
]
@staticmethod
def can_convert_to_string(gm):
cant_convert = set()
for _, module in gm.named_children():
if type(module) not in NNModuleToString.safe_reprs:
cant_convert.add(module)
if len(cant_convert) > 0:
log.warning(f"We have not tested reprs of some modules - {cant_convert}")
# TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct.
return True
@staticmethod
def convert(gm):
from torch.nn.modules.module import _addindent
tab = " " * 4
model_str = textwrap.dedent(
"""
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
"""
)
for module_name, module in gm.named_children():
module_str = f"{module.__repr__()}"
# module should be a core torch.nn.Module, so all parameters
# should be on the same device.
example_param = next(module.parameters(), None)
if example_param is not None and example_param.is_cuda:
module_str = f"{module_str}.cuda()"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in gm._buffers.items():
if buffer is None:
continue
if torch.is_floating_point(buffer):
tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
else:
tensor_str = (
f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})"
)
if buffer.is_cuda:
tensor_str = f"{tensor_str}.cuda()"
model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n"
for param_name, param in gm._parameters.items():
if param is None:
continue
tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))"
if param.is_cuda:
tensor_str = f"{tensor_str}.cuda()"
model_str += f"{tab*2}self.{param_name} = {tensor_str}\n"
# TODO - Keep this code for now. But, I don't think we will need this.
# attrs = dir(gm)
# for attr in attrs:
# if "_tensor_constant" in attr:
# val = getattr(gm, attr)
# model_str += f" {attr} = {val!r}\n"
model_str += f"{_addindent(gm.code, 4)}\n"
return model_str
@functools.lru_cache(None) # subprocess is expensive
def _cuda_system_info_comment():
if not torch.cuda.is_available():
return "# torch.cuda.is_available()==False, no GPU info collected\n"
model_str = "# CUDA Info: \n"
try:
cuda_version_out = subprocess.run(["nvcc", "--version"], stdout=subprocess.PIPE)
cuda_version_lines = cuda_version_out.stdout.decode().split("\n")
cuda_version_out = "".join(
[f"# {s} \n" for s in cuda_version_lines if s not in [""]]
)
model_str += f"{cuda_version_out}\n"
except FileNotFoundError:
model_str += "# nvcc not found\n"
gpu_names = Counter(
torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
)
model_str += "# GPU Hardware Info: \n"
for name, count in gpu_names.items():
model_str += f"# {name} : {count} \n"
model_str += "\n"
return model_str
def generate_config_string():
import torch._functorch.config
import torch._inductor.config
return textwrap.dedent(
f"""\
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
torch._dynamo.config.load_config({repr(torch._dynamo.config.save_config())})
torch._inductor.config.load_config({repr(torch._inductor.config.save_config())})
torch._functorch.config.load_config({repr(torch._functorch.config.save_config())})
"""
)
TEST_REPLACEABLE_COMMENT = "# REPLACEABLE COMMENT FOR TESTING PURPOSES"
def generate_compiler_repro_string(gm, args):
model_str = textwrap.dedent(
f"""
import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx
{generate_config_string()}
{TEST_REPLACEABLE_COMMENT}
{extra_imports}
"""
)
model_str += f"# torch version: {torch.version.__version__}\n"
if hasattr(torch.version, "cuda"):
model_str += f"# torch cuda version: {torch.version.cuda}\n"
if hasattr(torch.version, "git_version"):
model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
model_str += _cuda_system_info_comment()
model_str += NNModuleToString.convert(gm)
model_str += f"args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type) for a in args]!r}\n"
model_str += (
"args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]\n"
)
# TODO: fake may be better for performance here
tracing_mode = "real"
if config.dynamic_shapes:
tracing_mode = "symbolic"
model_str += f"mod = make_fx(Repro(), tracing_mode={repr(tracing_mode)})(*args)\n"
return model_str
INDUCTOR_IMPORT = """
from torch._inductor.compile_fx import compile_fx_inner
from torch._dynamo.debug_utils import same_two_models
"""
COMPILER_REPRO_OPTIONS = {
"inductor": (INDUCTOR_IMPORT, "compile_fx_inner", "inductor_fails"),
"inductor_accuracy": (
INDUCTOR_IMPORT,
"compile_fx_inner",
"inductor_accuracy_fails",
),
}
def dump_compiler_graph_state(gm, args, compiler_name):
subdir = os.path.join(minifier_dir(), "checkpoints")
if not os.path.exists(subdir):
os.makedirs(subdir, exist_ok=True)
file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py")
log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
with open(file_name, "w") as fd:
save_graph_repro(fd, gm, args, compiler_name)
curdir = os.getcwd()
repro_path = os.path.join(curdir, "repro.py")
try:
shutil.copyfile(file_name, repro_path)
log.warning(f"Copying repro file for convenience to {repro_path}")
if use_buck:
BuckTargetWriter(file_name).write()
except OSError:
log.warning(f"No write permissions for {repro_path}")
pass
def save_graph_repro(fd, gm, args, compiler_name):
sync_line = ""
for arg in args:
if arg.is_cuda:
sync_line = "torch.cuda.synchronize() # Ensures that segfaults are surfaced"
break
if "inductor" in compiler_name:
fd.write("import torch._inductor.overrides\n")
fd.write(generate_compiler_repro_string(gm, args))
fd.write(COMPILER_REPRO_OPTIONS[compiler_name][0])
if "_accuracy" in compiler_name:
fd.write(
textwrap.dedent(
f"""
compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args)
class AccuracyError(Exception):
pass
if not same_two_models(mod, compiled, args, only_fwd=True):
raise AccuracyError("Bad accuracy detected")
"""
)
)
else:
fd.write(
textwrap.dedent(
f"""
compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args)
ref = compiled(args)
{sync_line}
"""
)
Loading ...