import sys
import torch
from torch.fx.graph import (
Graph,
Node,
)
from torch.ao.quantization.utils import Pattern
from .quantize_handler import (
QuantizeHandler,
)
from ..qconfig import (
QConfigAny,
)
from ..utils import (
MatchAllNode
)
from .graph_module import (
_is_observed_standalone_module,
)
from torch.nn.utils.parametrize import type_before_parametrizations
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
__all__: List[str] = []
# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
_MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
# we'll start from the last node of the graph and traverse back.
def _is_match(modules, node, pattern, max_uses=sys.maxsize):
""" Matches a node in fx against a pattern
"""
if isinstance(pattern, tuple):
self_match, *arg_matches = pattern
if self_match is getattr:
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
arg_matches = []
else:
self_match = pattern
arg_matches = []
if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
return True
if node == pattern:
return True
if not isinstance(node, Node) or len(node.users) > max_uses:
return False
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != 'call_module':
return False
if not type_before_parametrizations(modules[node.target]) == self_match:
return False
elif callable(self_match):
if node.op != 'call_function' or node.target is not self_match:
return False
elif node.target is getattr:
if node.args[1] != pattern[1]:
return False
elif isinstance(self_match, str):
if node.op != 'call_method' or node.target != self_match:
return False
elif node.target != self_match:
return False
if not arg_matches:
return True
if len(arg_matches) != len(node.args):
return False
return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
def _find_matches(
graph: Graph,
modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler],
root_node_getter_mapping: Dict[Pattern, Callable],
standalone_module_names: List[str] = None,
standalone_module_classes: List[Type] = None,
custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]:
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.
Inputs:
- graph: an fx.Graph object
- modules: a mapping of fully qualified module name to instance,
for example, {'foo': ModuleFoo, ...}
- patterns: a mapping from a tuple of nodes in reverse order to
uninitialized QuantizeHandler subclass.
Outputs a map of
node_name ->
(node, matched_values, matched_pattern, QuantizeHandler instance,
qconfig)
For example, {
'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
<CopyNodeQuantizeHandler instance>, QConfig(...)),
...
}
"""
if custom_module_classes is None:
custom_module_classes = []
if standalone_module_classes is None:
standalone_module_classes = []
if standalone_module_names is None:
standalone_module_names = []
match_map: Dict[str, _MatchResult] = {}
all_matched : Set[str] = set()
def _recursive_record_node_in_match_map(
last_node,
match_map,
node_pattern,
matched_node_pattern,
pattern,
match_value):
if isinstance(node_pattern, Node):
match_map[node_pattern.name] = (
last_node, matched_node_pattern, pattern, match_value)
elif not isinstance(node_pattern, Iterable):
return
else:
for n in node_pattern:
_recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value)
# TODO: 1. merge with fuse matcher 2. document the code
def record_match(
pattern,
node,
last_node,
matched_node_pattern,
match_map):
if isinstance(pattern, tuple):
s, *args = pattern
is_single_arg = len(args) == 1
current_node_pattern: List[Node] = []
record_match(
s,
node,
last_node,
matched_node_pattern,
match_map)
if pattern[0] is not getattr:
for subpattern, arg in zip(args, node.args):
record_match(
subpattern,
arg,
node,
current_node_pattern,
match_map)
if len(current_node_pattern) > 1:
# current_node_pattern is the node pattern we get from matching
# the subpattern with arguments of the node
# we use is_single_arg to recover the original structure of the pattern
# if the original pattern has a single argument, we will have
# (original_op, (original_arg, ...))
# otherwise, we'll have a list of arguments
# (original_op, arg0, arg1, arg2, ...)
if is_single_arg:
matched_node_pattern.append(tuple(current_node_pattern))
else:
matched_node_pattern.extend(list(current_node_pattern))
else:
matched_node_pattern.append(current_node_pattern[0])
else:
matched_node_pattern.append(node)
for node in reversed(graph.nodes):
if node.name not in match_map and node.name not in all_matched:
for pattern, quantize_handler_cls in patterns.items():
root_node_getter = root_node_getter_mapping.get(pattern, None)
if _is_match(modules, node, pattern) and node.name not in match_map:
matched_node_pattern: List[Node] = []
record_match(
pattern,
node,
node,
matched_node_pattern,
match_map)
quantize_handler = quantize_handler_cls( # type: ignore[operator]
matched_node_pattern,
modules,
root_node_getter)
last_node = node
# record the match for all nodes in the pattern
_recursive_record_node_in_match_map(
last_node,
match_map,
# we need to record all nodes in the matched pattern in the match_map
matched_node_pattern,
# this is a part of the value corresponding to the node
matched_node_pattern,
pattern,
quantize_handler)
break
# add custom module instances to the match result
assert modules is not None
for node in graph.nodes:
if node.op == 'call_module' and \
type(modules[node.target]) in custom_module_classes:
match_map[node.name] = (
node, node, None, QuantizeHandler(node, modules, is_custom_module=True))
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
assert modules is not None
return (
node_target in standalone_module_names or # type: ignore[operator]
type(modules[node_target]) in standalone_module_classes # type: ignore[operator]
)
# add standalone modules to the match
for node in graph.nodes:
if node.op == 'call_module' and \
(is_standalone_module(node.target, modules) or
_is_observed_standalone_module(modules[node.target])):
# add node to matched nodes
match_map[node.name] = (
node, node, None,
QuantizeHandler(node, modules, is_standalone_module=True))
return match_map