Repository URL to install this package:
Version:
2.4.1 ▾
|
# mypy: allow-untyped-defs
import inspect
import logging
from queue import Queue
from functools import wraps
from typing import Callable, Dict, List
import torch.nn as nn
from torch.fx.graph_module import GraphModule
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager']
@compatibility(is_backward_compatible=False)
def pass_result_wrapper(fn: Callable) -> Callable:
"""
Wrapper for passes which currently do not return a PassResult.
This wrapper makes them return a PassResult containing the modified object
and True for the "modified" flag.
Args:
fn (Callable[Module, Any])
Returns:
wrapped_fn (Callable[Module, PassResult])
"""
if fn is None:
return None
@wraps(fn)
def wrapped_fn(gm):
res = fn(gm)
if res is None:
return PassResult(gm, True)
if isinstance(res, PassResult):
return res
elif isinstance(res, nn.Module):
return PassResult(res, True)
if not inspect.isfunction(fn):
wrapped_fn.__name__ = type(fn).__name__
return wrapped_fn
def _validate_pass_schedule_constraint(
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
) -> None:
for i, a in enumerate(passes):
for j, b in enumerate(passes[i + 1 :]):
if constraint(a, b):
continue
raise RuntimeError(
f"pass schedule constraint violated. Expected {a} before {b}"
f" but found {a} at index {i} and {b} at index{j} in pass"
f" list."
)
def _topological_sort_passes(
passes: List[Callable], constraints: List[Callable]
) -> List[Callable]:
"""
Args
passes: Passes that we are ordering
constraints: Constraints applied on these passes
Returns
A sorted list of callables and a boolean of if a circular dependency
existed
"""
if len(constraints) == 0:
return passes
# Contruct a graph mapping nodes to a list of their users
graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
candidates: Queue = Queue()
for a in passes:
for b in passes:
if a == b:
continue
for constraint in constraints:
if not constraint(a, b):
graph[b].append(a)
indegree_map[a] += 1
if indegree_map[a] == 0:
candidates.put(a)
visited: Dict[Callable, bool] = dict.fromkeys(passes, False)
sorted_passes: List[Callable] = []
while not candidates.empty():
p = candidates.get()
sorted_passes.append(p)
visited[p] = True
for n in graph[p]:
if not visited[n]:
indegree_map[n] -= 1
if indegree_map[n] == 0:
candidates.put(n)
# Check if there are unvisited nodes (aka cycles in the graph)
cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
if len(cycle_passes) != 0:
error = f"Circular dependency detected within the following passes: {cycle_passes}"
raise RuntimeError(error)
return sorted_passes
@compatibility(is_backward_compatible=False)
def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
"""
Defines a partial order ('depends on' function) where `this` must occur
before `that`.
For example, the following pass list and constraint list would be invalid.
```
passes = [pass_b, pass_a]
constraints = [
this_before_that_pass_constraint(pass_a, pass_b)
]
```
Args:
this (Callable): pass which should occur first
that (Callable): pass which should occur later
Returns:
depends_on (Callable[[Object, Object], bool]
"""
def depends_on(a: Callable, b: Callable):
if a == that and b == this:
return False
return True
return depends_on
@compatibility(is_backward_compatible=False)
class PassManager:
"""
Construct a PassManager.
Collects passes and constraints. This defines the pass schedule, manages
pass constraints and pass execution.
Args:
passes (Optional[List[Callable]]): List of passes. A pass is a
callable which modifies an object and returns a PassResult
constraint (Optional[List[Callable]]): List of constraints. A
constraint is a callable which takes two passes (A, B) and returns
True if A depends on B and False otherwise. See implementation of
`this_before_that_pass_constraint` for example.
steps (int): Max number of times we run the passes (default = 1).
run_checks_after_each_pass (bool): Whether to run checks and linting
after each pass
suppress_check_failures (bool): Whether to raise errors when running
checks
"""
passes: List[Callable[[nn.Module], PassResult]]
constraints: List[Callable[[Callable, Callable], bool]]
_validated: bool = False
steps: int = 1
def __init__(
self,
passes=None,
constraints=None,
steps=None,
run_checks_after_each_pass: bool = False,
suppress_check_failures: bool = False,
):
self.passes = passes or []
self.constraints = constraints or []
if steps:
self.steps = steps
self.run_checks_after_each_pass = run_checks_after_each_pass
self.suppress_check_failures = suppress_check_failures
def add_pass(self, _pass: Callable):
"""
Adds a pass into the current list of passes.
"""
self.passes.append(_pass)
self._validated = False
def add_constraint(self, constraint: Callable):
"""
Adds a constraint into the current list of constraints.
"""
self.constraints.append(constraint)
self._validated = False
def validate_constraints(self):
"""
Validates that current pass schedule defined by `self.passes` is valid
according to all constraints in `self.constraints`
"""
if self._validated:
return
for constraint in self.constraints:
_validate_pass_schedule_constraint(constraint, self.passes)
self._validated = True
def solve_constraints(self):
"""
Finds a valid traversal order based on the given constraints and orders
the passes based on this order.
If a circular dependency exists between the constraints and steps = 1,
then we will raise an error because if steps != 1 this means that we
will re-run the passes, allowing for circular dependencies.
"""
self.passes = _topological_sort_passes(self.passes, self.constraints)
self._validated = True
def add_checks(self, check: Callable) -> None:
"""
Adds a function which takes runs various checks on a given graph module.
This function is run before and after each pass if the
`run_checks_after_each_pass` flag is enabled.
"""
sig = inspect.signature(check)
if len(list(sig.parameters.values())) != 1:
raise TypeError("PassManager check function should only take in one variable, a module")
setattr(self, "check", check) # noqa: B010
def check(self, module: nn.Module) -> None:
pass
def __call__(self, module: nn.Module) -> PassResult:
"""
Runs a list of passes in the order based on `self.passes` on the given
graph module. Each time a pass is run, checks and linting will be run on
the graph module if `run_checks_after_each_pass` is set.
If the module is a graph module, we will run the list of passes until
the graph stops changing, or until `steps` number of times.
"""
# Order the passes based on the constraints
if not self._validated:
self.solve_constraints()
# Check graph invariants
self.check(module)
# Run the set of passes `steps` number of times or until the graph stops
# changing
overall_modified = False
for _ in range(self.steps):
modified = False
# Run the set of passes on the graph module
for i, fn in enumerate(self.passes):
fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
logger.debug("Running pass '%s'", fn_name)
try:
res = fn(module)
if not isinstance(res, PassResult) and not hasattr(
res, "graph_module"
):
raise TypeError(
f"The result of the pass {fn_name} should be type PassResult."
+ "Please wrap it with pass_result_wrapper()"
)
module = res.graph_module
modified = modified or res.modified
if isinstance(module, GraphModule):
logger.debug("Graph after pass '%s': %s", fn_name, module.graph)
module.recompile()
# Check graph invariants
if self.run_checks_after_each_pass:
self.check(module)
except Exception as e:
prev_pass_names = [
p.__name__ if inspect.isfunction(p) else type(p).__name__
for p in self.passes[:i]
]
msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}"
raise Exception(msg) from e # noqa: TRY002
# If the graph no longer changes, then we can stop running these passes
overall_modified = overall_modified or modified
if not modified:
break
return PassResult(module, overall_modified)