Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / dashboard / modules / dashboard_sdk.py
Size: Mime:
import dataclasses
import importlib
import logging
import json
import yaml
from pathlib import Path
import tempfile
from typing import Any, Dict, List, Optional
from pkg_resources import packaging
import ray

try:
    import requests
except ImportError:
    requests = None


from ray._private.runtime_env.packaging import (
    create_package,
    get_uri_for_directory,
    get_uri_for_package,
)
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
from ray.dashboard.modules.job.common import uri_to_http_components

from ray.util.annotations import PublicAPI
from ray.client_builder import _split_address
from ray.autoscaler._private.cli_logger import cli_logger

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# By default, connect to local cluster.
DEFAULT_DASHBOARD_ADDRESS = "http://localhost:8265"


def parse_runtime_env_args(
    runtime_env: Optional[str] = None,
    runtime_env_json: Optional[str] = None,
    working_dir: Optional[str] = None,
):
    """
    Generates a runtime_env dictionary using `runtime_env`, `runtime_env_json`,
    and `working_dir` CLI options. Only one of `runtime_env` or
    `runtime_env_json` may be defined. `working_dir` overwrites the
    `working_dir` from any other option.
    """

    final_runtime_env = {}
    if runtime_env is not None:
        if runtime_env_json is not None:
            raise ValueError(
                "Only one of --runtime_env and --runtime-env-json can be provided."
            )
        with open(runtime_env, "r") as f:
            final_runtime_env = yaml.safe_load(f)

    elif runtime_env_json is not None:
        final_runtime_env = json.loads(runtime_env_json)

    if working_dir is not None:
        if "working_dir" in final_runtime_env:
            cli_logger.warning(
                "Overriding runtime_env working_dir with --working-dir option"
            )

        final_runtime_env["working_dir"] = working_dir

    return final_runtime_env


@dataclasses.dataclass
class ClusterInfo:
    address: str
    cookies: Optional[Dict[str, Any]] = None
    metadata: Optional[Dict[str, Any]] = None
    headers: Optional[Dict[str, Any]] = None


# TODO (shrekris-anyscale): renaming breaks compatibility, do NOT rename
def get_job_submission_client_cluster_info(
    address: str,
    # For backwards compatibility
    *,
    # only used in importlib case in parse_cluster_info, but needed
    # in function signature.
    create_cluster_if_needed: Optional[bool] = False,
    cookies: Optional[Dict[str, Any]] = None,
    metadata: Optional[Dict[str, Any]] = None,
    headers: Optional[Dict[str, Any]] = None,
    _use_tls: Optional[bool] = False,
) -> ClusterInfo:
    """Get address, cookies, and metadata used for SubmissionClient.

    If no port is specified in `address`, the Ray dashboard default will be
    inserted.

    Args:
        address: Address without the module prefix that is passed
            to SubmissionClient.
        create_cluster_if_needed: Indicates whether the cluster
            of the address returned needs to be running. Ray doesn't
            start a cluster before interacting with jobs, but other
            implementations may do so.

    Returns:
        ClusterInfo object consisting of address, cookies, and metadata
        for SubmissionClient to use.
    """

    scheme = "https" if _use_tls else "http"
    return ClusterInfo(
        address=f"{scheme}://{address}",
        cookies=cookies,
        metadata=metadata,
        headers=headers,
    )


def parse_cluster_info(
    address: Optional[str] = None,
    create_cluster_if_needed: bool = False,
    cookies: Optional[Dict[str, Any]] = None,
    metadata: Optional[Dict[str, Any]] = None,
    headers: Optional[Dict[str, Any]] = None,
) -> ClusterInfo:
    if address is None:
        if (
            ray.is_initialized()
            and ray._private.worker.global_worker.node.address_info["webui_url"]
            is not None
        ):
            address = (
                "http://"
                f"{ray._private.worker.global_worker.node.address_info['webui_url']}"
            )
        else:
            logger.info(
                f"No address provided, defaulting to {DEFAULT_DASHBOARD_ADDRESS}."
            )
            address = DEFAULT_DASHBOARD_ADDRESS

    module_string, inner_address = _split_address(address)

    # If user passes in ray://, raise error. Dashboard submission should
    # not use a Ray client address.
    if module_string == "ray":
        raise ValueError(
            f'Got an unexpected Ray client address "{address}" while trying '
            "to connect to the Ray dashboard. The dashboard SDK requires the "
            "Ray dashboard server's HTTP(S) address (which should start with "
            '"http://" or "https://", not "ray://"). If this address '
            "wasn't passed explicitly, it may be set in the RAY_ADDRESS "
            "environment variable."
        )

    # If user passes http(s)://, go through normal parsing.
    if module_string in {"http", "https"}:
        return get_job_submission_client_cluster_info(
            inner_address,
            create_cluster_if_needed=create_cluster_if_needed,
            cookies=cookies,
            metadata=metadata,
            headers=headers,
            _use_tls=module_string == "https",
        )
    # Try to dynamically import the function to get cluster info.
    else:
        try:
            module = importlib.import_module(module_string)
        except Exception:
            raise RuntimeError(
                f"Module: {module_string} does not exist.\n"
                f"This module was parsed from Address: {address}"
            ) from None
        assert "get_job_submission_client_cluster_info" in dir(module), (
            f"Module: {module_string} does "
            "not have `get_job_submission_client_cluster_info`."
        )

        return module.get_job_submission_client_cluster_info(
            inner_address,
            create_cluster_if_needed=create_cluster_if_needed,
            cookies=cookies,
            metadata=metadata,
            headers=headers,
        )


class SubmissionClient:
    def __init__(
        self,
        address: Optional[str] = None,
        create_cluster_if_needed: bool = False,
        cookies: Optional[Dict[str, Any]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        headers: Optional[Dict[str, Any]] = None,
    ):

        # Remove any trailing slashes
        if address is not None and address.endswith("/"):
            address = address.rstrip("/")
            logger.debug(
                "The submission address cannot contain trailing slashes. Removing "
                f'them from the requested submission address of "{address}".'
            )

        cluster_info = parse_cluster_info(
            address, create_cluster_if_needed, cookies, metadata, headers
        )
        self._address = cluster_info.address
        self._cookies = cluster_info.cookies
        self._default_metadata = cluster_info.metadata or {}
        # Headers used for all requests sent to job server, optional and only
        # needed for cases like authentication to remote cluster.
        self._headers = cluster_info.headers

    def _check_connection_and_version(
        self, min_version: str = "1.9", version_error_message: str = None
    ):
        self._check_connection_and_version_with_url(min_version, version_error_message)

    def _check_connection_and_version_with_url(
        self,
        min_version: str = "1.9",
        version_error_message: str = None,
        url: str = "/api/version",
    ):
        if version_error_message is None:
            version_error_message = (
                f"Please ensure the cluster is running Ray {min_version} or higher."
            )

        try:
            r = self._do_request("GET", url)
            if r.status_code == 404:
                raise RuntimeError(version_error_message)
            r.raise_for_status()

            running_ray_version = r.json()["ray_version"]
            if packaging.version.parse(running_ray_version) < packaging.version.parse(
                min_version
            ):
                raise RuntimeError(version_error_message)
            # TODO(edoakes): check the version if/when we break compatibility.
        except requests.exceptions.ConnectionError:
            raise ConnectionError(
                f"Failed to connect to Ray at address: {self._address}."
            )

    def _raise_error(self, r: "requests.Response"):
        raise RuntimeError(
            f"Request failed with status code {r.status_code}: {r.text}."
        )

    def _do_request(
        self,
        method: str,
        endpoint: str,
        *,
        data: Optional[bytes] = None,
        json_data: Optional[dict] = None,
        **kwargs,
    ) -> "requests.Response":
        """Perform the actual HTTP request

        Keyword arguments other than "cookies", "headers" are forwarded to the
        `requests.request()`.
        """
        url = self._address + endpoint
        logger.debug(f"Sending request to {url} with json data: {json_data or {}}.")
        return requests.request(
            method,
            url,
            cookies=self._cookies,
            data=data,
            json=json_data,
            headers=self._headers,
            **kwargs,
        )

    def _package_exists(
        self,
        package_uri: str,
    ) -> bool:
        protocol, package_name = uri_to_http_components(package_uri)
        r = self._do_request("GET", f"/api/packages/{protocol}/{package_name}")

        if r.status_code == 200:
            logger.debug(f"Package {package_uri} already exists.")
            return True
        elif r.status_code == 404:
            logger.debug(f"Package {package_uri} does not exist.")
            return False
        else:
            self._raise_error(r)

    def _upload_package(
        self,
        package_uri: str,
        package_path: str,
        include_parent_dir: Optional[bool] = False,
        excludes: Optional[List[str]] = None,
        is_file: bool = False,
    ) -> bool:
        logger.info(f"Uploading package {package_uri}.")
        with tempfile.TemporaryDirectory() as tmp_dir:
            protocol, package_name = uri_to_http_components(package_uri)
            if is_file:
                package_file = Path(package_path)
            else:
                package_file = Path(tmp_dir) / package_name
                create_package(
                    package_path,
                    package_file,
                    include_parent_dir=include_parent_dir,
                    excludes=excludes,
                )
            try:
                r = self._do_request(
                    "PUT",
                    f"/api/packages/{protocol}/{package_name}",
                    data=package_file.read_bytes(),
                )
                if r.status_code != 200:
                    self._raise_error(r)
            finally:
                # If the package is a user's existing file, don't delete it.
                if not is_file:
                    package_file.unlink()

    def _upload_package_if_needed(
        self,
        package_path: str,
        include_parent_dir: bool = False,
        excludes: Optional[List[str]] = None,
        is_file: bool = False,
    ) -> str:
        if is_file:
            package_uri = get_uri_for_package(Path(package_path))
        else:
            package_uri = get_uri_for_directory(package_path, excludes=excludes)

        if not self._package_exists(package_uri):
            self._upload_package(
                package_uri,
                package_path,
                include_parent_dir=include_parent_dir,
                excludes=excludes,
                is_file=is_file,
            )
        else:
            logger.info(f"Package {package_uri} already exists, skipping upload.")

        return package_uri

    def _upload_working_dir_if_needed(self, runtime_env: Dict[str, Any]):
        def _upload_fn(working_dir, excludes, is_file=False):
            self._upload_package_if_needed(
                working_dir,
                include_parent_dir=False,
                excludes=excludes,
                is_file=is_file,
            )

        upload_working_dir_if_needed(runtime_env, upload_fn=_upload_fn)

    def _upload_py_modules_if_needed(self, runtime_env: Dict[str, Any]):
        def _upload_fn(module_path, excludes, is_file=False):
            self._upload_package_if_needed(
                module_path, include_parent_dir=True, excludes=excludes, is_file=is_file
            )

        upload_py_modules_if_needed(runtime_env, upload_fn=_upload_fn)

    @PublicAPI(stability="beta")
    def get_version(self) -> str:
        r = self._do_request("GET", "/api/version")
        if r.status_code == 200:
            return r.json().get("version")
        else:
            self._raise_error(r)