Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
from __future__ import absolute_import

try:
    # TODO: Either find better typing hints for wrapped functions
    # OR change function signature to use args defined for them
    from typing import Any, Dict, List
except ImportError:
    pass

from supertenant import consts
from supertenant.supermeter import _get_brain
from supertenant.supermeter.logger import IntegrationModuleLog

_log = IntegrationModuleLog(consts.INTEGRATION_MODULE_PYTHON_CELERY)

try:
    import sys

    if sys.version_info >= (3, 0):
        from urllib.parse import urlparse
    else:
        from urlparse import urlparse

    # control behavior based on env keys
    from os import environ

    import celery
    from celery import registry  # type: ignore
    from celery import signals

    from supertenant.supermeter.data.celery_data import CeleryClientData, CeleryData, CeleryWorkerData
    from supertenant.supermeter.data.sanitize import (
        bfs_mask,
        mask_str_full,
        mask_str_none,
        mask_str_partial,
        mask_value,
    )
    from supertenant.supermeter.managers.celery.catalog import (
        get_task_id,
        task_catalog_get,
        task_catalog_pop,
        task_catalog_push,
    )
    from supertenant.supermeter.managers.celery.celery_manager import ClientWorkerManager
    from supertenant.supermeter.scope_manager import Span

    ARGS_TAGS = environ.get("SUPERTENANT_SUPERMETER_CELERY__ARGS", "t") == "t"
    try:
        KWARGS_TAGS_DEPTH = int(environ.get("SUPERTENANT_SUPERMETER_CELERY__KWARGS", "2"))
    except ValueError:
        KWARGS_TAGS_DEPTH = 0
    MASK_VALUES = environ.get("SUPERTENANT_SUPERMETER_CELERY__MASK_VALUES", "no")
    try:
        MAX_KEYS = int(environ.get("SUPERTENANT_SUPERMETER_CELERY__MAX_KEYS", "32"))
    except ValueError:
        MAX_KEYS = 32

    if MASK_VALUES == "yes":
        _mask_str = mask_str_partial
    elif MASK_VALUES == "no":
        _mask_str = mask_str_none
    else:
        _mask_str = mask_str_full

    def add_args(data, args):
        # type: (CeleryData, List[Any]) -> None
        if not ARGS_TAGS:
            return
        for i, arg in enumerate(args):
            if i >= MAX_KEYS:
                break
            data.set_arg(i, mask_value(arg, _mask_str))

    def add_kwargs(data, kwargs):
        # type: (CeleryData, Dict[str, Any]) -> None
        if KWARGS_TAGS_DEPTH == 0:
            return
        for i, (p, v) in enumerate(bfs_mask(KWARGS_TAGS_DEPTH, kwargs, _mask_str)):
            if i >= MAX_KEYS:
                break
            data.set_kwarg(p, v)

    def add_broker_tags(data, task):
        # type: (CeleryData, str) -> bool
        try:
            app = getattr(task, "app", None)
            if not app:
                return False
            conf = getattr(app, "conf", None)
            if not conf:
                return False
            broker_url = conf.get("broker_url")
            if broker_url is None:
                broker_url = conf.get("BROKER_URL")
            if not broker_url:
                _log.debug("failed to find broker URL")
                return False
            url = urlparse(broker_url)
            url_scheme = str(url.scheme)
            data.set_scheme(url_scheme)

            host = "localhost" if url.hostname is None else url.hostname
            data.set_host(host)
            data.set_integration_module_resource_id(host)

            if url.port is None:
                # Set default port if not specified
                if url_scheme == "redis":
                    data.set_port("6379")
                elif "amqp" in url_scheme:
                    data.set_port("5672")
                elif "sqs" in url_scheme:
                    data.set_port("443")
            else:
                data.set_port(str(url.port))
            return True
        except Exception as exc:
            _log.exception("add_broker_tags", exc)
            return False

    @signals.task_prerun.connect
    def task_prerun(*args, **kwargs):
        # type: (Any, Any) -> None
        try:
            task = kwargs.get("sender")
            task_id = kwargs.get("task_id")
            _log.trace("task_prerun", extra={"task": str(task), "task_id": str(task_id)})

            if task is None or task_id is None:
                _log.debug("task_prerun: task/task_id not found")
                return
            task = registry.tasks.get(task.name)

            before_data = CeleryWorkerData()
            before_data.set_task(task.name)
            before_data.set_task_id(task_id)
            if not add_broker_tags(before_data, task):
                before_data.set_host("localhost")
            add_args(before_data, kwargs.get("args", []))
            add_kwargs(before_data, kwargs.get("kwargs", {}))

            headers = task.request.get("headers")
            if headers is not None:
                headers = headers.get("headers", {})
                before_data.extract_routing_data(headers)

            span_id, _, _ = ClientWorkerManager.open_span(before_data)
            if span_id is not None:
                span = Span(span_id, CeleryWorkerData())
                # Store the span on the task to eventually close it out on the "after" signal
                _log.trace(
                    "task_prerun pushing span_id to catalog",
                    extra={"task": str(task), "task_id": str(task_id), "span_id": repr(span_id)},
                )
                task_catalog_push(task, task_id, span, True)
            else:
                _log.debug("task_prerun: span is None", extra={"task": str(task), "task_id": str(task_id)})
        except Exception as exc:
            _log.exception("task_prerun", exc)

    _log.trace("instrumented task_prerun")

    @signals.task_postrun.connect
    def task_postrun(*args, **kwargs):
        # type: (Any, Any) -> None
        try:
            task = kwargs.get("sender")
            task_id = kwargs.get("task_id")
            _log.trace("task_postrun", extra={"task": str(task), "task_id": str(task_id)})

            if task is None or task_id is None:
                _log.debug("task_postrun: sender/task_id not found")
                return
            span = task_catalog_pop(task, task_id, True)
            if span is not None:
                _log.trace(
                    "task_postrun closing span_id",
                    extra={"task": str(task), "task_id": str(task_id), "span_id": repr(span.span_id)},
                )
                span.finish()
            else:
                _log.debug("task_postrun: span is None", extra={"task": str(task), "task_id": str(task_id)})
        except Exception as exc:
            _log.exception("task_postrun", exc)

    _log.trace("instrumented task_postrun")

    @signals.task_failure.connect
    def task_failure(*args, **kwargs):
        # type: (Any, Any) -> None
        try:
            task = kwargs.get("sender")
            task_id = kwargs.get("task_id")
            _log.trace("task_failure", extra={"task": str(task), "task_id": str(task_id)})

            if task is None or task_id is None:
                _log.debug("task_failure: task/task_id not found")
                return
            span = task_catalog_get(task, task_id, True)
            if span is not None:
                _log.trace(
                    "task_failure set reason for span_id",
                    extra={"task": str(task), "task_id": str(task_id), "span_id": repr(span.span_id)},
                )
                if isinstance(span.finish_data, CeleryWorkerData):
                    span.finish_data.set_success("false")
                    span.finish_data.mark_error()
            else:
                _log.debug("task_falure: span is None", extra={"task": str(task), "task_id": str(task_id)})
        except Exception as exc:
            _log.exception("task_failure", exc)

    _log.trace("instrumented task_failure")

    @signals.task_retry.connect
    def task_retry(*args, **kwargs):
        # type: (Any, Any) -> None
        try:
            task = kwargs.get("sender")
            task_id = kwargs.get("task_id", None)
            _log.trace("task_retry", extra={"task": str(task), "task_id": str(task_id)})

            if task is None or task_id is None:
                _log.debug("task_retry: task/task_id not found")
                return
            span = task_catalog_get(task, task_id, True)
            if span is not None:
                reason = kwargs.get("reason", None)
                _log.trace(
                    "task_retry set reason for span_id",
                    extra={
                        "task": str(task),
                        "task_id": str(task_id),
                        "span_id": repr(span.span_id),
                        "reason": str(reason),
                    },
                )
                if reason is not None:
                    if isinstance(span.finish_data, CeleryWorkerData):
                        span.finish_data.set_retry_reason(reason)
            else:
                _log.debug("task_retry: span is None", extra={"task": str(task), "task_id": str(task_id)})
        except Exception as exc:
            _log.exception("task_retry", exc)

    _log.trace("instrumented task_retry")

    @signals.before_task_publish.connect
    def before_task_publish(*args, **kwargs):
        # type: (Any, Any) -> None
        try:
            task_name = kwargs.get("sender")
            headers = kwargs.get("headers")
            body = kwargs.get("body")
            if task_name is None or (headers is None and body is None):
                _log.debug("before_task_publish: sender/headers/body not found")
                return

            task = registry.tasks.get(task_name)
            task_id = get_task_id(headers, body)
            _log.trace("before_task_publish", extra={"task": str(task), "task_id": str(task_id)})
            if task is None or task_id is None:
                _log.debug("before_task_publish: task/task_id not found")
                return

            before_data = CeleryClientData()
            before_data.set_task(task_name)
            before_data.set_task_id(task_id)
            exchange = kwargs.get("exchange")
            if exchange is not None:
                before_data.set_exchange(exchange)
            routing_key = kwargs.get("routing_key")
            if routing_key is not None:
                before_data.set_routing_key(routing_key)
            if not add_broker_tags(before_data, task):
                before_data.set_host("localhost")
            if isinstance(body, tuple):
                # version 2: https://docs.celeryq.dev/en/stable/internals/protocol.html#message-protocol-task-v2
                add_args(before_data, body[0])
                add_kwargs(before_data, body[1])
            elif isinstance(body, dict):
                # version 1: https://docs.celeryq.dev/en/stable/internals/protocol.html#message-protocol-task-v1
                add_args(before_data, body.get("args", []))
                add_kwargs(before_data, body.get("kwargs", {}))

            extra_headers = {}  # type: Dict[str, str]
            before_data.inject_routing_data(extra_headers)

            # https://github.com/celery/celery/issues/4875
            task_headers = kwargs.get("headers") or {}
            task_headers.setdefault("headers", {})
            task_headers["headers"].update(extra_headers)
            kwargs["headers"] = task_headers

            span_id, _, _ = ClientWorkerManager.open_span(before_data)
            if span_id is not None:
                span = Span(span_id, CeleryClientData())
                # Store the span on the task to eventually close it out on the "after" signal
                _log.trace(
                    "before_task_publish pushing span_id to catalog",
                    extra={"task": str(task), "task_id": str(task_id), "span_id": repr(span_id)},
                )
                task_catalog_push(task, task_id, span, False)
            else:
                _log.debug("before_task_publish: span is None", extra={"task": str(task), "task_id": str(task_id)})
        except Exception as exc:
            _log.exception("before_task_publish", exc)

    _log.trace("instrumented before_task_publish")

    @signals.after_task_publish.connect
    def after_task_publish(*args, **kwargs):
        # type: (Any, Any) -> None
        try:
            _log.trace("after_task_publish")
            task = registry.tasks.get(kwargs.get("sender"))
            task_id = get_task_id(kwargs.get("headers"), kwargs.get("body"))
            if task is None or task_id is None:
                _log.debug("after_task_publish: task/task_id not found")
                return

            span = task_catalog_pop(task, task_id, False)
            if span is not None:
                _log.trace(
                    "after_task_publish closing span_id",
                    extra={"task": str(task), "task_id": str(task_id), "span_id": repr(span.span_id)},
                )
                span.finish()
            else:
                _log.debug("after_task_publish: span is None", extra={"task": str(task), "task_id": str(task_id)})
        except Exception as exc:
            _log.exception("after_task_publish", exc)

    _log.trace("instrumented after_task_publish")

    @signals.worker_process_shutdown.connect
    def worker_process_shutdown(*args, **kwargs):
        # type: (Any, Any) -> None
        try:
            brain = _get_brain()
            if brain is not None:
                _log.info("worker_process_shutdown: shutting down")
                brain.shutdown()
        except Exception as exc:
            _log.exception("worker_process_shutdown", exc)

    _log.instrumentation_success(getattr(celery, "__version__"))
except ImportError as exc:
    _log.instrumentation_skipped("", {"exc": exc})