Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ ao / quantization / fx / fuse_handler.py

import torch
from torch.ao.quantization.backend_config import BackendConfig
from torch.fx.graph import Node, Graph
from ..utils import _parent_name, NodePattern, Pattern
from ..fuser_method_mappings import get_fuser_method_new
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Union
from .custom_config import FuseCustomConfig
from .match_utils import MatchAllNode
from torch.nn.utils.parametrize import type_before_parametrizations

__all__ = [
    "DefaultFuseHandler",
    "FuseHandler",
]


# ----------------------------
# Fusion Pattern Registrations
# ----------------------------

# Base Pattern Handler
class FuseHandler(ABC):
    """ Base handler class for the fusion patterns
    """
    def __init__(self, node: Node):
        pass

    @abstractmethod
    def fuse(self,
             load_arg: Callable,
             named_modules: Dict[str, torch.nn.Module],
             fused_graph: Graph,
             root_node: Node,
             extra_inputs: List[Any],
             matched_node_pattern: NodePattern,
             fuse_custom_config: FuseCustomConfig,
             fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
             is_qat: bool) -> Node:
        pass

class DefaultFuseHandler(FuseHandler):
    def __init__(
            self,
            node: Node):
        super().__init__(node)

    def fuse(self,
             load_arg: Callable,
             named_modules: Dict[str, torch.nn.Module],
             fused_graph: Graph,
             root_node: Node,
             extra_inputs: List[Any],
             matched_node_pattern: NodePattern,
             fuse_custom_config: FuseCustomConfig,
             fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
             is_qat: bool) -> Node:
        assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
        root_module = named_modules[str(root_node.target)]

        def get_modules(pattern):
            """ Given a node pattern, extract the corresponding modules
            e.g. input: (relu_node, (bn_node, conv_node))
                 output: (relu_module, (bn_module, conv_module))
            """
            if isinstance(pattern, (tuple, list)):
                n, *args = pattern
                modules: List[torch.nn.Module] = []
                modules.append(get_modules(n))
                for a in args:
                    modules.append(get_modules(a))
                return tuple(modules)
            else:
                n = pattern
                if n.op == "call_module":
                    return named_modules[n.target]
                elif n.op == "call_function" and n.target == torch.nn.functional.relu:
                    relu = torch.nn.ReLU()
                    relu.training = root_module.training
                    return relu
                elif n.op == "call_function" or n.op == "call_method":
                    return n.target
                else:
                    return MatchAllNode

        # since relu can be used multiple times, we'll need to create a relu module for each match
        matched_modules = get_modules(matched_node_pattern)

        def get_matched_types(m):
            if isinstance(m, tuple):
                return tuple(map(get_matched_types, m))
            if isinstance(m, torch.nn.Module):
                return type_before_parametrizations(m)
            return m

        matched_module_types = get_matched_types(matched_modules)
        module_parent_name, module_name = _parent_name(root_node.target)
        fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
        # TODO: change the signature for fuser_method to take matched module patterns
        # as input
        fused_module = fuser_method(is_qat, *matched_modules)
        setattr(named_modules[module_parent_name], module_name, fused_module)
        extra_args = []
        for input in extra_inputs:
            extra_args.append(load_arg(input))
        node = fused_graph.node_copy(root_node, load_arg)
        args = list(node.args)
        args.extend(extra_args)
        node.args = tuple(args)
        return node

def _get_fusion_pattern_to_fuse_handler_cls(
        backend_config: BackendConfig) -> Dict[Pattern, Callable]:
    fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
    for pattern, config in backend_config._pattern_complex_format_to_config.items():
        if config.fuser_method is not None:
            # TODO: is this logic right?
            fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
    return fusion_pattern_to_fuse_handlers