# Extra utilities for working with context managers that should have been
# in the standard library but are not
import functools
import inspect
import warnings
import sys
from typing import Any, Callable, TypeVar, cast
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
# 'no_grad' and 'enable_grad').
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
FuncType = Callable[..., Any]
F = TypeVar('F', bound=FuncType)
def _wrap_generator(ctx_factory, func):
"""
Wrap each generator invocation with the context manager factory.
The input should be a function that returns a context manager,
not a context manager itself, to handle one-shot context managers.
"""
@functools.wraps(func)
def generator_context(*args, **kwargs):
gen = func(*args, **kwargs)
# Generators are suspended and unsuspended at `yield`, hence we
# make sure the grad mode is properly set every time the execution
# flow returns into the wrapped generator and restored when it
# returns through our `yield` to our caller (see PR #49017).
try:
# Issuing `None` to a generator fires it up
with ctx_factory():
response = gen.send(None)
while True:
try:
# Forward the response to our caller and get its next request
request = yield response
except GeneratorExit:
# Inform the still active generator about its imminent closure
with ctx_factory():
gen.close()
raise
except BaseException:
# Propagate the exception thrown at us by the caller
with ctx_factory():
response = gen.throw(*sys.exc_info())
else:
# Pass the last request to the generator and get its response
with ctx_factory():
response = gen.send(request)
# We let the exceptions raised above by the generator's `.throw` or
# `.send` methods bubble up to our caller, except for StopIteration
except StopIteration as e:
# The generator informed us that it is done: take whatever its
# returned value (if any) was and indicate that we're done too
# by returning it (see docs for python's return-statement).
return e.value
return generator_context
def context_decorator(ctx, func):
"""
Like contextlib.ContextDecorator, but:
1. Is done by wrapping, rather than inheritance, so it works with context
managers that are implemented from C and thus cannot easily inherit from
Python classes
2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
3. Errors out if you try to wrap a class, because it is ambiguous whether
or not you intended to wrap only the constructor
The input argument can either be a context manager (in which case it must
be a multi-shot context manager that can be directly invoked multiple times)
or a callable that produces a context manager.
"""
assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
f"Passed in {ctx} is both callable and also a valid context manager "
"(has __enter__), making it ambiguous which interface to use. If you "
"intended to pass a context manager factory, rewrite your call as "
"context_decorator(lambda: ctx()); if you intended to pass a context "
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
)
if not callable(ctx):
def ctx_factory():
return ctx
else:
ctx_factory = ctx
if inspect.isclass(func):
raise RuntimeError(
"Cannot decorate classes; it is ambiguous whether or not only the "
"constructor or all methods should have the context manager applied; "
"additionally, decorating a class at definition-site will prevent "
"use of the identifier as a conventional type. "
"To specify which methods to decorate, decorate each of them "
"individually."
)
if inspect.isgeneratorfunction(func):
return _wrap_generator(ctx_factory, func)
@functools.wraps(func)
def decorate_context(*args, **kwargs):
with ctx_factory():
return func(*args, **kwargs)
return decorate_context
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator"""
def __call__(self, orig_func: F) -> F:
if inspect.isclass(orig_func):
warnings.warn("Decorating classes is deprecated and will be disabled in "
"future versions. You should only decorate functions or methods. "
"To preserve the current behavior of class decoration, you can "
"directly decorate the `__init__` method and nothing else.")
func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs))
else:
func = orig_func
return cast(F, context_decorator(self.clone, func))
def __enter__(self) -> None:
raise NotImplementedError
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
raise NotImplementedError
def clone(self):
# override this method if your children class takes __init__ parameters
return self.__class__()