import collections
import enum
import torch
toq = torch.ops.quantized
from torch.fx import GraphModule
from torch.fx.graph import Graph, Node
from torch.ao.quantization.utils import getattr_from_fqn
from .ns_types import NSSubgraph, NSNodeTargetType
from .mappings import (
get_base_name_to_sets_of_related_ops,
get_unmatchable_types_map,
)
from .pattern_utils import (
get_type_a_related_to_b,
get_reversed_fusions,
end_node_matches_reversed_fusion,
)
from torch.ao.quantization import (
ObserverBase,
FakeQuantizeBase,
)
from typing import Dict, Tuple, List, Optional, Set, Any
def _get_output_nodes(g: Graph) -> List[Node]:
return [n for n in g.nodes if n.op == 'output']
class _NSGraphMatchableSubgraphsIterator:
"""
Iterates through the graph of gm, starting with the output nodes
and continuing backwards.
1. Returns matchable subgraphs, in order. A subgraph is defined by
(start_node, end_node).
2. Skips over non-matchable subgraphs
"""
def __init__(
self,
gm: GraphModule,
non_matchable_functions: Set[NSNodeTargetType],
non_matchable_modules: Set[NSNodeTargetType],
non_matchable_methods: Set[NSNodeTargetType],
):
self.gm: GraphModule = gm
self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions
self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules
self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods
self.seen_nodes: Set[Node] = set()
self.stack: List[Node] = []
for start_node in _get_output_nodes(self.gm.graph):
self.stack.append(start_node)
def __iter__(self):
return self
def __next__(self) -> NSSubgraph:
"""
Returns the next matchable subgraph.
"""
while len(self.stack) > 0:
cur_end_node = self.stack.pop()
if cur_end_node in self.seen_nodes:
continue
# for subgraphs which are single nodes, start_node == end_node
# for subgraphs with more than one node, start node != end_node
cur_start_node = cur_end_node
# Subgraphs like linear-relu have the base node as the start node.
# Subgraphs like dequantize-linear-relu-to(torch.float16) have the
# base node as the second node.
# The cur_base_op_node var will move to the actual node during
# the fusion matching later in this code block.
cur_base_op_node = cur_end_node
# Check for potential fusions. For now, we are greedy
# and always skip all non-base nodes of a fusion. For example,
# if we match linear-relu backwards, we will always skip the
# relu node and attempt to match the linear node. This can
# be made configurable later if needed.
for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
is_match = end_node_matches_reversed_fusion(
cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes)
if is_match:
# navigate to the base node
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
self.seen_nodes.add(cur_start_node)
# for now, assume that there are no other nodes
# which need to be added to the stack
cur_start_node = cur_start_node.args[0] # type: ignore[assignment]
# if the base op index matches the current node, set it
rev_base_op_idx = \
len(_reverse_fusion_ops) - 2 - base_op_idx
if rev_fusion_idx == rev_base_op_idx:
cur_base_op_node = cur_start_node
break
self.seen_nodes.add(cur_start_node)
# add args of previous nodes to stack
for arg in cur_start_node.all_input_nodes:
self._recursively_add_node_arg_to_stack(arg)
# skip unmatchable nodes
# note: this check is done on the start_node, i.e.
# if we are matching linear-relu in reverse, this would do the matchable
# check on the linear
if not self._is_matchable(cur_base_op_node):
continue
# If an observer or a fake_quant was not matched as a part of
# a pattern of multiple nodes, ignore it. One case where this is
# relevant is an observer on a graph input, which was added because
# it is necessary for the next node.
if cur_end_node.op == 'call_module' and cur_start_node is cur_end_node:
maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type]
if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
continue
return NSSubgraph(
start_node=cur_start_node, end_node=cur_end_node,
base_op_node=cur_base_op_node)
raise StopIteration
def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
"""
Adds all of the nodes in this arg to the stack, properly navigating
through list, dicts and tuples.
"""
if isinstance(arg, Node):
self.stack.append(arg)
elif isinstance(arg, torch.fx.immutable_collections.immutable_list) or type(arg) is tuple:
for inner_arg in arg:
self._recursively_add_node_arg_to_stack(inner_arg)
elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
for key, value in arg.items():
self._recursively_add_node_arg_to_stack(value)
def _is_matchable(self, node: Node) -> bool:
if node.op == 'call_function':
return not (node.target in self.non_matchable_functions)
elif node.op == 'call_module':
assert isinstance(node.target, str)
target_mod = getattr_from_fqn(self.gm, node.target)
return not \
any(isinstance(target_mod, t) # type: ignore[arg-type]
for t in self.non_matchable_modules)
elif node.op == 'call_method':
return not (node.target in self.non_matchable_methods)
else:
return False
class GraphMatchingException(Exception):
"""
Exception raised when two graphs cannot be matched.
"""
pass
class SubgraphTypeRelationship(enum.Enum):
# same type, known
# example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
EQUAL = enum.auto()
# same type, but the type is not known to Numerical Suite
# (user defined type, etc).
EQUAL_BUT_UKNOWN = enum.auto()
# known, same subgraph_relationship set, but not the same type
# example: F.linear and toq.linear
RELATED_BUT_NOT_EQUAL = enum.auto()
# not related
NOT_RELATED = enum.auto()
def _get_subgraph_relationship_type(
subgraph_a: NSSubgraph,
subgraph_b: NSSubgraph,
gm_a: GraphModule,
gm_b: GraphModule,
type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]],
) -> SubgraphTypeRelationship:
node_a = subgraph_a.base_op_node
node_b = subgraph_b.base_op_node
# TODO(next): make this code handle matching by what is before the base op
if node_a.op != node_b.op:
if not (
node_a.op in ('call_function', 'call_method') and
node_b.op in ('call_function', 'call_method')
):
return SubgraphTypeRelationship.NOT_RELATED
if node_a.op in ('call_function', 'call_method'):
key = (node_a.target, node_b.target)
if key not in type_a_related_to_b:
if node_a.target == node_b.target:
return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
else:
return SubgraphTypeRelationship.NOT_RELATED
# after this point, we are dealing with known types
if node_a.target == node_b.target:
node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
if node_a_has_prev and (not node_b_has_prev):
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
elif (not node_a_has_prev) and node_b_has_prev:
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
elif (not node_a_has_prev) and (not node_b_has_prev):
return SubgraphTypeRelationship.EQUAL
else:
# TODO(future PR): check for matches start_op_node and base_op_node
return SubgraphTypeRelationship.EQUAL
if key in type_a_related_to_b:
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
else:
return SubgraphTypeRelationship.NOT_RELATED
elif node_a.op == 'call_module':
assert (subgraph_a.base_op_node == subgraph_a.start_node and
subgraph_b.base_op_node == subgraph_b.start_node), \
"Matching call_module patterns where base_op_node != start_node is not supported yet"
# for call_module, we need to look up the modules to do the type check
assert isinstance(node_a.target, str)
mod_a = getattr_from_fqn(gm_a, node_a.target)
assert isinstance(node_b.target, str)
mod_b = getattr_from_fqn(gm_b, node_b.target)
key = (type(mod_a), type(mod_b))
if key not in type_a_related_to_b:
if type(mod_a) == type(mod_b):
return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
else:
return SubgraphTypeRelationship.NOT_RELATED
elif type(mod_a) == type(mod_b):
return SubgraphTypeRelationship.EQUAL
else:
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
return SubgraphTypeRelationship.NOT_RELATED
def _get_name_for_subgraph(
subgraph_a: NSSubgraph,
gm_a: GraphModule,
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
existing_names: Set[str],
) -> str:
"""
Returns a unique name for a subgraph. This name is based on two things:
1. the name of the set containing the underlying type of the base op in the
subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
2. the number of previous subgraphs with related underlying type of the base op
For example, in the graph
linear0 -> relu0 -> linear1 -> relu1
The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate
from the output node backwards, the name given to (linear1, relu1) will be
`base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
will be `base_op_torch.nn.functional.linear_1`.
Why are we not just using the node name? Answer: because of two requirements:
A. fusions must be supported
B. some Numeric Suite APIs can be called without having all of the models in memory
For example, let's say we need to match nodes of
(1) ... -> linear0 -> relu0 -> ...
And
(2) ... -> linear_relu0 -> ...
Without being able to inspect them together. With the current naming scheme, if
we iterate through both of these graphs in the same order, and assuming the rest
of the graphs match, both of these subgraphs will get the same name without
(1) and (2) knowing anything about each other.
"""
target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
target_base_type = None
for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
if target_type in sets_of_related_ops:
target_base_type = base_name
target_base_name = 'base_op_' + str(target_base_type)
counter = 0
proposed_name = target_base_name + '_' + str(counter)
while proposed_name in existing_names:
counter += 1
proposed_name = target_base_name + '_' + str(counter)
existing_names.add(proposed_name)
return proposed_name
def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
if node.op in ('call_function', 'call_method'):
return node.target
elif node.op == 'call_module':
assert isinstance(node.target, str)
mod = getattr_from_fqn(gm, node.target)
return type(mod)
return None
def get_matching_subgraph_pairs(
gm_a: GraphModule,
gm_b: GraphModule,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]:
"""
Matches matchable subgraphs of graph_a to graph_b.
For a node, "matchable" is defined as a node which is not an observer,
fake_quants, quant or dequant.
A subgraph can contain one or more nodes. A subgraph is matchable if
at least one node inside of it is matchable. Currently, all nodes in
a subgraph must be matchable (because we assume no observers will be
inserted in the middle of a fusion).
A subgraph is defined by (start_node, end_node). We assume that only
start_node and end_node are linked with the surrounding graph, all other
nodes in a subgraph are self-contained.
A pair of nodes is "related" if both nodes represent the same mathematical
operation across different quantization flavors. For example,
`F.linear` and `torch.ops.quantized.linear` are related, and
`F.linear` and `torch.nn.Conv` are not related.
For each matchable pair of nodes node_a and node_b, they will match
if node_a and node_b are related.
For graphs A and B, they will match iff:
1. the number of matchable subgraphs in A and B is equivalent
2. when iterating through the matchable subgraphs of A and B in the same order, each
corresponding pair of base nodes is related.
This enables us to find the corresponding subgraphs between
graphs of related models. For example, if we had two graphs such as:
graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
w -/
b -/
graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
packed_params_0 -/
Loading ...