Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / util / tracing / tracing_helper.py
Size: Mime:
import importlib
import inspect
import logging
import os
from contextlib import contextmanager
from functools import wraps
from inspect import Parameter
from types import ModuleType
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    List,
    MutableMapping,
    Optional,
    Sequence,
    Union,
    cast,
)

import ray._private.worker
from ray._private.inspect_util import (
    is_class_method,
    is_function_or_method,
    is_static_method,
)
from ray.runtime_context import get_runtime_context

logger = logging.getLogger(__name__)


class _OpenTelemetryProxy:
    """
    This proxy makes it possible for tracing to be disabled when opentelemetry
    is not installed on the cluster, but is installed locally.

    The check for `opentelemetry`'s existence must happen where the functions
    are executed because `opentelemetry` may be present where the functions
    are pickled. This can happen when `ray[full]` is installed locally by `ray`
    (no extra dependencies) is installed on the cluster.
    """

    allowed_functions = {"trace", "context", "propagate", "Context"}

    def __getattr__(self, name):
        if name in _OpenTelemetryProxy.allowed_functions:
            return getattr(self, f"_{name}")()
        else:
            raise AttributeError(f"Attribute does not exist: {name}")

    def _trace(self):
        return self._try_import("opentelemetry.trace")

    def _context(self):
        return self._try_import("opentelemetry.context")

    def _propagate(self):
        return self._try_import("opentelemetry.propagate")

    def _Context(self):
        context = self._context()
        if context:
            return context.context.Context
        else:
            return None

    def try_all(self):
        self._trace()
        self._context()
        self._propagate()
        self._Context()

    def _try_import(self, module):
        try:
            return importlib.import_module(module)
        except ImportError:
            if os.getenv("RAY_TRACING_ENABLED", "False").lower() in ["true", "1"]:
                raise ImportError(
                    "Install opentelemetry with "
                    "'pip install opentelemetry-api==1.0.0rc1' "
                    "and 'pip install opentelemetry-sdk==1.0.0rc1' to enable "
                    "tracing. See more at docs.ray.io/tracing.html"
                )


_opentelemetry = _OpenTelemetryProxy()
_opentelemetry.try_all()

_nameable = Union[str, Callable[..., Any]]
_global_is_tracing_enabled = False


def _sort_params_list(params_list: List[Parameter]):
    """Given a list of Parameters, if a kwargs Parameter exists,
    move it to the end of the list."""
    for i, param in enumerate(params_list):
        if param.kind == Parameter.VAR_KEYWORD:
            params_list.append(params_list.pop(i))
            break
    return params_list


def _add_param_to_signature(function: Callable, new_param: Parameter):
    """Add additional Parameter to function signature."""
    old_sig = inspect.signature(function)
    old_sig_list_repr = list(old_sig.parameters.values())
    # If new_param is already in signature, do not add it again.
    if any(param.name == new_param.name for param in old_sig_list_repr):
        return old_sig
    new_params = _sort_params_list(old_sig_list_repr + [new_param])
    new_sig = old_sig.replace(parameters=new_params)
    return new_sig


def _is_tracing_enabled() -> bool:
    """Checks environment variable feature flag to see if tracing is turned on.
    Tracing is off by default."""
    return _global_is_tracing_enabled


class _ImportFromStringError(Exception):
    pass


def _import_from_string(import_str: Union[ModuleType, str]) -> ModuleType:
    """Given a string that is in format "<module>:<attribute>",
    import the attribute."""
    if not isinstance(import_str, str):
        return import_str

    module_str, _, attrs_str = import_str.partition(":")
    if not module_str or not attrs_str:
        message = (
            'Import string "{import_str}" must be in format' '"<module>:<attribute>".'
        )
        raise _ImportFromStringError(message.format(import_str=import_str))

    try:
        module = importlib.import_module(module_str)
    except ImportError as exc:
        if exc.name != module_str:
            raise exc from None
        message = 'Could not import module "{module_str}".'
        raise _ImportFromStringError(message.format(module_str=module_str))

    instance = module
    try:
        for attr_str in attrs_str.split("."):
            instance = getattr(instance, attr_str)
    except AttributeError:
        message = 'Attribute "{attrs_str}" not found in module "{module_str}".'
        raise _ImportFromStringError(
            message.format(attrs_str=attrs_str, module_str=module_str)
        )

    return instance


class _DictPropagator:
    def inject_current_context() -> Dict[Any, Any]:
        """Inject trace context into otel propagator."""
        context_dict: Dict[Any, Any] = {}
        _opentelemetry.propagate.inject(context_dict)
        return context_dict

    def extract(context_dict: Dict[Any, Any]) -> "_opentelemetry.Context":
        """Given a trace context, extract as a Context."""
        return cast(
            _opentelemetry.Context, _opentelemetry.propagate.extract(context_dict)
        )


@contextmanager
def _use_context(
    parent_context: "_opentelemetry.Context",
) -> Generator[None, None, None]:
    """Uses the Ray trace context for the span."""
    if parent_context is not None:
        new_context = parent_context
    else:
        new_context = _opentelemetry.Context()
    token = _opentelemetry.context.attach(new_context)
    try:
        yield
    finally:
        _opentelemetry.context.detach(token)


def _function_hydrate_span_args(func: Callable[..., Any]):
    """Get the Attributes of the function that will be reported as attributes
    in the trace."""
    runtime_context = get_runtime_context().get()

    span_args = {
        "ray.remote": "function",
        "ray.function": func,
        "ray.pid": str(os.getpid()),
        "ray.job_id": runtime_context["job_id"].hex(),
        "ray.node_id": runtime_context["node_id"].hex(),
    }

    # We only get task ID for workers
    if ray._private.worker.global_worker.mode == ray._private.worker.WORKER_MODE:
        task_id = (
            runtime_context["task_id"].hex() if runtime_context.get("task_id") else None
        )
        if task_id:
            span_args["ray.task_id"] = task_id

    worker_id = getattr(ray._private.worker.global_worker, "worker_id", None)
    if worker_id:
        span_args["ray.worker_id"] = worker_id.hex()

    return span_args


def _function_span_producer_name(func: Callable[..., Any]) -> str:
    """Returns the function span name that has span kind of producer."""
    args = _function_hydrate_span_args(func)
    name = args["ray.function"]

    return f"{name} ray.remote"


def _function_span_consumer_name(func: Callable[..., Any]) -> str:
    """Returns the function span name that has span kind of consumer."""
    args = _function_hydrate_span_args(func)
    name = args["ray.function"]

    return f"{name} ray.remote_worker"


def _actor_hydrate_span_args(class_: _nameable, method: _nameable):
    """Get the Attributes of the actor that will be reported as attributes
    in the trace."""
    if callable(class_):
        class_ = class_.__name__
    if callable(method):
        method = method.__name__

    runtime_context = get_runtime_context().get()

    span_args = {
        "ray.remote": "actor",
        "ray.actor_class": class_,
        "ray.actor_method": method,
        "ray.function": f"{class_}.{method}",
        "ray.pid": str(os.getpid()),
        "ray.job_id": runtime_context["job_id"].hex(),
        "ray.node_id": runtime_context["node_id"].hex(),
    }

    # We only get actor ID for workers
    if ray._private.worker.global_worker.mode == ray._private.worker.WORKER_MODE:
        actor_id = (
            runtime_context["actor_id"].hex()
            if runtime_context.get("actor_id")
            else None
        )

        if actor_id:
            span_args["ray.actor_id"] = actor_id

    worker_id = getattr(ray._private.worker.global_worker, "worker_id", None)
    if worker_id:
        span_args["ray.worker_id"] = worker_id.hex()

    return span_args


def _actor_span_producer_name(class_: _nameable, method: _nameable) -> str:
    """Returns the actor span name that has span kind of producer."""
    args = _actor_hydrate_span_args(class_, method)
    assert args is not None
    name = args["ray.function"]

    return f"{name} ray.remote"


def _actor_span_consumer_name(class_: _nameable, method: _nameable) -> str:
    """Returns the actor span name that has span kind of consumer."""
    args = _actor_hydrate_span_args(class_, method)
    assert args is not None
    name = args["ray.function"]

    return f"{name} ray.remote_worker"


def _tracing_task_invocation(method):
    """Trace the execution of a remote task. Inject
    the current span context into kwargs for propagation."""

    @wraps(method)
    def _invocation_remote_span(
        self,
        args: Any = None,  # from tracing
        kwargs: MutableMapping[Any, Any] = None,  # from tracing
        *_args: Any,  # from Ray
        **_kwargs: Any,  # from Ray
    ) -> Any:
        # If tracing feature flag is not on, perform a no-op.
        # Tracing doesn't work for cross lang yet.
        if not _is_tracing_enabled() or self._is_cross_language:
            if kwargs is not None:
                assert "_ray_trace_ctx" not in kwargs
            return method(self, args, kwargs, *_args, **_kwargs)

        assert "_ray_trace_ctx" not in kwargs

        tracer = _opentelemetry.trace.get_tracer(__name__)
        with tracer.start_as_current_span(
            _function_span_producer_name(self._function_name),
            kind=_opentelemetry.trace.SpanKind.PRODUCER,
            attributes=_function_hydrate_span_args(self._function_name),
        ):
            # Inject a _ray_trace_ctx as a dictionary
            kwargs["_ray_trace_ctx"] = _DictPropagator.inject_current_context()
            return method(self, args, kwargs, *_args, **_kwargs)

    return _invocation_remote_span


def _inject_tracing_into_function(function):
    """Wrap the function argument passed to RemoteFunction's __init__ so that
    future execution of that function will include tracing.
    Use the provided trace context from kwargs.
    """
    # Add _ray_trace_ctx to function signature
    if not _is_tracing_enabled():
        return function

    setattr(
        function,
        "__signature__",
        _add_param_to_signature(
            function,
            inspect.Parameter(
                "_ray_trace_ctx", inspect.Parameter.KEYWORD_ONLY, default=None
            ),
        ),
    )

    @wraps(function)
    def _function_with_tracing(
        *args: Any,
        _ray_trace_ctx: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        if _ray_trace_ctx is None:
            return function(*args, **kwargs)

        tracer = _opentelemetry.trace.get_tracer(__name__)
        function_name = function.__module__ + "." + function.__name__

        # Retrieves the context from the _ray_trace_ctx dictionary we injected
        with _use_context(
            _DictPropagator.extract(_ray_trace_ctx)
        ), tracer.start_as_current_span(
            _function_span_consumer_name(function_name),
            kind=_opentelemetry.trace.SpanKind.CONSUMER,
            attributes=_function_hydrate_span_args(function_name),
        ):
            return function(*args, **kwargs)

    return _function_with_tracing


def _tracing_actor_creation(method):
    """Trace the creation of an actor. Inject
    the current span context into kwargs for propagation."""

    @wraps(method)
    def _invocation_actor_class_remote_span(
        self,
        args: Any = tuple(),  # from tracing
        kwargs: MutableMapping[Any, Any] = None,  # from tracing
        *_args: Any,  # from Ray
        **_kwargs: Any,  # from Ray
    ):
        if kwargs is None:
            kwargs = {}

        # If tracing feature flag is not on, perform a no-op
        if not _is_tracing_enabled():
            assert "_ray_trace_ctx" not in kwargs
            return method(self, args, kwargs, *_args, **_kwargs)

        class_name = self.__ray_metadata__.class_name
        method_name = "__init__"
        assert "_ray_trace_ctx" not in _kwargs
        tracer = _opentelemetry.trace.get_tracer(__name__)
        with tracer.start_as_current_span(
            name=_actor_span_producer_name(class_name, method_name),
            kind=_opentelemetry.trace.SpanKind.PRODUCER,
            attributes=_actor_hydrate_span_args(class_name, method_name),
        ) as span:
            # Inject a _ray_trace_ctx as a dictionary
            kwargs["_ray_trace_ctx"] = _DictPropagator.inject_current_context()

            result = method(self, args, kwargs, *_args, **_kwargs)

            span.set_attribute("ray.actor_id", result._ray_actor_id.hex())

            return result

    return _invocation_actor_class_remote_span


def _tracing_actor_method_invocation(method):
    """Trace the invocation of an actor method."""

    @wraps(method)
    def _start_span(
        self,
        args: Sequence[Any] = None,
        kwargs: MutableMapping[Any, Any] = None,
        *_args: Any,
        **_kwargs: Any,
    ) -> Any:
        # If tracing feature flag is not on, perform a no-op
        if not _is_tracing_enabled() or self._actor_ref()._ray_is_cross_language:
            if kwargs is not None:
                assert "_ray_trace_ctx" not in kwargs
            return method(self, args, kwargs, *_args, **_kwargs)

        class_name = (
            self._actor_ref()._ray_actor_creation_function_descriptor.class_name
        )
        method_name = self._method_name
        assert "_ray_trace_ctx" not in _kwargs

        tracer = _opentelemetry.trace.get_tracer(__name__)
        with tracer.start_as_current_span(
            name=_actor_span_producer_name(class_name, method_name),
            kind=_opentelemetry.trace.SpanKind.PRODUCER,
            attributes=_actor_hydrate_span_args(class_name, method_name),
        ) as span:
            # Inject a _ray_trace_ctx as a dictionary
            kwargs["_ray_trace_ctx"] = _DictPropagator.inject_current_context()

            span.set_attribute("ray.actor_id", self._actor_ref()._ray_actor_id.hex())

            return method(self, args, kwargs, *_args, **_kwargs)

    return _start_span


def _inject_tracing_into_class(_cls):
    """Given a class that will be made into an actor,
    inject tracing into all of the methods."""

    def span_wrapper(method: Callable[..., Any]) -> Any:
        def _resume_span(
            self: Any,
            *_args: Any,
            _ray_trace_ctx: Optional[Dict[str, Any]] = None,
            **_kwargs: Any,
        ) -> Any:
            """
            Wrap the user's function with a function that
            will extract the trace context
            """
            # If tracing feature flag is not on, perform a no-op
            if not _is_tracing_enabled() or _ray_trace_ctx is None:
                return method(self, *_args, **_kwargs)

            tracer: _opentelemetry.trace.Tracer = _opentelemetry.trace.get_tracer(
                __name__
            )

            # Retrieves the context from the _ray_trace_ctx dictionary we
            # injected.
            with _use_context(
                _DictPropagator.extract(_ray_trace_ctx)
            ), tracer.start_as_current_span(
                _actor_span_consumer_name(self.__class__.__name__, method),
                kind=_opentelemetry.trace.SpanKind.CONSUMER,
                attributes=_actor_hydrate_span_args(self.__class__.__name__, method),
            ):
                return method(self, *_args, **_kwargs)

        return _resume_span

    def async_span_wrapper(method: Callable[..., Any]) -> Any:
        async def _resume_span(
            self: Any,
            *_args: Any,
            _ray_trace_ctx: Optional[Dict[str, Any]] = None,
            **_kwargs: Any,
        ) -> Any:
            """
            Wrap the user's function with a function that
            will extract the trace context
            """
            # If tracing feature flag is not on, perform a no-op
            if not _is_tracing_enabled() or _ray_trace_ctx is None:
                return await method(self, *_args, **_kwargs)

            tracer = _opentelemetry.trace.get_tracer(__name__)

            # Retrieves the context from the _ray_trace_ctx dictionary we
            # injected, or starts a new context
            with _use_context(
                _DictPropagator.extract(_ray_trace_ctx)
            ), tracer.start_as_current_span(
                _actor_span_consumer_name(self.__class__.__name__, method.__name__),
                kind=_opentelemetry.trace.SpanKind.CONSUMER,
                attributes=_actor_hydrate_span_args(
                    self.__class__.__name__, method.__name__
                ),
            ):
                return await method(self, *_args, **_kwargs)

        return _resume_span

    methods = inspect.getmembers(_cls, is_function_or_method)
    for name, method in methods:
        # Skip tracing for staticmethod or classmethod, because these method
        # might not be called directly by remote calls. Additionally, they are
        # tricky to get wrapped and unwrapped.
        if is_static_method(_cls, name) or is_class_method(method):
            continue

        # Don't decorate the __del__ magic method.
        # It's because the __del__ can be called after Python
        # modules are garbage colleted, which means the modules
        # used for the decorator (e.g., `span_wrapper`) may not be
        # available. For example, it is not guranteed that
        # `_is_tracing_enabled` is available when `__del__` is called.
        # Tracing `__del__` is also not very useful.
        # https://joekuan.wordpress.com/2015/06/30/python-3-__del__-method-and-imported-modules/ # noqa
        if name == "__del__":
            continue

        # Add _ray_trace_ctx to method signature
        setattr(
            method,
            "__signature__",
            _add_param_to_signature(
                method,
                inspect.Parameter(
                    "_ray_trace_ctx", inspect.Parameter.KEYWORD_ONLY, default=None
                ),
            ),
        )

        if inspect.iscoroutinefunction(method):
            # If the method was async, swap out sync wrapper into async
            wrapped_method = wraps(method)(async_span_wrapper(method))
        else:
            wrapped_method = wraps(method)(span_wrapper(method))

        setattr(_cls, name, wrapped_method)

    return _cls