Repository URL to install this package:
|
Version:
0.8.1 ▾
|
from __future__ import absolute_import
try:
# TODO: Either find better typing hints for wrapped functions
# OR change function signature to use args defined for them
from typing import Any, Dict, List
except ImportError:
pass
from supertenant import consts
from supertenant.supermeter import _get_brain
from supertenant.supermeter.logger import IntegrationModuleLog
_log = IntegrationModuleLog(consts.INTEGRATION_MODULE_PYTHON_CELERY)
try:
import sys
if sys.version_info >= (3, 0):
from urllib.parse import urlparse
else:
from urlparse import urlparse
# control behavior based on env keys
from os import environ
import celery
from celery import registry # type: ignore
from celery import signals
from supertenant.supermeter.data.celery_data import CeleryClientData, CeleryData, CeleryWorkerData
from supertenant.supermeter.data.sanitize import (
bfs_mask,
mask_str_full,
mask_str_none,
mask_str_partial,
mask_value,
)
from supertenant.supermeter.managers.celery.catalog import (
get_task_id,
task_catalog_get,
task_catalog_pop,
task_catalog_push,
)
from supertenant.supermeter.managers.celery.celery_manager import ClientWorkerManager
from supertenant.supermeter.scope_manager import Span
ARGS_TAGS = environ.get("SUPERTENANT_SUPERMETER_CELERY__ARGS", "t") == "t"
try:
KWARGS_TAGS_DEPTH = int(environ.get("SUPERTENANT_SUPERMETER_CELERY__KWARGS", "2"))
except ValueError:
KWARGS_TAGS_DEPTH = 0
MASK_VALUES = environ.get("SUPERTENANT_SUPERMETER_CELERY__MASK_VALUES", "no")
try:
MAX_KEYS = int(environ.get("SUPERTENANT_SUPERMETER_CELERY__MAX_KEYS", "32"))
except ValueError:
MAX_KEYS = 32
if MASK_VALUES == "yes":
_mask_str = mask_str_partial
elif MASK_VALUES == "no":
_mask_str = mask_str_none
else:
_mask_str = mask_str_full
def add_args(data, args):
# type: (CeleryData, List[Any]) -> None
if not ARGS_TAGS:
return
for i, arg in enumerate(args):
if i >= MAX_KEYS:
break
data.set_arg(i, mask_value(arg, _mask_str))
def add_kwargs(data, kwargs):
# type: (CeleryData, Dict[str, Any]) -> None
if KWARGS_TAGS_DEPTH == 0:
return
for i, (p, v) in enumerate(bfs_mask(KWARGS_TAGS_DEPTH, kwargs, _mask_str)):
if i >= MAX_KEYS:
break
data.set_kwarg(p, v)
def add_broker_tags(data, task):
# type: (CeleryData, str) -> bool
try:
app = getattr(task, "app", None)
if not app:
return False
conf = getattr(app, "conf", None)
if not conf:
return False
broker_url = conf.get("broker_url")
if broker_url is None:
broker_url = conf.get("BROKER_URL")
if not broker_url:
_log.debug("failed to find broker URL")
return False
url = urlparse(broker_url)
url_scheme = str(url.scheme)
data.set_scheme(url_scheme)
host = "localhost" if url.hostname is None else url.hostname
data.set_host(host)
data.set_integration_module_resource_id(host)
if url.port is None:
# Set default port if not specified
if url_scheme == "redis":
data.set_port("6379")
elif "amqp" in url_scheme:
data.set_port("5672")
elif "sqs" in url_scheme:
data.set_port("443")
else:
data.set_port(str(url.port))
return True
except Exception as exc:
_log.exception("add_broker_tags", exc)
return False
@signals.task_prerun.connect
def task_prerun(*args, **kwargs):
# type: (Any, Any) -> None
try:
task = kwargs.get("sender")
task_id = kwargs.get("task_id")
_log.trace("task_prerun", extra={"task": str(task), "task_id": str(task_id)})
if task is None or task_id is None:
_log.debug("task_prerun: task/task_id not found")
return
task = registry.tasks.get(task.name)
before_data = CeleryWorkerData()
before_data.set_task(task.name)
before_data.set_task_id(task_id)
if not add_broker_tags(before_data, task):
before_data.set_host("localhost")
add_args(before_data, kwargs.get("args", []))
add_kwargs(before_data, kwargs.get("kwargs", {}))
headers = task.request.get("headers")
if headers is not None:
headers = headers.get("headers", {})
before_data.extract_routing_data(headers)
span_id, _, _ = ClientWorkerManager.open_span(before_data)
if span_id is not None:
span = Span(span_id, CeleryWorkerData())
# Store the span on the task to eventually close it out on the "after" signal
_log.trace(
"task_prerun pushing span_id to catalog",
extra={"task": str(task), "task_id": str(task_id), "span_id": repr(span_id)},
)
task_catalog_push(task, task_id, span, True)
else:
_log.debug("task_prerun: span is None", extra={"task": str(task), "task_id": str(task_id)})
except Exception as exc:
_log.exception("task_prerun", exc)
_log.trace("instrumented task_prerun")
@signals.task_postrun.connect
def task_postrun(*args, **kwargs):
# type: (Any, Any) -> None
try:
task = kwargs.get("sender")
task_id = kwargs.get("task_id")
_log.trace("task_postrun", extra={"task": str(task), "task_id": str(task_id)})
if task is None or task_id is None:
_log.debug("task_postrun: sender/task_id not found")
return
span = task_catalog_pop(task, task_id, True)
if span is not None:
_log.trace(
"task_postrun closing span_id",
extra={"task": str(task), "task_id": str(task_id), "span_id": repr(span.span_id)},
)
span.finish()
else:
_log.debug("task_postrun: span is None", extra={"task": str(task), "task_id": str(task_id)})
except Exception as exc:
_log.exception("task_postrun", exc)
_log.trace("instrumented task_postrun")
@signals.task_failure.connect
def task_failure(*args, **kwargs):
# type: (Any, Any) -> None
try:
task = kwargs.get("sender")
task_id = kwargs.get("task_id")
_log.trace("task_failure", extra={"task": str(task), "task_id": str(task_id)})
if task is None or task_id is None:
_log.debug("task_failure: task/task_id not found")
return
span = task_catalog_get(task, task_id, True)
if span is not None:
_log.trace(
"task_failure set reason for span_id",
extra={"task": str(task), "task_id": str(task_id), "span_id": repr(span.span_id)},
)
if isinstance(span.finish_data, CeleryWorkerData):
span.finish_data.set_success("false")
span.finish_data.mark_error()
else:
_log.debug("task_falure: span is None", extra={"task": str(task), "task_id": str(task_id)})
except Exception as exc:
_log.exception("task_failure", exc)
_log.trace("instrumented task_failure")
@signals.task_retry.connect
def task_retry(*args, **kwargs):
# type: (Any, Any) -> None
try:
task = kwargs.get("sender")
task_id = kwargs.get("task_id", None)
_log.trace("task_retry", extra={"task": str(task), "task_id": str(task_id)})
if task is None or task_id is None:
_log.debug("task_retry: task/task_id not found")
return
span = task_catalog_get(task, task_id, True)
if span is not None:
reason = kwargs.get("reason", None)
_log.trace(
"task_retry set reason for span_id",
extra={
"task": str(task),
"task_id": str(task_id),
"span_id": repr(span.span_id),
"reason": str(reason),
},
)
if reason is not None:
if isinstance(span.finish_data, CeleryWorkerData):
span.finish_data.set_retry_reason(reason)
else:
_log.debug("task_retry: span is None", extra={"task": str(task), "task_id": str(task_id)})
except Exception as exc:
_log.exception("task_retry", exc)
_log.trace("instrumented task_retry")
@signals.before_task_publish.connect
def before_task_publish(*args, **kwargs):
# type: (Any, Any) -> None
try:
task_name = kwargs.get("sender")
headers = kwargs.get("headers")
body = kwargs.get("body")
if task_name is None or (headers is None and body is None):
_log.debug("before_task_publish: sender/headers/body not found")
return
task = registry.tasks.get(task_name)
task_id = get_task_id(headers, body)
_log.trace("before_task_publish", extra={"task": str(task), "task_id": str(task_id)})
if task is None or task_id is None:
_log.debug("before_task_publish: task/task_id not found")
return
before_data = CeleryClientData()
before_data.set_task(task_name)
before_data.set_task_id(task_id)
exchange = kwargs.get("exchange")
if exchange is not None:
before_data.set_exchange(exchange)
routing_key = kwargs.get("routing_key")
if routing_key is not None:
before_data.set_routing_key(routing_key)
if not add_broker_tags(before_data, task):
before_data.set_host("localhost")
if isinstance(body, tuple):
# version 2: https://docs.celeryq.dev/en/stable/internals/protocol.html#message-protocol-task-v2
add_args(before_data, body[0])
add_kwargs(before_data, body[1])
elif isinstance(body, dict):
# version 1: https://docs.celeryq.dev/en/stable/internals/protocol.html#message-protocol-task-v1
add_args(before_data, body.get("args", []))
add_kwargs(before_data, body.get("kwargs", {}))
extra_headers = {} # type: Dict[str, str]
before_data.inject_routing_data(extra_headers)
# https://github.com/celery/celery/issues/4875
task_headers = kwargs.get("headers") or {}
task_headers.setdefault("headers", {})
task_headers["headers"].update(extra_headers)
kwargs["headers"] = task_headers
span_id, _, _ = ClientWorkerManager.open_span(before_data)
if span_id is not None:
span = Span(span_id, CeleryClientData())
# Store the span on the task to eventually close it out on the "after" signal
_log.trace(
"before_task_publish pushing span_id to catalog",
extra={"task": str(task), "task_id": str(task_id), "span_id": repr(span_id)},
)
task_catalog_push(task, task_id, span, False)
else:
_log.debug("before_task_publish: span is None", extra={"task": str(task), "task_id": str(task_id)})
except Exception as exc:
_log.exception("before_task_publish", exc)
_log.trace("instrumented before_task_publish")
@signals.after_task_publish.connect
def after_task_publish(*args, **kwargs):
# type: (Any, Any) -> None
try:
_log.trace("after_task_publish")
task = registry.tasks.get(kwargs.get("sender"))
task_id = get_task_id(kwargs.get("headers"), kwargs.get("body"))
if task is None or task_id is None:
_log.debug("after_task_publish: task/task_id not found")
return
span = task_catalog_pop(task, task_id, False)
if span is not None:
_log.trace(
"after_task_publish closing span_id",
extra={"task": str(task), "task_id": str(task_id), "span_id": repr(span.span_id)},
)
span.finish()
else:
_log.debug("after_task_publish: span is None", extra={"task": str(task), "task_id": str(task_id)})
except Exception as exc:
_log.exception("after_task_publish", exc)
_log.trace("instrumented after_task_publish")
@signals.worker_process_shutdown.connect
def worker_process_shutdown(*args, **kwargs):
# type: (Any, Any) -> None
try:
brain = _get_brain()
if brain is not None:
_log.info("worker_process_shutdown: shutting down")
brain.shutdown()
except Exception as exc:
_log.exception("worker_process_shutdown", exc)
_log.instrumentation_success(getattr(celery, "__version__"))
except ImportError as exc:
_log.instrumentation_skipped("", {"exc": exc})