import torch
from torch.fx import GraphModule, map_arg
from torch.fx.graph import Graph, Node
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from .utils import (
get_node_first_input_and_output_type,
getattr_from_fqn,
NodeInputOrOutputType,
return_first_non_observer_node,
get_number_of_non_param_args,
get_target_type_str,
get_arg_indices_of_inputs_to_log,
get_node_input_qparams,
op_type_supports_shadowing,
get_normalized_nth_input,
)
from .ns_types import (
NSSingleResultValuesType,
NSSubgraph,
NSNodeTargetType,
)
from torch.ao.ns.fx.mappings import (
get_node_type_to_io_type_map,
)
from torch.ao.quantization.observer import _is_activation_post_process
from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set
def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
fqn = None
if hasattr(gm, '_node_name_to_scope'):
# fqn on observers is not present, because they do not
# exist when the fqns are created during tracing. If this is
# an observer, get the fqn of the node being observed.
node_to_use_for_fqn = node
if node.op == 'call_module':
assert isinstance(node.target, str)
module = getattr_from_fqn(gm, node.target)
if _is_activation_post_process(module):
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
return fqn # type: ignore[return-value]
def _insert_logger_after_node(
node: Node,
gm: GraphModule,
logger_cls: Callable,
logger_node_name_suffix: str,
ref_node_name: str,
model_name: str,
ref_name: str,
ref_node_target_type: str,
results_type: str,
index_within_arg: int,
index_of_arg: int,
fqn: Optional[str],
) -> Node:
"""
Given a starting graph of
prev_node -> node -> next_node
This function creates a new logger_cls obj and adds it
after node, resulting in
prev_node -> node -> logger_obj -> next_node
"""
# create new name
logger_node_name = \
get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
target_type = get_target_type_str(node, gm)
# create the logger object
logger_obj = logger_cls(
ref_node_name, node.name, model_name, ref_name, target_type,
ref_node_target_type,
results_type, index_within_arg, index_of_arg, fqn)
# attach the logger object to the parent module
setattr(gm, logger_node_name, logger_obj)
logger_node = node.graph.create_node(
'call_module', logger_node_name, (node,), {})
return logger_node
def add_loggers_to_model(
gm: GraphModule,
node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
logger_cls: Callable,
model_name: str,
) -> GraphModule:
"""
Takes the graph of gm, adds loggers to the output
of each node in nodes_to_instrument. Returns a GraphModule with the new
graph.
"""
new_graph = Graph()
env: Dict[str, Any] = {}
modules = dict(gm.named_modules())
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
for node in gm.graph.nodes:
if node.op == 'output':
new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
continue
if (
(node in node_to_instrument_inputs_to_ref_node_name) or
(node in node_to_instrument_outputs_to_ref_node_name)
):
fqn = _maybe_get_fqn(node, gm)
if node in node_to_instrument_inputs_to_ref_node_name:
ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[node]
# Ops such add and mul are special because either
# one or two of the first two arguments can be tensors,
# and if one argument is a tensor it can be first or
# second (x + 1 versus 1 + x).
arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
for node_arg_idx in arg_indices_to_log:
node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
if type(node_arg) == Node:
# create a single input logger
prev_node = env[node_arg.name]
env[node_arg.name] = _insert_logger_after_node(
prev_node, gm, logger_cls, '_ns_logger_', node.name,
model_name, ref_name, ref_node_type,
NSSingleResultValuesType.NODE_INPUT.value,
index_within_arg=0, index_of_arg=node_arg_idx,
fqn=fqn)
elif type(node_arg) == torch.fx.immutable_collections.immutable_list:
# create N input loggers, one for each node
for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type]
prev_node = env[arg.name]
env[prev_node.name] = _insert_logger_after_node(
prev_node, gm, logger_cls, '_ns_logger_', node.name,
model_name, ref_name, ref_node_type,
NSSingleResultValuesType.NODE_INPUT.value,
index_within_arg=arg_idx, index_of_arg=node_arg_idx,
fqn=fqn)
else:
pass
# ensure env is populated with base node
# Note: runs for both inputs and outputs
env[node.name] = new_graph.node_copy(node, load_arg)
if node in node_to_instrument_outputs_to_ref_node_name:
ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[node]
# add the logger after the base node
env[node.name] = _insert_logger_after_node(
env[node.name], gm, logger_cls, '_ns_logger_', node.name,
model_name, ref_name, ref_node_type,
NSSingleResultValuesType.NODE_OUTPUT.value,
index_within_arg=0, index_of_arg=0, fqn=fqn)
else:
env[node.name] = new_graph.node_copy(node, load_arg)
new_gm = GraphModule(gm, new_graph)
return new_gm
def _insert_quantize_per_tensor_node(
prev_node_c: Node,
node_a: Node,
gm_b: GraphModule,
graph_c: Graph,
scale: Union[torch.Tensor, float],
zero_point: Union[torch.Tensor, int],
dtype_cast_name: str,
) -> Node:
# copy scale
scale_node_name = \
get_new_attr_name_with_prefix(
node_a.name + '_input_scale_')(gm_b)
setattr(gm_b, scale_node_name, scale)
scale_node = graph_c.create_node(
'get_attr', scale_node_name, (), {}, scale_node_name)
# copy zero_point
zero_point_node_name = \
get_new_attr_name_with_prefix(
node_a.name + '_input_zero_point_')(gm_b)
setattr(gm_b, zero_point_node_name, zero_point)
zero_point_node = graph_c.create_node(
'get_attr', zero_point_node_name, (), {}, zero_point_node_name)
# create the quantize_per_tensor call
return graph_c.create_node(
'call_function', torch.quantize_per_tensor,
(prev_node_c, scale_node, zero_point_node, torch.quint8), {},
dtype_cast_name)
def _insert_dtype_cast_after_node(
node_a: Node,
node_c: Node,
prev_node_c: Union[Node, List[Node]],
gm_a: GraphModule,
gm_b: GraphModule,
graph_c: Graph,
node_name_prefix: str,
logger_cls: Callable,
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Union[Node, List[Node]]:
"""
Given a starting graph C (derived from graph B) of
... -> prev_node_c -> node_c -> ...
And a corresponding related node_a, inserts the correct dtype
cast node after prev_node_c to cast into the dtype expected
by node_a, resulting in:
dtype_cast
/
... -> prev_node_c -> node_c -> ...
For example, if node_c is an int8 op and node_a is an fp32 op, this function
will insert a dequant.
"""
dtype_cast_op = None
dtype_cast_mod_cls = None
dtype_cast_method = None
dtype_cast_method_dtype = None
dtype_cast_scale = None
dtype_cast_zero_point = None
node_input_type_a, _node_output_type_a = \
get_node_first_input_and_output_type(
node_a, gm_a, logger_cls, node_type_to_io_type_map)
node_input_type_c, _node_output_type_c = \
get_node_first_input_and_output_type(
node_c, gm_b, logger_cls, node_type_to_io_type_map)
if (
(node_input_type_a == NodeInputOrOutputType.FP32 and
node_input_type_c == NodeInputOrOutputType.INT8) or
(node_input_type_a == NodeInputOrOutputType.FP32 and
node_input_type_c == NodeInputOrOutputType.FP16) or
# TODO(future PR): determine the actual dtype of node_c,
# the current code only works because dequantize works with
# multiple input dtypes.
(node_input_type_a == NodeInputOrOutputType.FP32 and
node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)
):
dtype_cast_op = torch.dequantize
elif (
node_input_type_a == node_input_type_c and
node_input_type_a != NodeInputOrOutputType.UNKNOWN
):
dtype_cast_mod_cls = torch.nn.Identity
elif (
node_input_type_a == NodeInputOrOutputType.INT8 and
node_input_type_c == NodeInputOrOutputType.FP32
):
# int8 shadows fp32, the dtype cast needs to quantize to int8
# with the right qparams.
node_a_input_qparams = get_node_input_qparams(
node_a, gm_a, node_type_to_io_type_map)
if node_a_input_qparams is not None:
dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment]
dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
elif (
node_input_type_a == NodeInputOrOutputType.FP16 and
node_input_type_c == NodeInputOrOutputType.FP32
):
dtype_cast_method = 'to'
dtype_cast_method_dtype = torch.float16
else:
raise AssertionError(
f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
f"{node_input_type_a} {node_a.format_node()} needs to be implemented")
if isinstance(prev_node_c, Node):
new_dtype_cast_name = \
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
if dtype_cast_op:
if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
return _insert_quantize_per_tensor_node(
prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale,
dtype_cast_zero_point, new_dtype_cast_name)
else:
return graph_c.create_node(
'call_function', dtype_cast_op, (prev_node_c,), {},
new_dtype_cast_name)
elif dtype_cast_method:
return graph_c.create_node(
'call_method', dtype_cast_method,
(prev_node_c, dtype_cast_method_dtype), {}, new_dtype_cast_name)
else:
assert dtype_cast_mod_cls
dtype_cast_mod = dtype_cast_mod_cls()
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
return graph_c.create_node(
'call_module', new_dtype_cast_name, (prev_node_c,), {},
new_dtype_cast_name)
elif isinstance(prev_node_c, list):
results = []
for prev_node_c_inner in prev_node_c:
new_dtype_cast_name = \
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
if dtype_cast_op:
# TODO(future PR): add handling for quantize_per_tensor
new_dtype_cast_node = graph_c.create_node(
'call_function', dtype_cast_op, (prev_node_c_inner,), {},
new_dtype_cast_name)
results.append(new_dtype_cast_node)
else:
assert dtype_cast_mod_cls
dtype_cast_mod = dtype_cast_mod_cls()
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
new_dtype_cast_node = graph_c.create_node(
'call_module', new_dtype_cast_name, (prev_node_c_inner,), {},
new_dtype_cast_name)
results.append(new_dtype_cast_node)
return results
else:
raise AssertionError(f"type f{type(prev_node_c)} is not handled")
# TODO(future PR): look into using copy_node API instead
def _copy_node_from_a_to_c(
node_a: Node,
gm_a: GraphModule,
gm_b: GraphModule,
graph_c: Graph,
) -> Node:
"""
Simple copy of node_a to graph_c.
"""
if node_a.op == 'get_attr':
node_a_copy_name = \
get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type]
if torch.is_tensor(node_a_obj):
node_a_obj = node_a_obj.detach()
setattr(gm_b, node_a_copy_name, node_a_obj)
node_a_copy = graph_c.create_node(
node_a.op, node_a_copy_name, (), {}, node_a_copy_name)
return node_a_copy
elif node_a.op == 'call_method':
assert node_a.target in ('dequantize', 'to'), \
f"target {node_a.target} is not implemented"
if node_a.target == 'dequantize':
arg_copy = _copy_node_from_a_to_c(
get_normalized_nth_input(node_a, gm_a, 0),
Loading ...