import inspect
from typing import Dict, List, Union
from torch import _C
from torch.onnx import _constants
from torch.onnx._internal import registration
class _TorchSchema:
def __init__(self, schema: Union[_C.FunctionSchema, str]) -> None:
if isinstance(schema, _C.FunctionSchema):
self.name: str = schema.name
self.overload_name: str = schema.overload_name
self.arguments: List[str] = [arg.name for arg in schema.arguments]
self.optional_arguments: List[str] = []
self.returns: List[str] = [ret.name for ret in schema.returns]
self.opsets: List[int] = []
else:
self.name = schema
self.overload_name = ""
self.arguments = []
self.optional_arguments = []
self.returns = []
self.opsets = []
def __str__(self) -> str:
s = (
f"{self.name}.{self.overload_name}("
+ ", ".join(self.arguments)
+ ") -> ("
+ ", ".join(self.returns)
+ ")"
+ " in opsets "
+ ", ".join(str(opset) for opset in self.opsets)
)
return s
def __hash__(self):
# TODO(thiagocrepaldi): handle overload_name?
return hash(self.name)
def __eq__(self, other) -> bool:
if not isinstance(other, _TorchSchema):
return False
# TODO(thiagocrepaldi): handle overload_name?
return self.name == other.name
def is_aten(self) -> bool:
return self.name.startswith("aten::")
def is_backward(self) -> bool:
return "backward" in self.name
def _symbolic_argument_count(func):
params = []
signature = inspect.signature(func)
optional_params = []
for name, parameter in signature.parameters.items():
if name in {"_outputs", "g"}:
continue
if parameter.default is parameter.empty:
optional_params.append(parameter)
else:
params.append(str(parameter))
return params
def all_forward_schemas() -> Dict[str, _TorchSchema]:
"""Returns schemas for all TorchScript forward ops."""
torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
return {schema.name: schema for schema in torch_schemas if not schema.is_backward()}
def all_symbolics_schemas() -> Dict[str, _TorchSchema]:
"""Returns schemas for all onnx supported ops."""
symbolics_schemas = {}
for name in registration.registry.all_functions():
func_group = registration.registry.get_function_group(name)
assert func_group is not None
symbolics_schema = _TorchSchema(name)
func = func_group.get(_constants.ONNX_MAX_OPSET)
if func is not None:
symbolics_schema.arguments = _symbolic_argument_count(func)
symbolics_schema.opsets = list(
range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1)
)
else:
# Only support opset < 9
func = func_group.get(7)
symbolics_schema.arguments = _symbolic_argument_count(func)
symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET))
symbolics_schemas[name] = symbolics_schema
return symbolics_schemas