import torch
from torch.ao.quantization.backend_config import BackendConfig
from torch.fx.graph import Node, Graph
from ..utils import _parent_name, NodePattern, Pattern
from ..fuser_method_mappings import get_fuser_method_new
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Union
from .custom_config import FuseCustomConfig
from .match_utils import MatchAllNode
from torch.nn.utils.parametrize import type_before_parametrizations
__all__ = [
"DefaultFuseHandler",
"FuseHandler",
]
# ----------------------------
# Fusion Pattern Registrations
# ----------------------------
# Base Pattern Handler
class FuseHandler(ABC):
""" Base handler class for the fusion patterns
"""
def __init__(self, node: Node):
pass
@abstractmethod
def fuse(self,
load_arg: Callable,
named_modules: Dict[str, torch.nn.Module],
fused_graph: Graph,
root_node: Node,
extra_inputs: List[Any],
matched_node_pattern: NodePattern,
fuse_custom_config: FuseCustomConfig,
fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
is_qat: bool) -> Node:
pass
class DefaultFuseHandler(FuseHandler):
def __init__(
self,
node: Node):
super().__init__(node)
def fuse(self,
load_arg: Callable,
named_modules: Dict[str, torch.nn.Module],
fused_graph: Graph,
root_node: Node,
extra_inputs: List[Any],
matched_node_pattern: NodePattern,
fuse_custom_config: FuseCustomConfig,
fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
is_qat: bool) -> Node:
assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
root_module = named_modules[str(root_node.target)]
def get_modules(pattern):
""" Given a node pattern, extract the corresponding modules
e.g. input: (relu_node, (bn_node, conv_node))
output: (relu_module, (bn_module, conv_module))
"""
if isinstance(pattern, (tuple, list)):
n, *args = pattern
modules: List[torch.nn.Module] = []
modules.append(get_modules(n))
for a in args:
modules.append(get_modules(a))
return tuple(modules)
else:
n = pattern
if n.op == "call_module":
return named_modules[n.target]
elif n.op == "call_function" and n.target == torch.nn.functional.relu:
relu = torch.nn.ReLU()
relu.training = root_module.training
return relu
elif n.op == "call_function" or n.op == "call_method":
return n.target
else:
return MatchAllNode
# since relu can be used multiple times, we'll need to create a relu module for each match
matched_modules = get_modules(matched_node_pattern)
def get_matched_types(m):
if isinstance(m, tuple):
return tuple(map(get_matched_types, m))
if isinstance(m, torch.nn.Module):
return type_before_parametrizations(m)
return m
matched_module_types = get_matched_types(matched_modules)
module_parent_name, module_name = _parent_name(root_node.target)
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
# TODO: change the signature for fuser_method to take matched module patterns
# as input
fused_module = fuser_method(is_qat, *matched_modules)
setattr(named_modules[module_parent_name], module_name, fused_module)
extra_args = []
for input in extra_inputs:
extra_args.append(load_arg(input))
node = fused_graph.node_copy(root_node, load_arg)
args = list(node.args)
args.extend(extra_args)
node.args = tuple(args)
return node
def _get_fusion_pattern_to_fuse_handler_cls(
backend_config: BackendConfig) -> Dict[Pattern, Callable]:
fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
for pattern, config in backend_config._pattern_complex_format_to_config.items():
if config.fuser_method is not None:
# TODO: is this logic right?
fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
return fusion_pattern_to_fuse_handlers