Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ ao / quantization / fx / fuse.py

from torch.fx import (
    GraphModule,
    Node,
    map_arg
)
from torch.fx.graph import Graph
from .match_utils import (
    _is_match,
    MatchAllNode,
)
from .pattern_utils import (
    _sorted_patterns_dict,
)

from ..backend_config import (
    BackendConfig,
    get_native_backend_config,
)
from ..backend_config.utils import (
    get_fuser_method_mapping,
    get_fusion_pattern_to_root_node_getter,
    get_fusion_pattern_to_extra_inputs_getter,
)

from .custom_config import FuseCustomConfig

from .fuse_handler import (
    _get_fusion_pattern_to_fuse_handler_cls,
    FuseHandler,
)

from typing import Any, Callable, Dict, List, Tuple, Union
import warnings

from torch.ao.quantization.utils import Pattern, NodePattern


__all__ = [
    "fuse",
    # TODO: We should make this private in the future
    # This is currently needed for test_public_bindings for some reason
    "FuseHandler",
]


def fuse(
    model: GraphModule,
    is_qat: bool,
    fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
    if fuse_custom_config is None:
        fuse_custom_config = FuseCustomConfig()

    if isinstance(fuse_custom_config, Dict):
        warnings.warn(
            "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
            "in a future version. Please pass in a FuseCustomConfig instead.")
        fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)

    if isinstance(backend_config, Dict):
        warnings.warn(
            "Passing a backend_config_dict to prepare is deprecated and will not be supported "
            "in a future version. Please pass in a BackendConfig instead.")
        backend_config = BackendConfig.from_dict(backend_config)

    named_modules = dict(model.named_modules())

    if backend_config is None:
        backend_config = get_native_backend_config()

    fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config))
    fuser_method_mapping = get_fuser_method_mapping(backend_config)
    fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
    fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config)

    # find fusion
    fusion_pairs = _find_matches(
        model, model.graph, fusion_pattern_to_fuse_handler_cls)
    # TODO: change this to inplace changes to graph, since we no longer construct
    # new GraphModule anymore
    fused_graph = Graph()
    env: Dict[Any, Any] = {}

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    def default_root_node_getter(node_pattern):
        while not isinstance(node_pattern[-1], Node):
            node_pattern = node_pattern[-1]
        return node_pattern[-1]

    for node in model.graph.nodes:
        maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \
            fusion_pairs.get(node.name, (None, None, None, None, None))
        # get the corresponding subpattern for the current node
        if node_to_subpattern is not None:
            node_subpattern = node_to_subpattern.get(node, None)
        else:
            node_subpattern = None
        if maybe_last_node is node:
            assert obj is not None
            root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter)
            root_node = root_node_getter(matched_node_pattern)  # type: ignore[index]
            extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None)
            extra_inputs = []
            if extra_inputs_getter is not None:
                extra_inputs = extra_inputs_getter(matched_node_pattern)
            # TODO: add validation that root_node is a module and has the same type
            # as the root_module in the configuration
            env[node.name] = obj.fuse(
                load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern,  # type: ignore[arg-type]
                fuse_custom_config, fuser_method_mapping, is_qat)
        elif maybe_last_node is None or node_subpattern is MatchAllNode:
            env[node.name] = fused_graph.node_copy(node, load_arg)
        # node matched in patterns and is not root is removed here

    model = GraphModule(model, fused_graph)
    return model

def _find_matches(
        root: GraphModule,
        graph: Graph,
        pattern_to_fuse_handler_cls: Dict[Pattern, Callable],
) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]:
    modules = dict(root.named_modules())
    # node name -> (root_node, match_value)
    match_map : Dict[
        str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]] = {}
    # a map from node to the matched subpattern
    node_to_subpattern: Dict[Node, Any] = {}

    # TODO: dedup with quantization matching function in match_utils.py
    def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
        if isinstance(pattern, tuple):
            s, *args = pattern
            current_node_pattern: List[Node] = []
            apply_match(s, node, match, current_node_pattern, node_to_subpattern)
            for subpattern, arg in zip(args, node.args):
                apply_match(subpattern, arg, match, current_node_pattern, node_to_subpattern)
            matched_node_pattern.append(tuple(current_node_pattern))
        else:
            # the first pattern matches will take precedence
            if node.name not in match_map:
                matched_node_pattern.append(node)
                # MatchAllNode here is actually MatchAllInputNode which should not
                # be added to match_map
                if pattern is not MatchAllNode:
                    node_to_subpattern[node] = pattern
                    root_node, pattern, handler = match
                    match_map[node.name] = (root_node, pattern, matched_node_pattern, handler, node_to_subpattern)

    for node in reversed(graph.nodes):
        if node.name not in match_map:
            for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
                matched_node_pattern: List[Node] = []
                if _is_match(modules, node, pattern):
                    apply_match(pattern, node, (node, pattern, fuse_handler_cls(node)), matched_node_pattern, node_to_subpattern)
                    break

    return match_map