Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ quantization / fx / observed_module.py

import torch
import copy
from torch.fx import GraphModule  # type: ignore
from torch.fx.graph import Graph
from typing import Union, Dict, Any, List

class ObservedGraphModule(GraphModule):

    def get_preserved_attr_names(self) -> List[str]:
        return ['_activation_post_process_map',
                '_patterns',
                '_qconfig_map',
                '_prepare_custom_config_dict',
                '_node_name_to_scope']

    def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph):
        preserved_attrs = dict()
        for attr in self.get_preserved_attr_names():
            preserved_attrs[attr] = getattr(root, attr)
        super().__init__(root, graph)
        for attr in preserved_attrs:
            setattr(self, attr, preserved_attrs[attr])

    # GraphModule does not copy attributes which are not in the __dict__
    # of vanilla nn.Module.  So, we override __deepcopy__ in order
    # to copy the quantization specific attributes correctly.
    def __deepcopy__(self, memo):
        fake_mod = torch.nn.Module()
        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
        return ObservedGraphModule(fake_mod, self.graph)

def mark_observed_module(module: GraphModule) -> GraphModule:
    return ObservedGraphModule(module, module.graph)

def is_observed_module(module: Any) -> bool:
    return isinstance(module, ObservedGraphModule)

class ObservedStandaloneGraphModule(ObservedGraphModule):
    def get_preserved_attr_names(self) -> List[str] :
        return super().get_preserved_attr_names() + [
            "_standalone_module_input_quantized_idxs",
            "_standalone_module_output_quantized_idxs"
        ]

    def __deepcopy__(self, memo):
        fake_mod = torch.nn.Module()
        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
        return ObservedStandaloneGraphModule(fake_mod, self.graph)

def mark_observed_standalone_module(module: GraphModule) -> GraphModule:
    return ObservedStandaloneGraphModule(module, module.graph)

def is_observed_standalone_module(module: Any) -> bool:
    return isinstance(module, ObservedStandaloneGraphModule)