Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
supermeter / supermeter / managers / django / middleware_py3.py
Size: Mime:
from __future__ import absolute_import

try:
    from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
except ImportError:
    pass


import wrapt

from supertenant.consts import ACTION_REJECT, INTEGRATION_MODULE_PYTHON_DJANGO
from supertenant.supermeter.logger import (
    log_instrumentation_failed,
    log_instrumentation_skipped,
    log_instrumentation_success,
    log_integration_module_exception,
)

DJANGO_SUPER_MIDDLEWARE = "supertenant.supermeter.managers.django.middleware_py3.SupertenantMiddleware"

try:
    import django
    from asgiref.sync import async_to_sync, iscoroutinefunction, markcoroutinefunction, sync_to_async

    _django_version = getattr(django, "__version__", "")

    from django.http.response import Http404

    from supertenant import consts
    from supertenant.supermeter.data import http_data, http_requests_utils
    from supertenant.supermeter.managers import http_manager
    from supertenant.supermeter.managers.actions import SyncActions
    from supertenant.supermeter.scope_manager import Span

    ST_SPAN_ATTR = "st_span"

    def ensure_callback_signature(name, callback):
        # type: (str, Any) -> Union[Callable[[django.http.HttpRequest], Optional[str]], Callable[[django.http.HttpRequest], Awaitable[Optional[str]]]]  # noqa: E501
        import inspect

        if not inspect.isfunction(callback) and not inspect.ismethod(callback):
            raise ValueError(f"SupertenantMiddleware {name} is not a function ({callback})")
        sig = inspect.signature(callback)
        if len(sig.parameters) == 0:
            if iscoroutinefunction(callback):

                async def async_callback(request):
                    # type: (django.http.HttpRequest) -> Optional[str]
                    result = await callback()  # type: Optional[str]
                    return result

                return async_callback
            else:
                return lambda request: callback()
        elif len(sig.parameters) > 1:
            raise ValueError(
                f"SupertenantMiddleware {name} should receive only one parameter but instead its signature is {sig}"
            )
        return callback

    class TagCallback(object):
        def __init__(self, name, callback):
            # type: (str, Callable[[django.http.HttpRequest], Optional[str]]) -> None
            _callback = ensure_callback_signature(name, callback)  # type: ignore
            if iscoroutinefunction(_callback):
                self.async_call = _callback
                self.sync_call = async_to_sync(_callback)  # type: ignore
            else:
                self.async_call = sync_to_async(_callback, thread_sensitive=True)  # type: ignore
                self.sync_call = _callback

    class HttpRejectResponse(django.http.HttpResponse):
        def __init__(self, status_code):
            # type: (Any) -> None
            super(HttpRejectResponse, self).__init__()
            self.status_code = status_code

    class SupertenantMiddleware(object):
        sync_capable = True
        async_capable = True
        tenant_callback = None  # type: Optional[Callable[[django.http.HttpRequest], Optional[str]]]
        resource_callback = None  # type: Optional[Callable[[django.http.HttpRequest], Optional[str]]]

        def __init__(self, get_response):
            # type: (Union[Callable[[django.http.HttpRequest], django.http.HttpResponse], Callable[[django.http.HttpRequest], Awaitable[django.http.HttpResponse]]]) -> None  # noqa: E501
            if get_response is None:
                raise ValueError("get_response must be provided.")
            self.tag_callbacks = {}  # type: Dict[str, TagCallback]
            self.get_response = get_response
            if iscoroutinefunction(get_response):
                markcoroutinefunction(self)
                self.call = self.async_call
            else:
                self.call = self.sync_call  # type: ignore
            if self.tenant_callback is not None:
                self.tag_callbacks[consts.LABEL_SUPERTENANT_TENANT_ID] = TagCallback(
                    "tenant_callback", self.tenant_callback
                )
            if self.resource_callback is not None:
                self.tag_callbacks[consts.LABEL_SUPERTENANT_RESOURCE_ID] = TagCallback(
                    "resource_callback", self.resource_callback
                )

        def __call__(self, request):
            # type: (django.http.HttpRequest) -> Any
            # Union[django.http.HttpResponse, Awaitable[django.http.HttpResponse]]
            return self.call(request)

        def sync_call(self, request):
            # type: (django.http.HttpRequest) -> django.http.HttpResponse
            if hasattr(request, ST_SPAN_ATTR):  # prevent double instrumentation (if can happen at all)
                return self.get_response(request)  # type: ignore
            tags = {k: v.sync_call(request) for k, v in self.tag_callbacks.items()}
            response = self.on_request(request, tags)
            if response is None:
                response = self.get_response(request)  # type: ignore
            return self.on_response(request, response)  # type: ignore

        async def async_call(self, request):
            # type: (django.http.HttpRequest) -> Awaitable[django.http.HttpResponse]
            if hasattr(request, ST_SPAN_ATTR):  # prevent double instrumentation (if can happen at all)
                return await self.get_response(request)  # type: ignore
            tags = {k: await v.async_call(request) for k, v in self.tag_callbacks.items()}  # type: ignore
            response = self.on_request(request, tags)
            if response is None:
                response = await self.get_response(request)  # type: ignore
            return self.on_response(request, response)  # type: ignore

        def on_request(self, request, tags):
            # type: (django.http.HttpRequest, Dict[str, str]) -> Optional[django.http.HttpResponse]
            action = None
            rc = None
            try:
                env = getattr(request, "META", None)
                if env is None:
                    env = getattr(request, "environ", {})
                before_data = http_data.HTTPServerData(INTEGRATION_MODULE_PYTHON_DJANGO)

                if request.method is not None:
                    before_data.set_method(request.method)

                if "PATH_INFO" in env:
                    before_data.set_path(env["PATH_INFO"])
                if "QUERY_STRING" in env and len(env["QUERY_STRING"]):
                    before_data.set_params(env["QUERY_STRING"])
                if "HTTP_HOST" in env:
                    before_data.set_host(env["HTTP_HOST"])
                if "HTTP_USER_AGENT" in env:
                    before_data.set_user_agent(env["HTTP_USER_AGENT"])

                before_data.set_headers_from_wsgi_env(env)
                for k, v in tags.items():
                    before_data.set_tag(k, v)

                if not before_data.get_tag(consts.LABEL_SUPERTENANT_RESOURCE_ID):
                    # use get_host() as the basis for a resource ID - it also takes into account if there's
                    # a X-Forwarded-For header.
                    try:
                        host = request.get_host()
                        before_data.set_integration_module_resource_id(host)
                    except Exception:
                        data_host = before_data.get_host()
                        if data_host is not None:
                            before_data.set_integration_module_resource_id(host)

                span_id, act, poll_key = http_manager.HTTPManager.open_span(data=before_data)
                if span_id is not None:
                    span = Span(span_id, http_data.HTTPServerData(INTEGRATION_MODULE_PYTHON_DJANGO))
                    setattr(request, ST_SPAN_ATTR, span)
                    action, action_desc = SyncActions.get_action(span_id, act, poll_key)
                    if action == ACTION_REJECT:
                        assert isinstance(span.finish_data, http_data.HTTPData)
                        rc = http_requests_utils.reject_http_request(action, action_desc, span.finish_data)
                        return HttpRejectResponse(rc or 429)
            except Exception as exc:
                log_integration_module_exception(INTEGRATION_MODULE_PYTHON_DJANGO, "on_request", exc)
            return None

        def on_response(self, request, response):
            # type: (django.http.HttpRequest, django.http.HttpResponse) -> django.http.HttpResponse
            try:
                span = getattr(request, ST_SPAN_ATTR, None)
                if span:
                    if hasattr(response, "status_code"):
                        if response.status_code == 429 or response.status_code >= 500:
                            span.finish_data.mark_error()
                        span.finish_data.set_status(response.status_code)
                    span.finish()
                    delattr(request, ST_SPAN_ATTR)
            except Exception as exc:
                log_integration_module_exception(INTEGRATION_MODULE_PYTHON_DJANGO, "on_response", exc)
            return response

        def process_exception(self, request, exception):
            # type: (django.http.HttpRequest, Exception) -> None
            if isinstance(exception, Http404):
                return

            span = getattr(request, ST_SPAN_ATTR, None)
            if span:
                try:
                    span.finish_data.mark_error()
                except Exception as exc:
                    log_integration_module_exception(INTEGRATION_MODULE_PYTHON_DJANGO, "on_exception", exc)

    def load_middleware_wrapper(wrapped, instance, args, kwargs):
        # type: (Callable[..., None], Any, List[Any], Dict[str, Any]) -> None
        try:
            from django.conf import settings

            # Django >=1.10 to <2.0 support old-style MIDDLEWARE_CLASSES so we
            # do as well here
            if getattr(settings, "MIDDLEWARE", None):
                if DJANGO_SUPER_MIDDLEWARE in settings.MIDDLEWARE:
                    return wrapped(*args, **kwargs)

                if isinstance(settings.MIDDLEWARE, tuple):
                    settings.MIDDLEWARE = (DJANGO_SUPER_MIDDLEWARE,) + settings.MIDDLEWARE
                elif isinstance(settings.MIDDLEWARE, list):
                    settings.MIDDLEWARE = [DJANGO_SUPER_MIDDLEWARE] + settings.MIDDLEWARE
                else:
                    log_instrumentation_failed(
                        INTEGRATION_MODULE_PYTHON_DJANGO, _django_version, {"reason": "MIDDLEWARE is not list or tuple"}
                    )
            else:
                log_instrumentation_failed(
                    INTEGRATION_MODULE_PYTHON_DJANGO,
                    _django_version,
                    {"reason": "MIDDLEWARE not found in settings"},
                )

            return wrapped(*args, **kwargs)
        except Exception as exc:
            log_instrumentation_failed(INTEGRATION_MODULE_PYTHON_DJANGO, _django_version, {"exc": exc})

    wrapt.wrap_function_wrapper("django.core.handlers.base", "BaseHandler.load_middleware", load_middleware_wrapper)
    log_instrumentation_success(INTEGRATION_MODULE_PYTHON_DJANGO, _django_version)

except ImportError as exc:
    log_instrumentation_skipped(INTEGRATION_MODULE_PYTHON_DJANGO, "", {"exc": exc})