Repository URL to install this package:
Version:
2.1.2+cpu ▾
|
import torch
from torch.fx import GraphModule, map_arg
from torch.fx.graph import Graph, Node
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from .utils import (
get_node_first_input_and_output_type,
getattr_from_fqn,
NodeInputOrOutputType,
return_first_non_observer_node,
get_number_of_non_param_args,
get_target_type_str,
get_arg_indices_of_inputs_to_log,
get_node_input_qparams,
op_type_supports_shadowing,
get_normalized_nth_input,
)
from .ns_types import (
NSSingleResultValuesType,
NSSubgraph,
NSNodeTargetType,
)
from torch.ao.ns.fx.mappings import (
get_node_type_to_io_type_map,
)
from torch.ao.quantization.observer import _is_activation_post_process
from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set
def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
fqn = None
if hasattr(gm, '_node_name_to_scope'):
# fqn on observers is not present, because they do not
# exist when the fqns are created during tracing. If this is
# an observer, get the fqn of the node being observed.
node_to_use_for_fqn = node
if node.op == 'call_module':
assert isinstance(node.target, str)
module = getattr_from_fqn(gm, node.target)
if _is_activation_post_process(module):
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
return fqn # type: ignore[return-value]
def _insert_logger_after_node(
node: Node,
gm: GraphModule,
logger_cls: Callable,
logger_node_name_suffix: str,
ref_node_name: str,
model_name: str,
ref_name: str,
ref_node_target_type: str,
results_type: str,
index_within_arg: int,
index_of_arg: int,
fqn: Optional[str],
) -> Node:
"""
Given a starting graph of
prev_node -> node -> next_node
This function creates a new logger_cls obj and adds it
after node, resulting in
prev_node -> node -> logger_obj -> next_node
"""
# create new name
logger_node_name = \
get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
target_type = get_target_type_str(node, gm)
# create the logger object
logger_obj = logger_cls(
ref_node_name, node.name, model_name, ref_name, target_type,
ref_node_target_type,
results_type, index_within_arg, index_of_arg, fqn)
# attach the logger object to the parent module
setattr(gm, logger_node_name, logger_obj)
logger_node = node.graph.create_node(
'call_module', logger_node_name, (node,), {})
return logger_node
def add_loggers_to_model(
gm: GraphModule,
node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
logger_cls: Callable,
model_name: str,
) -> GraphModule:
"""
Takes the graph of gm, adds loggers to the output
of each node in nodes_to_instrument. Returns a GraphModule with the new
graph.
"""
new_graph = Graph()
env: Dict[str, Any] = {}
modules = dict(gm.named_modules())
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
for node in gm.graph.nodes:
if node.op == 'output':
new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
continue
if (
(node in node_to_instrument_inputs_to_ref_node_name) or
(node in node_to_instrument_outputs_to_ref_node_name)
):
fqn = _maybe_get_fqn(node, gm)
if node in node_to_instrument_inputs_to_ref_node_name:
ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[node]
# Ops such add and mul are special because either
# one or two of the first two arguments can be tensors,
# and if one argument is a tensor it can be first or
# second (x + 1 versus 1 + x).
arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
for node_arg_idx in arg_indices_to_log:
node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
if type(node_arg) == Node:
# create a single input logger
prev_node = env[node_arg.name]
env[node_arg.name] = _insert_logger_after_node(
prev_node, gm, logger_cls, '_ns_logger_', node.name,
model_name, ref_name, ref_node_type,
NSSingleResultValuesType.NODE_INPUT.value,
index_within_arg=0, index_of_arg=node_arg_idx,
fqn=fqn)
elif type(node_arg) == torch.fx.immutable_collections.immutable_list:
# create N input loggers, one for each node
for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type]
prev_node = env[arg.name]
env[prev_node.name] = _insert_logger_after_node(
prev_node, gm, logger_cls, '_ns_logger_', node.name,
model_name, ref_name, ref_node_type,
NSSingleResultValuesType.NODE_INPUT.value,
index_within_arg=arg_idx, index_of_arg=node_arg_idx,
fqn=fqn)
else:
pass
# ensure env is populated with base node
# Note: runs for both inputs and outputs
env[node.name] = new_graph.node_copy(node, load_arg)
if node in node_to_instrument_outputs_to_ref_node_name:
ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[node]
# add the logger after the base node
env[node.name] = _insert_logger_after_node(
env[node.name], gm, logger_cls, '_ns_logger_', node.name,
model_name, ref_name, ref_node_type,
NSSingleResultValuesType.NODE_OUTPUT.value,
index_within_arg=0, index_of_arg=0, fqn=fqn)
else:
env[node.name] = new_graph.node_copy(node, load_arg)
new_gm = GraphModule(gm, new_graph)
return new_gm
def _insert_quantize_per_tensor_node(
prev_node_c: Node,
node_a: Node,
gm_b: GraphModule,
graph_c: Graph,
scale: Union[torch.Tensor, float],
zero_point: Union[torch.Tensor, int],
dtype_cast_name: str,
) -> Node:
# copy scale
scale_node_name = \
get_new_attr_name_with_prefix(
node_a.name + '_input_scale_')(gm_b)
setattr(gm_b, scale_node_name, scale)
scale_node = graph_c.create_node(
'get_attr', scale_node_name, (), {}, scale_node_name)
# copy zero_point
zero_point_node_name = \
get_new_attr_name_with_prefix(
node_a.name + '_input_zero_point_')(gm_b)
setattr(gm_b, zero_point_node_name, zero_point)
zero_point_node = graph_c.create_node(
'get_attr', zero_point_node_name, (), {}, zero_point_node_name)
# create the quantize_per_tensor call
return graph_c.create_node(
'call_function', torch.quantize_per_tensor,
(prev_node_c, scale_node, zero_point_node, torch.quint8), {},
dtype_cast_name)
def _insert_dtype_cast_after_node(
node_a: Node,
node_c: Node,
prev_node_c: Union[Node, List[Node]],
gm_a: GraphModule,
gm_b: GraphModule,
graph_c: Graph,
node_name_prefix: str,
logger_cls: Callable,
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Union[Node, List[Node]]:
"""
Given a starting graph C (derived from graph B) of
... -> prev_node_c -> node_c -> ...
And a corresponding related node_a, inserts the correct dtype
cast node after prev_node_c to cast into the dtype expected
by node_a, resulting in:
dtype_cast
/
... -> prev_node_c -> node_c -> ...
For example, if node_c is an int8 op and node_a is an fp32 op, this function
will insert a dequant.
"""
dtype_cast_op = None
dtype_cast_mod_cls = None
dtype_cast_method = None
dtype_cast_method_dtype = None
dtype_cast_scale = None
dtype_cast_zero_point = None
node_input_type_a, _node_output_type_a = \
get_node_first_input_and_output_type(
node_a, gm_a, logger_cls, node_type_to_io_type_map)
node_input_type_c, _node_output_type_c = \
get_node_first_input_and_output_type(
node_c, gm_b, logger_cls, node_type_to_io_type_map)
if (
(node_input_type_a == NodeInputOrOutputType.FP32 and
node_input_type_c == NodeInputOrOutputType.INT8) or
(node_input_type_a == NodeInputOrOutputType.FP32 and
node_input_type_c == NodeInputOrOutputType.FP16) or
# TODO(future PR): determine the actual dtype of node_c,
# the current code only works because dequantize works with
# multiple input dtypes.
(node_input_type_a == NodeInputOrOutputType.FP32 and
node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)
):
dtype_cast_op = torch.dequantize
elif (
node_input_type_a == node_input_type_c and
node_input_type_a != NodeInputOrOutputType.UNKNOWN
):
dtype_cast_mod_cls = torch.nn.Identity
elif (
node_input_type_a == NodeInputOrOutputType.INT8 and
node_input_type_c == NodeInputOrOutputType.FP32
):
# int8 shadows fp32, the dtype cast needs to quantize to int8
# with the right qparams.
node_a_input_qparams = get_node_input_qparams(
node_a, gm_a, node_type_to_io_type_map)
if node_a_input_qparams is not None:
dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment]
dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
elif (
node_input_type_a == NodeInputOrOutputType.FP16 and
node_input_type_c == NodeInputOrOutputType.FP32
):
dtype_cast_method = 'to'
dtype_cast_method_dtype = torch.float16
else:
raise AssertionError(
f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
f"{node_input_type_a} {node_a.format_node()} needs to be implemented")
if isinstance(prev_node_c, Node):
new_dtype_cast_name = \
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
if dtype_cast_op:
if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
return _insert_quantize_per_tensor_node(
prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale,
dtype_cast_zero_point, new_dtype_cast_name)
else:
return graph_c.create_node(
'call_function', dtype_cast_op, (prev_node_c,), {},
new_dtype_cast_name)
elif dtype_cast_method:
return graph_c.create_node(
'call_method', dtype_cast_method,
(prev_node_c, dtype_cast_method_dtype), {}, new_dtype_cast_name)
else:
assert dtype_cast_mod_cls
dtype_cast_mod = dtype_cast_mod_cls()
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
return graph_c.create_node(
'call_module', new_dtype_cast_name, (prev_node_c,), {},
new_dtype_cast_name)
elif isinstance(prev_node_c, list):
results = []
for prev_node_c_inner in prev_node_c:
new_dtype_cast_name = \
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
if dtype_cast_op:
# TODO(future PR): add handling for quantize_per_tensor
new_dtype_cast_node = graph_c.create_node(
'call_function', dtype_cast_op, (prev_node_c_inner,), {},
new_dtype_cast_name)
results.append(new_dtype_cast_node)
else:
assert dtype_cast_mod_cls
dtype_cast_mod = dtype_cast_mod_cls()
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
new_dtype_cast_node = graph_c.create_node(
'call_module', new_dtype_cast_name, (prev_node_c_inner,), {},
new_dtype_cast_name)
results.append(new_dtype_cast_node)
return results
else:
raise AssertionError(f"type f{type(prev_node_c)} is not handled")
# TODO(future PR): look into using copy_node API instead
def _copy_node_from_a_to_c(
node_a: Node,
gm_a: GraphModule,
gm_b: GraphModule,
graph_c: Graph,
) -> Node:
"""
Simple copy of node_a to graph_c.
"""
if node_a.op == 'get_attr':
node_a_copy_name = \
get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type]
if torch.is_tensor(node_a_obj):
node_a_obj = node_a_obj.detach()
setattr(gm_b, node_a_copy_name, node_a_obj)
node_a_copy = graph_c.create_node(
node_a.op, node_a_copy_name, (), {}, node_a_copy_name)
return node_a_copy
elif node_a.op == 'call_method':
assert node_a.target in ('dequantize', 'to'), \
f"target {node_a.target} is not implemented"
if node_a.target == 'dequantize':
arg_copy = _copy_node_from_a_to_c(
get_normalized_nth_input(node_a, gm_a, 0),
gm_a, gm_b, graph_c) # type: ignore[arg-type]
node_a_copy_name = \
get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
node_a_copy = graph_c.create_node(
node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name)
return node_a_copy
else: # to
arg_copy = _copy_node_from_a_to_c(
get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c) # type: ignore[arg-type]
node_a_copy_name = \
get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
node_a_copy = graph_c.create_node(
node_a.op, node_a.target,
(arg_copy, get_normalized_nth_input(node_a, gm_a, 1)),
{}, node_a_copy_name)
return node_a_copy
else:
raise AssertionError(
f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented")
def _can_insert_copy_of_subgraph_a(
subgraph_a: NSSubgraph,
gm_a: GraphModule,
num_non_param_args_node_a: int,
) -> bool:
"""
This function returns `False` if the input subgraph cannot be copied by
`_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
that there is a corner case logic for which copy is not yet implemented.
"""
# populate the list of nodes we need to check
nodes = []
cur_node = subgraph_a.end_node
while cur_node != subgraph_a.start_node:
nodes.append(cur_node)
cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
nodes.append(cur_node)
nodes.reverse()
def _can_insert(node_a_arg, gm_a):
if isinstance(node_a_arg, Node):
arg_a = return_first_non_observer_node(node_a_arg, gm_a)
if arg_a.op == 'call_method':
return arg_a.target in ('dequantize', 'to')
elif arg_a.op == 'get_attr':
return True
else:
return False
elif isinstance(node_a_arg, (list, tuple)):
for el in node_a_arg:
if not isinstance(el, Node):
return False
return True
# For each node, check if we handle the copy behavior. This follows the
# logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
for node_a in nodes:
local_num_non_param_args_node_a = num_non_param_args_node_a \
if node_a is nodes[0] else 1
norm_args_kwargs = node_a.normalized_arguments(
gm_a, normalize_to_only_use_kwargs=True)
if norm_args_kwargs is not None:
norm_args, norm_kwargs = norm_args_kwargs
else:
norm_args, norm_kwargs = node_a.args, node_a.kwargs
cur_idx = 0
while cur_idx < len(norm_args):
if cur_idx == 0:
pass
elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
pass
else:
if not _can_insert(norm_args[cur_idx], gm_a):
return False
cur_idx += 1
for kwarg_val in norm_kwargs.values():
# stitch the inputs from base graph
if cur_idx == 0:
pass
elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
pass
else:
if not _can_insert(kwarg_val, gm_a):
return False
cur_idx += 1
return True
def _insert_copy_of_subgraph_a_after_input_node_c(
input_node_c: Union[Node, List[Node]],
input_node_c_2: Optional[Union[Node, List[Node]]],
subgraph_a: NSSubgraph,
gm_a: GraphModule,
gm_b: GraphModule,
node_name_prefix: str,
) -> Node:
"""
TODO(before land): real docblock
"""
if isinstance(input_node_c, Node):
graph_c = input_node_c.graph
else:
assert isinstance(input_node_c, list)
graph_c = input_node_c[0].graph
# create a sequential list of the subgraphs' nodes from start to end,
# because we need to add the nodes to graph C in non-reverse order
nodes_of_a = [subgraph_a.end_node]
cur_node = subgraph_a.end_node
while cur_node != subgraph_a.start_node:
cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
nodes_of_a.insert(0, cur_node)
# go through nodes of a in order, and insert them into the graph of c
# sequentially
cur_node_a = nodes_of_a[0]
cur_node_c = _insert_copy_of_node_a_after_input_node_c(
input_node_c,
input_node_c_2,
cur_node_a,
gm_a,
gm_b,
node_name_prefix)
for cur_idx_a in range(1, len(nodes_of_a)):
cur_node_a = nodes_of_a[cur_idx_a]
prev_node_c = cur_node_c # previous added node is the input to next node
cur_node_c = _insert_copy_of_node_a_after_input_node_c(
prev_node_c,
# TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph
None,
cur_node_a,
gm_a,
gm_b,
node_name_prefix)
# return the last inserted node
return cur_node_c
def _insert_copy_of_node_a_after_input_node_c(
input_node_c: Union[Node, List[Node]],
input_node_c_2: Optional[Union[Node, List[Node]]],
node_a: Node,
gm_a: GraphModule,
gm_b: GraphModule,
node_name_prefix: str,
) -> Node:
"""
Assume that node_a from graph_a has
args (input, (input2)?, arg1, ...), and
kwargs {kw0: kwarg0, ...}
Note: input2 is optional. If it equals to None, we assume that the op
has a single non-param input. If it is specified, we assume that the op
has two non-param inputs.
Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
and creates the corresponding nodes in graph_c. Note: observers are ignored,
so if an arg is an observer we navigate up until we find a non-observer parent.
If node_a is a call_module, points the module pointed to by node_a to gm_b.
Creates the copy of node_a in graph_c, with input as the first arg,
and all other args and kwargs pointing to the copies of the objects
in gm_b created above.
An example in pictures:
graph A:
========
input -------------> node_a
/ / /
(input_2)?----------/ / /
/ /
weight -> weight_obs /
/
bias ----------------
graph C (derived from B):
=========================
input_node_c --> node_a_copy
/ / /
(input_node_c_2)? / /
/ /
weight_copy ----/ /
/
bias_copy ------/
"""
if isinstance(input_node_c, Node):
graph_c = input_node_c.graph
else:
assert isinstance(input_node_c, list)
graph_c = input_node_c[0].graph
norm_args_kwargs = node_a.normalized_arguments(
gm_a, normalize_to_only_use_kwargs=True)
if norm_args_kwargs is not None:
norm_args, norm_kwargs = norm_args_kwargs
else:
norm_args, norm_kwargs = node_a.args, node_a.kwargs
new_args = []
new_kwargs = {}
def _copy_arg(arg):
# copy the other inputs from the other graph
if isinstance(arg, Node):
arg = return_first_non_observer_node(arg, gm_a)
arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c)
return arg
elif isinstance(arg, (int, float, torch.dtype)):
return arg
elif isinstance(kwarg_val, (list, tuple)):
for el in kwarg_val:
assert not isinstance(el, Node), \
"handling of Node inside list is not implemented"
return arg
else:
raise AssertionError(
f"handling for kwarg of type {type(kwarg_val)} is not implemented")
cur_idx = 0
while cur_idx < len(norm_args):
if cur_idx == 0:
new_arg = input_node_c
elif cur_idx == 1 and input_node_c_2 is not None:
new_arg = input_node_c_2
else:
new_arg = _copy_arg(norm_args[cur_idx])
new_args.append(new_arg)
cur_idx += 1
for kwarg_name, kwarg_val in norm_kwargs.items():
# stitch the inputs from base graph
if cur_idx == 0:
new_kwargs[kwarg_name] = input_node_c
elif cur_idx == 1 and input_node_c_2 is not None:
new_kwargs[kwarg_name] = input_node_c_2
else:
new_kwargs[kwarg_name] = _copy_arg(kwarg_val)
cur_idx += 1
new_args = tuple(new_args) # type: ignore[assignment]
node_a_shadows_c_name = \
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
if node_a.op == 'call_module':
# if target is a module, we point to the module from gm_b
new_mod_copy_name = \
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
# fetch the corresponding module from gm_a
assert isinstance(node_a.target, str)
mod_a = getattr_from_fqn(gm_a, node_a.target)
setattr(gm_b, new_mod_copy_name, mod_a)
node_a_shadows_c = graph_c.create_node(
node_a.op, new_mod_copy_name, new_args,
new_kwargs, node_a_shadows_c_name)
return node_a_shadows_c
else:
assert node_a.op in ('call_function', 'call_method')
node_a_shadows_c = graph_c.create_node(
node_a.op, node_a.target, new_args,
new_kwargs, node_a_shadows_c_name)
return node_a_shadows_c
def create_a_shadows_b(
name_a: str,
gm_a: GraphModule,
name_b: str,
gm_b: GraphModule,
matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
logger_cls: Callable,
should_log_inputs: bool,
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> GraphModule:
"""
Creates a new GraphModule consisting of the graph of C, with the meaningful
nodes of A shadowing the corresponding nodes of B. For example,
Graph A:
a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2
Graph B:
b0 -> op0_int8 -> b1 -> op1_int8 -> b2
matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}
Graph C (A shadows B):
/ dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1
/ /
b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1
In a nutshell, this function does the following for each node pair:
* copies the necessary attributes and modules from gm_a to gm_b,
keeping names unique
* adds a dtype cast op (dequant, quant, etc)
* adds a copy of node_a in gm_b's graph
* adds loggers to the outputs of node_a and node_b
"""
if node_type_to_io_type_map is None:
node_type_to_io_type_map = get_node_type_to_io_type_map()
# graph_c is the graph created from copying the nodes of graph_b and inserting
# the shadows with the nodes copied from graph_a
graph_c = Graph()
env_c: Dict[str, Any] = {}
modules = dict(gm_b.named_modules())
def load_arg(a):
return map_arg(a, lambda node: env_c[node.name])
start_node_b_to_matched_subgraph_a_and_name = {}
end_node_b_to_matched_subgraph_a_and_name = {}
for match_name, match in matched_subgraph_pairs.items():
subgraph_a, subgraph_b = match
ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \
(subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \
(subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
for node_b in gm_b.graph.nodes:
if node_b.op == 'output':
graph_c.output(map_arg(node_b.args[0], load_arg))
continue
# calculate the flags to determine what to do with this node
node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name
if (node_b_is_start_node or node_b_is_end_node):
if node_b_is_start_node:
subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
start_node_b_to_matched_subgraph_a_and_name[node_b]
else:
assert node_b_is_end_node
subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
end_node_b_to_matched_subgraph_a_and_name[node_b]
all_op_types_support_shadowing = (
op_type_supports_shadowing(subgraph_a.start_node) and
op_type_supports_shadowing(node_b)
)
if not all_op_types_support_shadowing:
print(
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
', unsupported')
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
continue
# For both start_node and end_node verify that we know how to do
# the dtype cast. If we do not, skip.
node_input_type_a, node_output_type_a = \
get_node_first_input_and_output_type(
subgraph_a.start_node, gm_a, logger_cls,
node_type_to_io_type_map)
node_input_type_b, node_output_type_b = \
get_node_first_input_and_output_type(
node_b, gm_b, logger_cls,
node_type_to_io_type_map)
node_io_types_known_a_and_b = (
node_input_type_a != NodeInputOrOutputType.UNKNOWN and
node_output_type_a != NodeInputOrOutputType.UNKNOWN and
node_input_type_b != NodeInputOrOutputType.UNKNOWN and
node_output_type_b != NodeInputOrOutputType.UNKNOWN
)
if not node_io_types_known_a_and_b:
print(
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
', unknown dtype cast')
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
continue
# If we are shadowing from fp32 to int8, we need to insert
# quantize_per_tensor call with qparams from the previous node.
# Only do this if we are able to infer these qparams from the graph.
if (
node_input_type_a == NodeInputOrOutputType.INT8 and
node_input_type_b == NodeInputOrOutputType.FP32
):
node_a_input_qparams = get_node_input_qparams(
subgraph_a.start_node, gm_a, node_type_to_io_type_map)
if not node_a_input_qparams:
print(
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
', unknown input qparams')
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
continue
num_non_param_args_node_a = \
get_number_of_non_param_args(subgraph_a.start_node, gm_a)
if not _can_insert_copy_of_subgraph_a(subgraph_a, gm_a, num_non_param_args_node_a):
print(
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
', unhandled logic in subgraph copy')
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
continue
fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)
if node_b_is_start_node:
# if necessary, log the input of node_c
if should_log_inputs:
prev_node_b = get_normalized_nth_input(node_b, gm_b, 0)
if isinstance(prev_node_b, Node):
prev_node_c = env_c[prev_node_b.name]
env_c[prev_node_c.name] = _insert_logger_after_node(
prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
node_b.name, name_b, ref_name, ref_node_type_b,
NSSingleResultValuesType.NODE_INPUT.value,
index_within_arg=0, index_of_arg=0,
fqn=fqn_base_b)
elif isinstance(prev_node_b, list):
# first, save the prev_node instances, because they
# will be overwritten in the env after the first logger
# is added
prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
for arg_idx, arg in enumerate(prev_node_b):
prev_node_c = prev_node_c_list[arg_idx]
env_c[prev_node_c.name] = _insert_logger_after_node(
prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
node_b.name, name_b, ref_name, ref_node_type_b,
NSSingleResultValuesType.NODE_INPUT.value,
index_within_arg=arg_idx, index_of_arg=0,
fqn=fqn_base_b)
else:
# logging of inputs which are not lists is not supported yet
raise AssertionError(f"type {type(prev_node_b)} is not handled yet")
# subgraph so far:
#
# (prev_node_c)+ -> (logger_c_input)?
# Note: this if statement is always True, spelling it out to clarify code
# intent.
if node_b_is_start_node or node_b_is_end_node:
# ensure env_c is populated with base node
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
node_c = env_c[node_b.name]
# after this point,
#
# node_a is the original node from graph_a, with parent module gm_a
# node_b is the original node from graph_b, with parent module gm_b
# node_c is the copy of node_b in graph_c
#
# subgraph so far:
#
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
if node_b_is_start_node:
# cast dtype from the dtype of node_c's input to the dtype of
# node_a's input (dequant, etc)
# prev_node_c = node_c.args[0]
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
if should_log_inputs:
# skip the input logger when inserting a dtype cast
if isinstance(prev_node_c, Node):
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
elif isinstance(prev_node_c, list):
prev_node_c = [get_normalized_nth_input(arg, gm_b, 0) for arg in prev_node_c]
dtype_cast_node = _insert_dtype_cast_after_node(
subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c,
node_b.name + '_dtype_cast_', logger_cls,
node_type_to_io_type_map)
# note: not inserting to env_c because all nodes which use the dtype
# casts are copied from graph_a
#
# subgraph so far:
#
# (dtype_cast_node)+
# /
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
# if input logging is enabled, log the input to the subgraph
if should_log_inputs:
# TODO: explain this
ref_node_name = ''
if isinstance(dtype_cast_node, Node):
dtype_cast_node = _insert_logger_after_node(
dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_',
ref_node_name, name_a, ref_name, ref_node_type_a,
NSSingleResultValuesType.NODE_INPUT.value,
index_within_arg=0, index_of_arg=0,
fqn=fqn_base_a)
input_logger: Union[Node, List[Node]] = dtype_cast_node
else:
assert isinstance(dtype_cast_node, list)
new_loggers = []
for dtype_cast_idx, dtype_cast_node_inner in enumerate(dtype_cast_node):
dtype_cast_logger = _insert_logger_after_node(
dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_',
ref_node_name, name_a, ref_name, ref_node_type_a,
NSSingleResultValuesType.NODE_INPUT.value,
index_within_arg=dtype_cast_idx,
index_of_arg=0,
fqn=fqn_base_a)
new_loggers.append(dtype_cast_logger)
dtype_cast_node = new_loggers
input_logger = dtype_cast_node
# subgraph so far:
#
# (dtype_cast_node)+ -> (logger_a_input)?
# /
# prev_node_c -> (logger_c_input)? -> node_start_c
# hook up the new mod_a copy to be in the graph, receiving the
# same inputs as mod_b does, with dtype cast to match a
# Some ops, such as LSTMs, have two non-param inputs. If we have
# such an op, pass the second param as well. Note: dtype casting
# for the second param is not implemented yet, it can be added
# later if there is a use case.
node_c_second_non_param_arg = None
num_non_param_args_node_a = get_number_of_non_param_args(subgraph_a.start_node, gm_a)
if num_non_param_args_node_a == 2:
# node_c_second_non_param_arg = node_c.args[1]
node_c_second_non_param_arg = get_normalized_nth_input(node_c, gm_b, 1)
node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
dtype_cast_node, node_c_second_non_param_arg,
subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_')
env_c[node_a_shadows_c.name] = node_a_shadows_c
# subgraph so far:
#
# dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
# /
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
if should_log_inputs:
# When we created the input logger, we left the ref_node_name
# as an empty string, because the subgraph copy did not exist
# yet. Now that the subgraph copy exists, we modify this name
# to its true value.
# Note: the alternative to this is to create the input logger
# after creating the subgraph, which is slightly more
# complicated. This is the lesser of two evils.
# input_logger = env_c[dtype_cast_node.name]
# Find the first node in the subgraph
cur_node = node_a_shadows_c
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger:
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
if isinstance(input_logger, Node):
input_logger_mod = getattr(gm_b, input_logger.name)
input_logger_mod.ref_node_name = cur_node.name
else:
assert isinstance(input_logger, list)
for input_logger_inner in input_logger:
input_logger_mod = getattr(gm_b, input_logger_inner.name)
input_logger_mod.ref_node_name = cur_node.name
# hook up a logger to the mod_a copy
env_c[node_a_shadows_c.name] = _insert_logger_after_node(
env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_',
node_a_shadows_c.name, name_a, ref_name, ref_node_type_a,
NSSingleResultValuesType.NODE_OUTPUT.value,
index_within_arg=0, index_of_arg=0,
fqn=fqn_base_a)
# subgraph so far:
#
# dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
# /
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
if node_b_is_end_node:
# hook up a logger to the mod_b copy
env_c[node_b.name] = _insert_logger_after_node(
env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_',
node_b.name, name_b, ref_name, ref_node_type_b,
NSSingleResultValuesType.NODE_OUTPUT.value,
index_within_arg=0, index_of_arg=0,
fqn=fqn_base_b)
# subgraph so far:
#
# dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
# /
# (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
#
# Note: node_start_c may be the same node as node_end_c, or they
# may have nodes inbetween.
else:
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
gm_c = GraphModule(gm_b, graph_c)
return gm_c