Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
supermeter / supermeter / task_manager / asyncio_task_manager_py3.py
Size: Mime:
import asyncio
from typing import Any, Callable, Dict, List, Optional

import wrapt

from supertenant import consts
from supertenant.supermeter import _get_brain

from .thread_task_manager import ThreadTaskManager, task_holder


class AsyncIOTaskManager(ThreadTaskManager):
    def __init__(self):
        # type: () -> None
        super().__init__()
        self.patch_asyncio()

    def get_asyncio_current_task_id(self, create_missing=True):
        # type: (bool) -> int
        try:
            # Prevent failure when run from a thread without an event loop.
            loop = asyncio.get_running_loop()
        except RuntimeError:
            return 0
        # Python 3.6 and below need to use asyncio.Task.current_task(loop=loop)
        task = asyncio.current_task(loop=loop)
        if task is None:
            return 0
        if not getattr(task, "_st_created", False):
            if create_missing:
                task.add_done_callback(self.on_asyncio_task_done)
                self.on_asyncio_task_create(task)
            else:
                return 0
        return id(task)

    def get_current_task_id(self, create_missing=True):
        # type: (bool) -> int
        async_task_id = self.get_asyncio_current_task_id(create_missing)
        if async_task_id != 0:
            return async_task_id
        return super().get_current_task_id(create_missing)

    def on_asyncio_task_create(self, task):
        # type: (asyncio.Task[Any]) -> None
        brain = _get_brain()
        taskid = id(task)
        if taskid != 0 and brain is not None:
            # TODO: delete note?
            # Must set attr before calling get_current_task_id to not get in to a loop
            brain.create_task(self.get_current_task_id(False), taskid, consts.TASK_COROUTINE)
            setattr(task, "_st_created", True)

    def on_asyncio_task_done(self, task):
        # type: (Optional[Any]) -> None
        if task is None:
            return
        task_id = id(task)
        brain = _get_brain()
        if brain is not None:
            brain.task_done(task_id)

    def patch_asyncio(self):
        # type: () -> None
        @wrapt.patch_function_wrapper("asyncio", "ensure_future")
        def ensure_future_with_supermeter(wrapped, instance, argv, kwargs):
            # type: (Callable[..., asyncio.Task[Any]], Any, List[Any], Dict[str, Any]) -> asyncio.Task[Any]
            task = wrapped(*argv, **kwargs)
            task.add_done_callback(self.on_asyncio_task_done)
            self.on_asyncio_task_create(task)
            return task

        if hasattr(asyncio, "create_task"):

            @wrapt.patch_function_wrapper("asyncio", "create_task")
            def create_task_with_supermeter(wrapped, instance, argv, kwargs):
                # type: (Callable[..., asyncio.Task[Any]], Any, List[Any], Dict[str, Any]) -> asyncio.Task[Any]
                task = wrapped(*argv, **kwargs)
                task.add_done_callback(self.on_asyncio_task_done)
                self.on_asyncio_task_create(task)
                return task

        # This is just to get the task id in to the thread local storage easily
        class AsyncioID:
            def __init__(self, id):
                # type: (int) -> None
                self.id = id

            def get_id(self):
                # type: () -> int
                return self.id

        @wrapt.patch_function_wrapper("asyncio", "BaseEventLoop.run_in_executor")
        def run_in_executor_with_supermeter(wrapped, instance, argv, kwargs):
            # type: (Callable[..., Any], Any, List[Any], Dict[str, Any]) -> Any
            current_taskid = self.get_current_task_id(False)

            # setting the task id on the thread local storage for the duration of the function call
            @wrapt.decorator
            def taskid_wrapper(wrapped, instance, argv, kwargs):
                # type: (Callable[..., Any], Any, List[Any], Dict[str, Any]) -> Any
                try:
                    id_to_restore = getattr(task_holder, "id", None)
                    task_holder.id = AsyncioID(current_taskid)
                finally:
                    result = wrapped(*argv, **kwargs)
                task_holder.id = id_to_restore
                return result

            if len(argv) >= 2:
                argv = list(argv)  # argv is a tuple so we need to convert to a list
                argv[1] = taskid_wrapper(argv[1])
            elif "func" in kwargs:
                kwargs["func"] = taskid_wrapper(kwargs["func"])

            return wrapped(*argv, **kwargs)