import functools
import weakref
import torch.nn
from torch.nn import Module
from .utils import ExactWeakKeyDictionary
class MutationTracker:
db = ExactWeakKeyDictionary()
def __init__(self):
self.mutation_count = 0
self.watchers = []
def on_mutation(self, name):
self.mutation_count += 1
tmp = self.watchers
self.watchers = []
for ref in tmp:
guarded = ref()
if guarded is not None:
guarded.invalidate(ref)
def track(self, guarded_code):
self.watchers.append(weakref.ref(guarded_code))
def watch(obj, guarded_code):
"""invalidate guarded_code when obj is mutated"""
ensure_patched(type(obj))
if obj not in MutationTracker.db:
MutationTracker.db[obj] = MutationTracker()
tracker = MutationTracker.db[obj]
tracker.track(guarded_code)
def ensure_patched(cls):
if getattr(cls, "___needs_mutation_patch", True):
cls.___needs_mutation_patch = False
original_setattr = cls.__setattr__
@functools.wraps(original_setattr)
def custom_setattr(self, key, value):
try:
MutationTracker.db[self].on_mutation(key)
except KeyError:
pass
return original_setattr(self, key, value)
cls.__setattr__ = custom_setattr
class GenerationTracker:
generation = 0
dynamic_classes = ExactWeakKeyDictionary()
generation_values = ExactWeakKeyDictionary()
@classmethod
def tag(cls, obj):
cls.generation_values[obj] = cls.generation
@staticmethod
def mark_class_dynamic(cls):
assert issubclass(cls, torch.nn.Module)
GenerationTracker.dynamic_classes[cls] = True
@classmethod
def get_generation_value(cls, obj):
if obj not in cls.generation_values:
return -1
return cls.generation_values[obj]
@classmethod
def check(cls, obj):
return (
obj in cls.generation_values
and cls.generation_values[obj] == cls.generation
)
def is_dynamic_nn_module(obj):
"""Check for nn.Modules() created dynamically or mutated"""
if hasattr(obj, "torchdynamo_force_dynamic"):
return obj.torchdynamo_force_dynamic
dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check(
obj
)
return dyn
def install_generation_tagging_init():
"""
Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__
so we can detect nn.Module instances created dynamically inside forward methods.
"""
if getattr(Module, "___needs_generation_tag_patch", True):
init = Module.__init__
def patched_init(self, *args, **kwargs):
init(self, *args, **kwargs)
GenerationTracker.tag(self)
Module.__init__ = patched_init
setstate = Module.__setstate__
def patched_setstate(self, state):
setstate(self, state)
GenerationTracker.tag(self)
Module.__setstate__ = patched_setstate
Module.___needs_generation_tag_patch = False
GenerationTracker.generation += 1