Repository URL to install this package:
|
Version:
0.8.1 ▾
|
import asyncio
from typing import Any, Callable, Dict, List, Optional
import wrapt
from supertenant import consts
from supertenant.supermeter import _get_brain
from .thread_task_manager import ThreadTaskManager, task_holder
class AsyncIOTaskManager(ThreadTaskManager):
def __init__(self):
# type: () -> None
super().__init__()
self.patch_asyncio()
def get_asyncio_current_task_id(self, create_missing=True):
# type: (bool) -> int
try:
# Prevent failure when run from a thread without an event loop.
loop = asyncio.get_running_loop()
except RuntimeError:
return 0
# Python 3.6 and below need to use asyncio.Task.current_task(loop=loop)
task = asyncio.current_task(loop=loop)
if task is None:
return 0
if not getattr(task, "_st_created", False):
if create_missing:
task.add_done_callback(self.on_asyncio_task_done)
self.on_asyncio_task_create(task)
else:
return 0
return id(task)
def get_current_task_id(self, create_missing=True):
# type: (bool) -> int
async_task_id = self.get_asyncio_current_task_id(create_missing)
if async_task_id != 0:
return async_task_id
return super().get_current_task_id(create_missing)
def on_asyncio_task_create(self, task):
# type: (asyncio.Task[Any]) -> None
brain = _get_brain()
taskid = id(task)
if taskid != 0 and brain is not None:
# TODO: delete note?
# Must set attr before calling get_current_task_id to not get in to a loop
brain.create_task(self.get_current_task_id(False), taskid, consts.TASK_COROUTINE)
setattr(task, "_st_created", True)
def on_asyncio_task_done(self, task):
# type: (Optional[Any]) -> None
if task is None:
return
task_id = id(task)
brain = _get_brain()
if brain is not None:
brain.task_done(task_id)
def patch_asyncio(self):
# type: () -> None
@wrapt.patch_function_wrapper("asyncio", "ensure_future")
def ensure_future_with_supermeter(wrapped, instance, argv, kwargs):
# type: (Callable[..., asyncio.Task[Any]], Any, List[Any], Dict[str, Any]) -> asyncio.Task[Any]
task = wrapped(*argv, **kwargs)
task.add_done_callback(self.on_asyncio_task_done)
self.on_asyncio_task_create(task)
return task
if hasattr(asyncio, "create_task"):
@wrapt.patch_function_wrapper("asyncio", "create_task")
def create_task_with_supermeter(wrapped, instance, argv, kwargs):
# type: (Callable[..., asyncio.Task[Any]], Any, List[Any], Dict[str, Any]) -> asyncio.Task[Any]
task = wrapped(*argv, **kwargs)
task.add_done_callback(self.on_asyncio_task_done)
self.on_asyncio_task_create(task)
return task
# This is just to get the task id in to the thread local storage easily
class AsyncioID:
def __init__(self, id):
# type: (int) -> None
self.id = id
def get_id(self):
# type: () -> int
return self.id
@wrapt.patch_function_wrapper("asyncio", "BaseEventLoop.run_in_executor")
def run_in_executor_with_supermeter(wrapped, instance, argv, kwargs):
# type: (Callable[..., Any], Any, List[Any], Dict[str, Any]) -> Any
current_taskid = self.get_current_task_id(False)
# setting the task id on the thread local storage for the duration of the function call
@wrapt.decorator
def taskid_wrapper(wrapped, instance, argv, kwargs):
# type: (Callable[..., Any], Any, List[Any], Dict[str, Any]) -> Any
try:
id_to_restore = getattr(task_holder, "id", None)
task_holder.id = AsyncioID(current_taskid)
finally:
result = wrapped(*argv, **kwargs)
task_holder.id = id_to_restore
return result
if len(argv) >= 2:
argv = list(argv) # argv is a tuple so we need to convert to a list
argv[1] = taskid_wrapper(argv[1])
elif "func" in kwargs:
kwargs["func"] = taskid_wrapper(kwargs["func"])
return wrapped(*argv, **kwargs)