import torch
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
# Works for length 2 patterns with 1 module and 1 function/method.
def matches_module_function_pattern(pattern, node, modules):
if len(node.args) == 0:
return False
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
node, torch.fx.Node
):
return False
# the first node is call_module
if node.args[0].op != "call_module":
return False
if not isinstance(node.args[0].target, str):
return False
if node.args[0].target not in modules:
return False
if type(modules[node.args[0].target]) is not pattern[0]:
return False
# the second node is call_function or call_method
if node.op != "call_function" and node.op != "call_method":
return False
if node.target != pattern[1]:
return False
# make sure node.args[0] output is only used by current node.
if len(node.args[0].users) > 1:
return False
return True