"""JIT-related state
This module stores various pieces of Python-global state relating to the JIT.
This is not intended to be imported directly; please the exposed
functionalities in `torch.jit`.
"""
import torch
import os
import weakref
class EnabledProxy:
"""Stores whether the JIT is enabled or not.
This is just a wrapper for a bool, so that we get reference semantics
"""
def __init__(self):
self.enabled = self.parse_env(
"PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
)
def parse_env(self, name, default, true_message, false_message):
value = os.environ.get(name)
if value is None:
return default
if value.lower() in {"1", "true", "yes"}:
return True
elif value.lower() in {"0", "false", "no"}:
return False
if value == "1v":
print(true_message)
return True
elif value == "0v":
print(false_message)
return False
raise ValueError("Unknown setting of {}. Try using 0 or 1.".format(name))
def __bool__(self):
return self.enabled
_enabled = EnabledProxy()
def disable():
_enabled.enabled = False
def enable():
_enabled.enabled = True
# The Python CompilationUnit. All functions and modules defined in Python will
# live in here. It's defined in Python because doing in cpp creates static
# destruction order issues.
_python_cu = torch._C.CompilationUnit()
# qualified_name => ScriptClass mapping
_script_classes = {}
def _add_script_class(cls, name):
global _script_classes
_script_classes[name] = cls
def _get_script_class(name):
global _script_classes
if name not in _script_classes:
return None
return _script_classes[name]
# Caching: we currently cache compilation of free functions and overloaded functions.
# To cache free functions we hold a weak ref to the function object and
# map to the compiled fn's qualified name.
# To cache overloaded functions we hold a weak ref to the function obj and
# map to all of its overloaded compiled fns.
# In the future we could consider caching more types of objects so that
# aliasing is preserved across separate compilations of the same object.
_jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
_jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
def _try_get_jit_cached_overloads(key):
qual_names = _jit_function_overload_caching.get(key, None)
if qual_names:
return [_python_cu.find_function(qual_name) for qual_name in qual_names]
else:
return None
def _set_jit_overload_cache(key, compiled_fns):
_jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
def _try_get_jit_cached_function(key):
if getattr(key, "__disable_jit_function_caching__", False) is True:
return None
qual_name = _jit_caching_layer.get(key, None)
if qual_name:
return _python_cu.find_function(qual_name)
else:
return None
def _set_jit_function_cache(key, value):
# only free functions currently supported
assert isinstance(value, torch.jit.ScriptFunction)
_jit_caching_layer[key] = value.qualified_name