import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.fx
from torch.fx._compatibility import compatibility
from torch.fx.node import map_arg
from .shape_prop import ShapeProp
from .split_utils import split_by_tags
from .tools_common import (
CALLABLE_NODE_OPS,
FxNetAccFusionsFinder,
Names,
NodeList,
NodeSet,
TensorOrTensors,
Tensors,
)
__all__ = [
"FxNetMinimizerBadModuleError",
"FxNetMinimizerRunFuncError",
"FxNetMinimizerResultMismatchError",
]
_LOGGER = logging.getLogger(__name__)
@compatibility(is_backward_compatible=False)
class FxNetMinimizerBadModuleError(Exception):
"""
Raised if failed to split out a minimize module
"""
pass
@compatibility(is_backward_compatible=False)
class FxNetMinimizerRunFuncError(Exception):
"""
Raised if error occurs during run_a or run_b functions
"""
pass
@compatibility(is_backward_compatible=False)
class FxNetMinimizerResultMismatchError(Exception):
"""
Raised if comparing function thinks the results are mismatching.
"""
pass
@dataclass
class _MinimizerSettingBase:
"""
Args:
`accumulate_error`: Instead of using a's input for both converted module to verify
, use the previous outputs of each converted module as input to accumulate the
errors.
`traverse_method`: "sequential" or "binary" or "accumulate"
Determine the way of traverse the nodes in FX module.
`find_all`: Minimizer will go through the entire model and return all problematic nodes.
`return_intermediate`: If true, when using `run_nodes()` function to run the
model, intermediate results of all the ops will be returned as output.
"""
accumulate_error: bool = False
traverse_method: str = "sequential"
find_all: bool = False
return_intermediate: bool = False
def __str__(self):
settings_str = "FX Minimizer Settings:\n"
for k, v in vars(self).items():
settings_str += f"\t{k}: {v}\n"
return settings_str
class _MinimizerBase:
"""
This class is used to automatically find problematic nodes in a model. It takes a FX
graphmodule and generate some submodules while traverse the graph. Then two functions
`run_a` and `run_b` will be used to run the same submodule and a function `compare_fn`
will be used to compare the results.
Currently we provides two ways to traverse the graph and generate submodules.
1. Sequential traversal: this will traverse the graph node by node and generate
one submodule with one sigle node.
2. Binary searching: this will do a binary search style traversal on the graph.
For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP.
"""
def __init__(
self,
module: torch.fx.GraphModule,
sample_input: Tensors,
compare_fn: Callable[
[TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
],
settings: _MinimizerSettingBase,
):
assert isinstance(module, torch.fx.GraphModule)
self.module = module
self.sample_input = sample_input
self.compare_fn = compare_fn
self.settings = settings
# Stores outputs of run_a function
self.a_outputs: Dict[str, Any] = {}
# Stores outputs of run_b function
self.b_outputs: Dict[str, Any] = {}
# Stores the results of compare_fn
self.results: Dict[Any, Any] = {}
# Stores the report for the runs
self.reports: List[List[str]] = []
# Current iteration
self.iteration: int = 0
callable_nodes = {
node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
}
ShapeProp(self.module).propagate(*self.sample_input)
self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)()
# Check if number of input in sample_input matches the number of placeholders
placeholders = [
node.name for node in self.module.graph.nodes if node.op == "placeholder"
]
assert len(placeholders) == len(self.sample_input)
# Store sample_input
for i, name in enumerate(placeholders):
self.a_outputs[name] = sample_input[i]
self.b_outputs[name] = sample_input[i]
def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors:
"""
Run `mod` with `inputs` and generate output. The output will be compared with
output of run_b().
"""
raise RuntimeError("run_a() is not implemented.")
def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors:
"""
Run `mod` with `inputs` and generate output. The output will be compared with
output of run_a().
"""
raise RuntimeError("run_b() is not implemented.")
def _store_outputs(
self,
a_result: TensorOrTensors,
b_result: TensorOrTensors,
submodule: torch.fx.GraphModule,
):
"""
Store the outputs of self.run_a() and self.run_b() into self.a_outputs and
self.b_outputs, so that we can use them when execute preceding nodes that
use those outputs as inputs.
Args:
a_result: Output of self.run_a(). Could be a tensor or tensors.
b_result: Output of self.run_b(). Could be a tensor or tensors.
submodule: The module that generates a_result and b_result.
"""
output_node = next(
node for node in submodule.graph.nodes if node.op == "output"
)
# Only one output
if isinstance(output_node.args[0], torch.fx.Node):
self.a_outputs[output_node.args[0].name] = a_result
self.b_outputs[output_node.args[0].name] = b_result
# Multiple outputs
else:
for i, arg in enumerate(output_node.args[0]):
self.a_outputs[arg.name] = a_result[i]
self.b_outputs[arg.name] = b_result[i]
def _get_submod_inputs(
self, main_module: torch.fx.GraphModule, submod_path: str
) -> Tuple[Tensors, Tensors]:
"""
Try get submodule inputs from stored outputs. If not found then use
torch_glow.get_submod_inputs to get the inputs.
If accumulate_error is False, use a_input for run_a() and run_b()
otherwise use a_input for run_a and b_input for run_b.
Args:
main_module: Top-levlel fx module.
submod_path: Path to the submodule we want to run and compare results.
Returns:
a_input: List of tensor(s) that will be used by run_a() as submodule inputs.
b_input: List of tensor(s) that will be used by run_b() as submodule inputs.
"""
a_input = []
b_input = []
submodule = getattr(main_module, submod_path)
placeholders = [
node.name for node in submodule.graph.nodes if node.op == "placeholder"
]
# If all placeholder can be found in stored outputs, use stored
# outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs`
# to get the inputs.
if set(placeholders) <= self.a_outputs.keys():
for name in placeholders:
a_input.append(self.a_outputs[name])
b_input.append(self.b_outputs[name])
else:
if self.settings.accumulate_error:
print(f"Can't find previous stored outputs named {placeholders}!")
def get_inputs(self: torch.nn.Module, inputs: Any):
nonlocal a_input
a_input = inputs
# Use forward hook to get the inputs to the submodule
handle = submodule.register_forward_pre_hook(get_inputs)
main_module(*self.sample_input)
handle.remove()
b_input = a_input
if not self.settings.accumulate_error:
return a_input, a_input
return a_input, b_input
def _tag_nodes(self, selected_nodes: NodeSet):
"""
Tag selected nodes with tag "minimize". Nodes with the same tags will
be split to the same submodule afterwards.
Args:
selected_nodes: Nodes that we want to minimize. We will tag those nodes
with "minimize", all preceding nodes with "main_0" and all following
nodes with "main_1".
"""
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
if node in selected_nodes:
node.tag = "minimize"
elif any(
n.tag in {"minimize", "main_1"}
for n in node.all_input_nodes
if n.op in CALLABLE_NODE_OPS
):
node.tag = "main_1"
else:
node.tag = "main_0"
def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]:
"""
Split self.module so that one submodule consists of `nodes` and only `nodes`.
Args:
nodes: Nodes that we want to include in the minimize submodule.
Returns:
split_module (torch.fx.GraphModule): the module after split.
submodule_name (str): the name of the submodule that consists of `nodes`.
"""
# Color provided nodes
self._tag_nodes(nodes)
# Split module based on coloring
split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"])
# Find submodule containing colored nodes
submodule_name: str = ""
for child_name, _ in split_module.named_children():
# Skip submodules we're not interested in at the moment
if "minimize" not in child_name:
continue
if submodule_name == "":
submodule_name = child_name
else:
raise FxNetMinimizerBadModuleError(
f"Expected only one minimize submodule with nodes {nodes}"
)
if submodule_name == "":
raise FxNetMinimizerBadModuleError(
f"Minimize submodule was not found with nodes {nodes}"
)
return split_module, submodule_name
def _run_and_compare(
self, split_module: torch.fx.GraphModule, submod_name: str, output_names: Names
):
"""
Run the submodule in `split_module` that has name `submod_name`
using `self.run_a` and `self.run_b` and compare their results.
Args:
split_module: Main module that contains the minimize submodule.
submod_name: Name of the minimize submodule.
output_names: Names of the node we want to output. If None, we
will use the original output.
"""
submodule = getattr(split_module, submod_name)
a_input, b_input = self._get_submod_inputs(split_module, submod_name)
if len(self.reports) == 0:
self.reports.append([])
self.iteration = 1
report = self.reports[self.iteration - 1]
report.append("Run and compare ...")
if output_names:
output_nodes: NodeList = []
for node in submodule.graph.nodes:
if node.op == "output":
submodule.graph.erase_node(node)
if node.name in output_names:
output_nodes.append(node)
submodule.graph.output(
output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes)
)
Loading ...