Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _functorch / fx_minifier.py

import torch.fx as fx
import copy
import torch
import math
from typing import Callable, List
from functools import wraps, partial
from dataclasses import dataclass
from .compile_utils import get_placeholders, get_outputs

class ConcreteProp(torch.fx.Interpreter):
    def run_node(self, n):
        result = super().run_node(n)

        found_tensor = False

        def extract_tensor_meta(obj):
            if isinstance(obj, torch.Tensor):
                nonlocal found_tensor
                found_tensor = True
                return obj
            else:
                return obj

        from torch.fx.node import map_aggregate
        concrete_value = map_aggregate(result, extract_tensor_meta)
        if found_tensor:
            n.meta['concrete_value'] = concrete_value
        return result

    def propagate(self, *args):
        return super().run(*args)


# inplace modifies node/inps
def _convert_node_to_placeholder(node, inps):
    if node.op == 'output' or node.op == "placeholder":
        return
    node.op = 'placeholder'
    node.args = ()
    node.kwargs = {}
    node.target = node.name
    concrete_val = node.meta.get('concrete_value', None)
    if isinstance(concrete_val, torch.Tensor):
        inps.append(concrete_val)
    else:
        inps.append(torch.zeros(()))
        for tuple_user in list(node.users):
            _convert_node_to_placeholder(tuple_user, inps)

def dump_state(fx_g, inps):
    print(f"""
# Working Repro with {len(fx_g.graph.nodes)} nodes
inps = {[(i.shape, i.dtype, i.device.type) for i in inps]}
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device=device) for (shape, dtype, device) in inps]
{fx_g.code}
""")

@dataclass
class ReproState:
    graph: fx.Graph
    inps: List[torch.Tensor]

def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable = dump_state):
    """
    Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.

    Does 2 main strategies:
    1. Truncates suffix: Removes some suffix from the graph and sets a new output.
    2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
        tries replacing quarter of the graph, etc.

    >>> # xdoctest: +SKIP(failing)
    >>> failing_function = fx.symbolic_trace(f)
    >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))

    note: module_fails returns True if it fails.
    """
    failing_graph = fail_f.graph
    cur_size = len(failing_graph.nodes)

    num_queries = 0

    def deepcopy_fx_graph(fx_graph):
        return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph


    def graph_fails(graph, inps):
        nonlocal num_queries
        graph = copy.deepcopy(graph)
        num_queries += 1
        mod = fx.GraphModule(fail_f, graph)
        mod.graph.lint()
        return module_fails(mod, inps)

    ConcreteProp(fail_f).propagate(*inps)
    if not graph_fails(failing_graph, inps):
        raise RuntimeError("Input graph did not fail the tester")
    print(f"Started off with {cur_size} nodes")

    def _register_strategy(strategy: Callable, name: str):
        @wraps(strategy)
        def new_func(old_state: ReproState, granularity=1):
            print()
            print(f"Strategy: {name} (G: {granularity}) ({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)")
            new_state = strategy(deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity)
            if new_state is not None:
                new_nodes = len(new_state.graph.nodes)
                old_nodes = len(old_state.graph.nodes)
                new_inps = len(new_state.inps)
                old_inps = len(old_state.inps)
                new_outs = len(get_outputs(new_state.graph))
                old_outs = len(get_outputs(old_state.graph))
                progress_made = False
                if new_nodes < old_nodes:
                    progress_made = True
                    print(f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes")
                if new_inps > old_inps:
                    progress_made = True
                    print(f"SUCCESS: Went from {old_inps} to {new_inps} inputs")
                if new_outs < old_outs:
                    progress_made = True
                    print(f"SUCCESS: Went from {old_outs} to {new_outs} outputs")

                if not progress_made:
                    raise RuntimeError("Success raised but no progress made?")

                if not graph_fails(new_state.graph, new_state.inps):
                    print("WARNING: Something went wrong, not applying this minification")
                    return None
                return new_state
            else:
                print(f"FAIL: {name}")
            return None

        return new_func

    def register_strategy(name: str):
        return partial(_register_strategy, name=name)

    @register_strategy("Truncate suffix")
    def remove_suffix(cur_graph, cur_inps, granularity):
        tested = set()
        new_graph = fx.Graph()
        env = {}
        for idx, node in enumerate(cur_graph.nodes):
            new_node = new_graph.node_copy(node, lambda x: env[x])
            if node.op not in ['placeholder', 'output']:
                # If idx is divisible by (granularity * 2), it would have been checked already.
                if idx % granularity == 0 and (idx % (granularity * 2) != 0) and idx not in tested:
                    output_node = new_graph.output((new_node,))
                    if len(new_graph.nodes) < len(cur_graph.nodes) and graph_fails(new_graph, cur_inps):
                        return ReproState(new_graph, cur_inps)
                    else:
                        tested.add(idx)
                        new_graph.erase_node(output_node)
            env[node] = new_node
        return None

    @register_strategy("Remove outputs")
    def remove_outputs(cur_graph, cur_inps, granularity):
        granularity = max(1, granularity // 2)
        for idx, node in enumerate(cur_graph.nodes):
            node.idx = idx
            if node.op == 'output':
                output = node
                break

        output_args = sorted(output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9))
        if len(output_args) == 1:
            return None

        for idx in range(0, len(output_args), granularity):
            output.args = (output_args[:idx] + output_args[idx + granularity:],)
            if graph_fails(cur_graph, cur_inps):
                return ReproState(cur_graph, cur_inps)
        return None


    def remove_unused_inputs_unchecked(cur_state: ReproState):
        cur_graph = cur_state.graph
        cur_inps = cur_state.inps
        ph_nodes = get_placeholders(cur_graph)
        assert len(ph_nodes) == len(cur_inps)

        new_inps = []
        for idx in range(len(ph_nodes)):
            if len(ph_nodes[idx].users) == 0:
                cur_graph.erase_node(ph_nodes[idx])
            else:
                new_inps.append(cur_inps[idx])
        if len(new_inps) < len(cur_inps):
            return ReproState(cur_graph, new_inps)
        return None

    def remove_unused_inputs_checked(cur_state: ReproState):
        new_state = remove_unused_inputs_unchecked(cur_state)
        if new_state is not None and graph_fails(new_state.graph, new_state.inps):
            return new_state
        return None

    def _remove_unused_wrapper(cur_graph, cur_inps, granularity):
        return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps))

    remove_unused_inputs = register_strategy("Remove unused inputs")(_remove_unused_wrapper)

    @register_strategy("Eliminate dead code")
    def eliminate_dead_code(cur_graph, cur_inps, granularity):
        if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps):
            return ReproState(cur_graph, cur_inps)
        return None


    def _consolidate_placeholders(cur_graph):
        new_graph = fx.Graph()
        env = {}
        for node in cur_graph.nodes:
            if node.op == 'placeholder':
                new_node = new_graph.node_copy(node, lambda x: env[x])
                env[node] = new_node

        for node in cur_graph.nodes:
            if node.op != 'placeholder':
                new_node = new_graph.node_copy(node, lambda x: env[x])
                env[node] = new_node
        return new_graph

    @register_strategy("Delta Debugging")
    def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity):
        num_nodes = len(cur_graph.nodes)
        for start_range in range(0, num_nodes, granularity):
            is_removing = False
            new_graph = deepcopy_fx_graph(cur_graph)
            new_inps = cur_inps[:]
            end_range = min(num_nodes, start_range + granularity)
            for idx in range(start_range, end_range):
                new_node = list(new_graph.nodes)[idx]
                if new_node.op not in ['placeholder', 'output']:
                    is_removing = True
                    _convert_node_to_placeholder(new_node, new_inps)
            if not is_removing:
                continue
            new_graph = _consolidate_placeholders(new_graph)
            new_state = remove_unused_inputs_unchecked(ReproState(new_graph, new_inps))
            if new_state is None:
                new_state = ReproState(new_graph, new_inps)
            if graph_fails(new_state.graph, new_state.inps):
                return ReproState(new_state.graph, new_state.inps)

        return None

    failing_state = ReproState(failing_graph, inps)

    def try_granularity(failing_state, granularity, use_non_granular):
        print(f"Trying granularity {granularity}")

        strategies = []
        num_nodes = len(failing_state.graph.nodes)
        num_outputs = len(get_outputs(failing_state.graph))
        if num_outputs > num_nodes // 2:
            strategies += [remove_outputs]

        if use_non_granular:
            strategies += [eliminate_dead_code, remove_unused_inputs]

        strategies += [remove_suffix, delta_debugging]

        for strategy in strategies:
            new_state = strategy(failing_state, granularity)
            if new_state is not None:
                return new_state
        return None

    while True:
        dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps)
        granularity = int(2**(math.floor(math.log2(len(failing_state.graph.nodes)))))
        new_state = try_granularity(failing_state, granularity, use_non_granular=True)
        if new_state is not None:
            failing_state = new_state
            continue

        granularity //= 2
        has_progress = False
        while granularity >= 1:
            new_state = try_granularity(failing_state, granularity, use_non_granular=False)
            if new_state is not None:
                failing_state = new_state
                has_progress = True
                break
            granularity //= 2
        if has_progress:
            continue

        new_state = remove_outputs(failing_state, 1)
        if new_state is not None:
            failing_state = new_state
            continue

        break

    if not graph_fails(failing_state.graph, failing_state.inps):
        raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing")

    print(f"Made {num_queries} queries")
    failing_fx = fx.GraphModule(fail_f, failing_state.graph)
    dump_state(failing_fx, failing_state.inps)
    print("Wrote minimal repro out to repro.py")
    return failing_fx, failing_state.inps