import torch
import copy
from torch.fx import GraphModule # type: ignore
from torch.fx.graph import Graph
from typing import Union, Dict, Any, List
class ObservedGraphModule(GraphModule):
def get_preserved_attr_names(self) -> List[str]:
return ['_activation_post_process_map',
'_patterns',
'_qconfig_map',
'_prepare_custom_config_dict',
'_node_name_to_scope']
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph):
preserved_attrs = dict()
for attr in self.get_preserved_attr_names():
preserved_attrs[attr] = getattr(root, attr)
super().__init__(root, graph)
for attr in preserved_attrs:
setattr(self, attr, preserved_attrs[attr])
# GraphModule does not copy attributes which are not in the __dict__
# of vanilla nn.Module. So, we override __deepcopy__ in order
# to copy the quantization specific attributes correctly.
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedGraphModule(fake_mod, self.graph)
def mark_observed_module(module: GraphModule) -> GraphModule:
return ObservedGraphModule(module, module.graph)
def is_observed_module(module: Any) -> bool:
return isinstance(module, ObservedGraphModule)
class ObservedStandaloneGraphModule(ObservedGraphModule):
def get_preserved_attr_names(self) -> List[str] :
return super().get_preserved_attr_names() + [
"_standalone_module_input_quantized_idxs",
"_standalone_module_output_quantized_idxs"
]
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedStandaloneGraphModule(fake_mod, self.graph)
def mark_observed_standalone_module(module: GraphModule) -> GraphModule:
return ObservedStandaloneGraphModule(module, module.graph)
def is_observed_standalone_module(module: Any) -> bool:
return isinstance(module, ObservedStandaloneGraphModule)