import _collections_abc
import _weakrefset
import abc
import collections
import contextlib
import copy
import copyreg
import dataclasses
import enum
import functools
import importlib
import inspect
import linecache
import logging
import multiprocessing
import operator
import os
import posixpath
import random
import re
import selectors
import signal
import tempfile
import threading
import tokenize
import traceback
import types
import typing
import unittest
import weakref
import torch
import torch._inductor.test_operators
try:
import torch._prims
# isort: split
# TODO: Hack to unblock simultaneous landing changes. Fix after https://github.com/pytorch/pytorch/pull/81088 lands
import torch._prims.utils
import torch._prims.wrappers
import torch._refs
import torch._refs.nn
import torch._refs.nn.functional
import torch._refs.special
HAS_PRIMS_REFS = True
except ImportError:
HAS_PRIMS_REFS = False
from . import comptime, config, external_utils
"""
A note on skipfiles:
Dynamo consults this file to determine whether code should be compiled or skipped.
A skip applies at the frame boundary, meaning dynamo either triggers a graph break
at the beginning of the frame or attempts to trace the whole frame. When skipping
a frame, recursively called frames are still traced by dynamo unless also skipped.
Skipfiles (skipped at the file level instead of function level) still apply on a
frame-by-frame boundary as dynamo traces, but apply to all functions in that file.
@skip is a helper decorator that can be applied to your function to cause it to be
included here.
"""
def _strip_init_py(s):
return re.sub(r"__init__.py$", "", s)
def _module_dir(m: types.ModuleType):
return _strip_init_py(m.__file__)
SKIP_DIRS = [
# torch.*
_module_dir(torch),
# torchdynamo.*
os.path.dirname(__file__) + "/",
"<frozen importlib",
"<__array_function__ internals>",
] + [
# skip some standard libs
_module_dir(m)
for m in (
abc,
collections,
contextlib,
copy,
copyreg,
dataclasses,
enum,
functools,
importlib,
inspect,
linecache,
logging,
multiprocessing,
operator,
os,
posixpath,
random,
re,
selectors,
signal,
tempfile,
threading,
tokenize,
traceback,
types,
typing,
unittest,
weakref,
_collections_abc,
_weakrefset,
)
]
FILENAME_ALLOWLIST = {
torch.nn.Sequential.__init__.__code__.co_filename,
torch.set_rng_state.__code__.co_filename,
torch._inductor.test_operators.__file__,
# These are dynamo files!
external_utils.__file__,
comptime.__file__, # Want to inline these helpers
}
# Include optimizer code for tracing
FILENAME_ALLOWLIST |= {
inspect.getfile(obj)
for obj in torch.optim.__dict__.values()
if inspect.isclass(obj)
}
FILENAME_ALLOWLIST |= {torch.optim._functional.__file__}
if HAS_PRIMS_REFS:
FILENAME_ALLOWLIST |= {
torch._prims.__file__,
torch._prims.utils.__file__,
torch._prims.wrappers.__file__,
torch._refs.__file__,
torch._refs.special.__file__,
torch._refs.nn.functional.__file__,
}
SKIP_DIRS_RE = None
def _recompile_re():
global SKIP_DIRS_RE
SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})")
def add(import_name: str):
if isinstance(import_name, types.ModuleType):
return add(import_name.__name__)
assert isinstance(import_name, str)
module_spec = importlib.util.find_spec(import_name)
if not module_spec:
return
origin = module_spec.origin
if origin is None:
return
global SKIP_DIRS_RE
SKIP_DIRS.append(_strip_init_py(origin))
_recompile_re()
def check(filename, allow_torch=False):
"""Should skip this file?"""
if filename is None:
return True
if filename in FILENAME_ALLOWLIST:
return False
if allow_torch and is_torch(filename):
return False
return bool(SKIP_DIRS_RE.match(filename))
# skip common third party libs
for _name in (
"functorch",
"intel_extension_for_pytorch",
"networkx",
"numpy",
"omegaconf",
"onnx",
"onnxruntime",
"onnx_tf",
"pandas",
"sklearn",
"tabulate",
"tensorflow",
"tensorrt",
"torch2trt",
"tqdm",
"tree",
"tvm",
"fx2trt_oss",
"xarray",
):
add(_name)
_recompile_re()
def is_torch_inline_allowed(filename):
return any(
filename.startswith(_module_dir(mod))
for mod in config.skipfiles_inline_module_allowlist
)
@functools.lru_cache(None)
def dynamo_dir():
import torch._dynamo
return _module_dir(torch._dynamo)
def is_torch(filename):
if filename.startswith(dynamo_dir()):
return False
return filename.startswith(_module_dir(torch))