Repository URL to install this package:
|
Version:
2.4.1 ▾
|
# mypy: allow-untyped-defs
"""Utilities for lowering subgraphs used by higher order operators
"""
import functools
import operator
from dataclasses import dataclass
from typing import List, Optional, TypeVar
import torch
from . import ir
from .exc import SubgraphLoweringException
from .ops_handler import SimpleCSEHandler
from .virtualized import ops, V, WrapperHandler
T = TypeVar("T")
class PointwiseSubgraphLowering(torch.fx.Interpreter):
graph_outputs: Optional[List[ir.IRNode]]
def __init__(
self,
gm: torch.fx.GraphModule,
root_graph_lowering: "torch._inductor.graph.GraphLowering",
):
super().__init__(gm)
self.graph_outputs = None
self.root_graph = root_graph_lowering
@property
def sizevars(self):
return self.root_graph.sizevars
def mark_buffer_mutated(self, name):
raise SubgraphLoweringException("Mutations are not supported in this context")
def register_buffer(self, data):
raise SubgraphLoweringException(
"Buffer creation is not supported in this context"
)
def call_function(self, target, args, kwargs):
from .lowering import lowerings
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
return super().call_function(target, args, kwargs)
assert isinstance(target, torch._ops.OpOverload)
if target not in lowerings:
raise SubgraphLoweringException(
f"{target} not supported in subgraph, (missing lowering)"
)
if torch.Tag.pointwise not in target.tags:
raise SubgraphLoweringException(
f"Only pointwise operators are supported in this context, but got {target}"
)
return lowerings[target](*args, **kwargs)
def output(self, target, args, kwargs):
assert len(args) == 1
self.graph_outputs = args[0]
@dataclass
class InputDescriptor:
dtype: torch.dtype
device: torch.device
class TracingOpsHandler(WrapperHandler[T]):
def __init__(self, tracer, num_inputs):
parent = tracer.create_proxy("placeholder", "ops", (), {})
super().__init__(parent)
self.tracer = tracer
self.placeholders = [
self.tracer.create_proxy("placeholder", f"input{i}", (), {})
for i in range(num_inputs)
]
def placeholder(self, idx):
return self.placeholders[idx]
def output(self, *args):
return self.tracer.create_node(
"output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {}
)
def lower_pointwise_subgraph(subgraph: ir.Subgraph, inputs: List[InputDescriptor]):
# Lower subgraph to ir.Pointwise nodes
def fake_inner_fn(loop_idx, input_idx):
return ops.placeholder(input_idx)
graph_inputs = [
ir.Pointwise.create(
device=desc.device,
dtype=desc.dtype,
inner_fn=functools.partial(fake_inner_fn, input_idx=i),
ranges=[],
)
for i, desc in enumerate(inputs)
]
gm = subgraph.graph_module
pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph)
with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
pw_subgraph.run(*graph_inputs)
# Combine multiple pointwise computations into a single graph module
# Do this by tracing through each individually and doing CSE
tracer = torch.fx.Tracer()
tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
trace_ops = SimpleCSEHandler(TracingOpsHandler(tracer, len(inputs)))
assert pw_subgraph.graph_outputs is not None
with V.set_ops_handler(trace_ops):
output_irs = []
for out_var in pw_subgraph.graph_outputs:
assert isinstance(out_var, ir.TensorBox), type(out_var)
assert out_var.get_size() == []
assert isinstance(out_var.data, ir.StorageBox)
assert isinstance(out_var.data.data, ir.Pointwise)
idx = ()
ir_out = out_var.data.data.inner_fn(idx)
output_irs.append(ir_out)
ops.output(*output_irs)
lowered_gm = torch.fx.GraphModule({}, tracer.graph)
def inner_fn(*args, **kwargs):
return lowered_gm(V.get_ops_handler(), *args, **kwargs)
return inner_fn