Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / _private / runtime_env / plugin.py
Size: Mime:
import json
import logging
import os
from abc import ABC
from typing import Any, Dict, List, Optional, Type

from ray._common.utils import import_attr
from ray._private.runtime_env.constants import (
    RAY_RUNTIME_ENV_CLASS_FIELD_NAME,
    RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY,
    RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY,
    RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY,
    RAY_RUNTIME_ENV_PLUGINS_ENV_VAR,
    RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME,
)
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.uri_cache import URICache
from ray.util.annotations import DeveloperAPI

default_logger = logging.getLogger(__name__)


@DeveloperAPI
class RuntimeEnvPlugin(ABC):
    """Abstract base class for runtime environment plugins."""

    name: str = None
    priority: int = RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY

    @staticmethod
    def validate(runtime_env_dict: dict) -> None:
        """Validate user entry for this plugin.

        The method is invoked upon installation of runtime env.

        Args:
            runtime_env_dict: The user-supplied runtime environment dict.

        Raises:
            ValueError: If the validation fails.
        """
        pass

    def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]:  # noqa: F821
        return []

    async def create(
        self,
        uri: Optional[str],
        runtime_env,
        context: RuntimeEnvContext,
        logger: logging.Logger,
    ) -> float:
        """Create and install the runtime environment.

        Gets called in the runtime env agent at install time. The URI can be
        used as a caching mechanism.

        Args:
            uri: A URI uniquely describing this resource.
            runtime_env: The RuntimeEnv object.
            context: Auxiliary information supplied by Ray.
            logger: A logger to log messages during the context modification.

        Returns:
            float: The disk space taken up by this plugin installation for this
                environment. e.g. for working_dir, this downloads the files to the
                local node.
        """
        return 0

    def modify_context(
        self,
        uris: List[str],
        runtime_env: "RuntimeEnv",  # noqa: F821
        context: RuntimeEnvContext,
        logger: logging.Logger,
    ) -> None:
        """Modify context to change worker startup behavior.

        For example, you can use this to prepend "cd <dir>" command to worker
        startup, or add new environment variables.

        Args:
            uris: The URIs used by this resource.
            runtime_env: The RuntimeEnv object.
            context: Auxiliary information supplied by Ray.
            logger: A logger to log messages during the context modification.
        """
        return

    def delete_uri(self, uri: str, logger: logging.Logger) -> float:
        """Delete the runtime environment given uri.

        Args:
            uri: A URI uniquely describing this resource.
            logger: The logger used to log messages during the deletion.

        Returns:
            float: The amount of space reclaimed by the deletion.
        """
        return 0


class PluginSetupContext:
    def __init__(
        self,
        name: str,
        class_instance: RuntimeEnvPlugin,
        priority: int,
        uri_cache: URICache,
    ):
        self.name = name
        self.class_instance = class_instance
        self.priority = priority
        self.uri_cache = uri_cache


class RuntimeEnvPluginManager:
    """This manager is used to load plugins in runtime env agent."""

    def __init__(self):
        self.plugins: Dict[str, PluginSetupContext] = {}
        plugin_config_str = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR)
        if plugin_config_str:
            plugin_configs = json.loads(plugin_config_str)
            self.load_plugins(plugin_configs)

    def validate_plugin_class(self, plugin_class: Type[RuntimeEnvPlugin]) -> None:
        if not issubclass(plugin_class, RuntimeEnvPlugin):
            raise RuntimeError(
                f"Invalid runtime env plugin class {plugin_class}. "
                "The plugin class must inherit "
                "ray._private.runtime_env.plugin.RuntimeEnvPlugin."
            )
        if not plugin_class.name:
            raise RuntimeError(f"No valid name in runtime env plugin {plugin_class}.")
        if plugin_class.name in self.plugins:
            raise RuntimeError(
                f"The name of runtime env plugin {plugin_class} conflicts "
                f"with {self.plugins[plugin_class.name]}.",
            )

    def validate_priority(self, priority: Any) -> None:
        if (
            not isinstance(priority, int)
            or priority < RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY
            or priority > RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY
        ):
            raise RuntimeError(
                f"Invalid runtime env priority {priority}, "
                "it should be an integer between "
                f"{RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY} "
                f"and {RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY}."
            )

    def load_plugins(self, plugin_configs: List[Dict]) -> None:
        """Load runtime env plugins and create URI caches for them."""
        for plugin_config in plugin_configs:
            if (
                not isinstance(plugin_config, dict)
                or RAY_RUNTIME_ENV_CLASS_FIELD_NAME not in plugin_config
            ):
                raise RuntimeError(
                    f"Invalid runtime env plugin config {plugin_config}, "
                    "it should be a object which contains the "
                    f"{RAY_RUNTIME_ENV_CLASS_FIELD_NAME} field."
                )
            plugin_class = import_attr(plugin_config[RAY_RUNTIME_ENV_CLASS_FIELD_NAME])
            self.validate_plugin_class(plugin_class)

            # The priority should be an integer between 0 and 100.
            # The default priority is 10. A smaller number indicates a
            # higher priority and the plugin will be set up first.
            if RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME in plugin_config:
                priority = plugin_config[RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME]
            else:
                priority = plugin_class.priority
            self.validate_priority(priority)

            class_instance = plugin_class()
            self.plugins[plugin_class.name] = PluginSetupContext(
                plugin_class.name,
                class_instance,
                priority,
                self.create_uri_cache_for_plugin(class_instance),
            )

    def add_plugin(self, plugin: RuntimeEnvPlugin) -> None:
        """Add a plugin to the manager and create a URI cache for it.

        Args:
            plugin: The class instance of the plugin.
        """
        plugin_class = type(plugin)
        self.validate_plugin_class(plugin_class)
        self.validate_priority(plugin_class.priority)
        self.plugins[plugin_class.name] = PluginSetupContext(
            plugin_class.name,
            plugin,
            plugin_class.priority,
            self.create_uri_cache_for_plugin(plugin),
        )

    def create_uri_cache_for_plugin(self, plugin: RuntimeEnvPlugin) -> URICache:
        """Create a URI cache for a plugin.

        Args:
            plugin_name: The name of the plugin.

        Returns:
            The created URI cache for the plugin.
        """
        # Set the max size for the cache.  Defaults to 10 GB.
        cache_size_env_var = f"RAY_RUNTIME_ENV_{plugin.name}_CACHE_SIZE_GB".upper()
        cache_size_bytes = int(
            (1024**3) * float(os.environ.get(cache_size_env_var, 10))
        )
        return URICache(plugin.delete_uri, cache_size_bytes)

    def sorted_plugin_setup_contexts(self) -> List[PluginSetupContext]:
        """Get the sorted plugin setup contexts, sorted by increasing priority.

        Returns:
            The sorted plugin setup contexts.
        """
        return sorted(self.plugins.values(), key=lambda x: x.priority)


async def create_for_plugin_if_needed(
    runtime_env: "RuntimeEnv",  # noqa: F821
    plugin: RuntimeEnvPlugin,
    uri_cache: URICache,
    context: RuntimeEnvContext,
    logger: logging.Logger = default_logger,
):
    """Set up the environment using the plugin if not already set up and cached."""
    if plugin.name not in runtime_env or runtime_env[plugin.name] is None:
        return

    plugin.validate(runtime_env)

    uris = plugin.get_uris(runtime_env)

    if not uris:
        logger.debug(
            f"No URIs for runtime env plugin {plugin.name}; "
            "create always without checking the cache."
        )
        await plugin.create(None, runtime_env, context, logger=logger)

    for uri in uris:
        if uri not in uri_cache:
            logger.debug(f"Cache miss for URI {uri}.")
            size_bytes = await plugin.create(uri, runtime_env, context, logger=logger)
            uri_cache.add(uri, size_bytes, logger=logger)
        else:
            logger.info(
                f"Runtime env {plugin.name} {uri} is already installed "
                "and will be reused. Search "
                "all runtime_env_setup-*.log to find the corresponding setup log."
            )
            uri_cache.mark_used(uri, logger=logger)

    plugin.modify_context(uris, runtime_env, context, logger)