import dataclasses
import sys
import types
from typing import (
Callable,
Dict,
List,
NamedTuple,
Optional,
OrderedDict,
Protocol,
Union,
)
if sys.version_info >= (3, 11):
from torch._C._dynamo import eval_frame
DynamoFrameType = eval_frame._PyInterpreterFrame
else:
DynamoFrameType = types.FrameType
class GuardFail(NamedTuple):
# A string repr of the piece of failed guard code we eval-ed
reason: str
# A code object where we failed a guard
orig_code: types.CodeType
class GuardFn(Protocol):
closure_vars: OrderedDict[str, object]
args: List[str]
code_parts: List[str]
verbose_code_parts: List[str]
global_scope: Dict[str, object]
guard_fail_fn: Optional[Callable[[GuardFail], None]]
# maps locals of user function to bool
def __call__(self, *maybe_dotzero: object, **f_locals: object) -> bool:
...
@dataclasses.dataclass
class GuardedCode:
code: types.CodeType
check_fn: GuardFn
class DynamoCallbackFn(Protocol):
def __call__(
self,
frame: DynamoFrameType,
cache_size: int,
) -> Optional[GuardedCode]:
...
DynamoCallback = Union[DynamoCallbackFn, None, bool]
class DynamoGuardHook(Protocol):
def __call__(
self,
guard_fn: GuardFn,
code: types.CodeType,
f_locals: Dict[str, object],
last: bool,
) -> None:
...