Repository URL to install this package:
|
Version:
2.0.0rc1 ▾
|
import hashlib
import logging
import os
import shutil
from enum import Enum
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Tuple
from urllib.parse import urlparse
from zipfile import ZipFile
from filelock import FileLock
from ray._private.ray_constants import (
RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT,
RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR,
)
from ray._private.gcs_utils import GcsAioClient
from ray._private.thirdparty.pathspec import PathSpec
from ray.experimental.internal_kv import (
_internal_kv_exists,
_internal_kv_put,
_pin_runtime_env_uri,
)
default_logger = logging.getLogger(__name__)
# If an individual file is beyond this size, print a warning.
FILE_SIZE_WARNING = 10 * 1024 * 1024 # 10MiB
# The size is bounded by the max gRPC message size.
# Keep in sync with max_grpc_message_size in ray_config_def.h.
GCS_STORAGE_MAX_SIZE = int(
os.environ.get("RAY_max_grpc_message_size", 250 * 1024 * 1024)
)
RAY_PKG_PREFIX = "_ray_pkg_"
def _mib_string(num_bytes: float) -> str:
size_mib = float(num_bytes / 1024 ** 2)
return f"{size_mib:.2f}MiB"
class Protocol(Enum):
"""A enum for supported storage backends."""
# For docstring
def __new__(cls, value, doc=None):
self = object.__new__(cls)
self._value_ = value
if doc is not None:
self.__doc__ = doc
return self
GCS = "gcs", "For packages dynamically uploaded and managed by the GCS."
CONDA = "conda", "For conda environments installed locally on each node."
PIP = "pip", "For pip environments installed locally on each node."
HTTPS = "https", "Remote https path, assumes everything packed in one zip file."
S3 = "s3", "Remote s3 path, assumes everything packed in one zip file."
GS = "gs", "Remote google storage path, assumes everything packed in one zip file."
FILE = "file", "File storage path, assumes everything packed in one zip file."
@classmethod
def remote_protocols(cls):
# Returns a list of protocols that support remote storage
# These protocols should only be used with paths that end in ".zip"
return [cls.HTTPS, cls.S3, cls.GS, cls.FILE]
def _xor_bytes(left: bytes, right: bytes) -> bytes:
if left and right:
return bytes(a ^ b for (a, b) in zip(left, right))
return left or right
def _dir_travel(
path: Path,
excludes: List[Callable],
handler: Callable,
logger: Optional[logging.Logger] = default_logger,
):
"""Travels the path recursively, calling the handler on each subpath.
Respects excludes, which will be called to check if this path is skipped.
"""
e = _get_gitignore(path)
if e is not None:
excludes.append(e)
skip = any(e(path) for e in excludes)
if not skip:
try:
handler(path)
except Exception as e:
logger.error(f"Issue with path: {path}")
raise e
if path.is_dir():
for sub_path in path.iterdir():
_dir_travel(sub_path, excludes, handler, logger=logger)
if e is not None:
excludes.pop()
def _hash_directory(
root: Path,
relative_path: Path,
excludes: Optional[Callable],
logger: Optional[logging.Logger] = default_logger,
) -> bytes:
"""Helper function to create hash of a directory.
It'll go through all the files in the directory and xor
hash(file_name, file_content) to create a hash value.
"""
hash_val = b"0" * 8
BUF_SIZE = 4096 * 1024
def handler(path: Path):
md5 = hashlib.md5()
md5.update(str(path.relative_to(relative_path)).encode())
if not path.is_dir():
try:
f = path.open("rb")
except Exception as e:
logger.debug(
f"Skipping contents of file {path} when calculating package hash "
f"because the file could not be opened: {e}"
)
else:
try:
data = f.read(BUF_SIZE)
while len(data) != 0:
md5.update(data)
data = f.read(BUF_SIZE)
finally:
f.close()
nonlocal hash_val
hash_val = _xor_bytes(hash_val, md5.digest())
excludes = [] if excludes is None else [excludes]
_dir_travel(root, excludes, handler, logger=logger)
return hash_val
def parse_uri(pkg_uri: str) -> Tuple[Protocol, str]:
"""
Parse resource uri into protocol and package name based on its format.
Note that the output of this function is not for handling actual IO, it's
only for setting up local directory folders by using package name as path.
For GCS URIs, netloc is the package name.
urlparse("gcs://_ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip")
-> ParseResult(
scheme='gcs',
netloc='_ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip'
)
-> ("gcs", "_ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip")
For HTTPS URIs, the netloc will have '.' replaced with '_', and
the path will have '/' replaced with '_'. The package name will be the
adjusted path with 'https_' prepended.
urlparse(
"https://github.com/shrekris-anyscale/test_module/archive/HEAD.zip"
)
-> ParseResult(
scheme='https',
netloc='github.com',
path='/shrekris-anyscale/test_repo/archive/HEAD.zip'
)
-> ("https",
"github_com_shrekris-anyscale_test_repo_archive_HEAD.zip")
For S3 URIs, the bucket and path will have '/' replaced with '_'. The
package name will be the adjusted path with 's3_' prepended.
urlparse("s3://bucket/dir/file.zip")
-> ParseResult(
scheme='s3',
netloc='bucket',
path='/dir/file.zip'
)
-> ("s3", "bucket_dir_file.zip")
For GS URIs, the path will have '/' replaced with '_'. The package name
will be the adjusted path with 'gs_' prepended.
urlparse("gs://public-runtime-env-test/test_module.zip")
-> ParseResult(
scheme='gs',
netloc='public-runtime-env-test',
path='/test_module.zip'
)
-> ("gs",
"gs_public-runtime-env-test_test_module.zip")
For FILE URIs, the path will have '/' replaced with '_'. The package name
will be the adjusted path with 'file_' prepended.
urlparse("file:///path/to/test_module.zip")
-> ParseResult(
scheme='file',
netloc='path',
path='/path/to/test_module.zip'
)
-> ("file", "file__path_to_test_module.zip")
"""
uri = urlparse(pkg_uri)
try:
protocol = Protocol(uri.scheme)
except ValueError as e:
raise ValueError(
f"Invalid protocol for runtime_env URI {pkg_uri}. "
f"Supported protocols: {Protocol._member_names_}. Original error: {e}"
)
if protocol == Protocol.S3 or protocol == Protocol.GS:
return (protocol, f"{protocol.value}_{uri.netloc}{uri.path.replace('/', '_')}")
elif protocol == Protocol.HTTPS:
return (
protocol,
f"https_{uri.netloc.replace('.', '_')}{uri.path.replace('/', '_')}",
)
elif protocol == Protocol.FILE:
return (
protocol,
f"file_{uri.path.replace('/', '_')}",
)
else:
return (protocol, uri.netloc)
def is_zip_uri(uri: str) -> bool:
try:
protocol, path = parse_uri(uri)
except ValueError:
return False
return Path(path).suffix == ".zip"
def is_whl_uri(uri: str) -> bool:
try:
_, path = parse_uri(uri)
except ValueError:
return False
return Path(path).suffix == ".whl"
def is_jar_uri(uri: str) -> bool:
try:
_, path = parse_uri(uri)
except ValueError:
return False
return Path(path).suffix == ".jar"
def _get_excludes(path: Path, excludes: List[str]) -> Callable:
path = path.absolute()
pathspec = PathSpec.from_lines("gitwildmatch", excludes)
def match(p: Path):
path_str = str(p.absolute().relative_to(path))
return pathspec.match_file(path_str)
return match
def _get_gitignore(path: Path) -> Optional[Callable]:
path = path.absolute()
ignore_file = path / ".gitignore"
if ignore_file.is_file():
with ignore_file.open("r") as f:
pathspec = PathSpec.from_lines("gitwildmatch", f.readlines())
def match(p: Path):
path_str = str(p.absolute().relative_to(path))
return pathspec.match_file(path_str)
return match
else:
return None
def pin_runtime_env_uri(uri: str, *, expiration_s: Optional[int] = None) -> None:
"""Pin a reference to a runtime_env URI in the GCS on a timeout.
This is used to avoid premature eviction in edge conditions for job
reference counting. See https://github.com/ray-project/ray/pull/24719.
Packages are uploaded to GCS in order to be downloaded by a runtime env plugin
(e.g. working_dir, py_modules) after the job starts.
This function adds a temporary reference to the package in the GCS to prevent
it from being deleted before the job starts. (See #23423 for the bug where
this happened.)
If this reference didn't have an expiration, then if the script exited
(e.g. via Ctrl-C) before the job started, the reference would never be
removed, so the package would never be deleted.
"""
if expiration_s is None:
expiration_s = int(
os.environ.get(
RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR,
RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT,
)
)
elif not isinstance(expiration_s, int):
raise ValueError(f"expiration_s must be an int, got {type(expiration_s)}.")
if expiration_s < 0:
raise ValueError(f"expiration_s must be >= 0, got {expiration_s}.")
elif expiration_s > 0:
_pin_runtime_env_uri(uri, expiration_s=expiration_s)
def _store_package_in_gcs(
pkg_uri: str,
data: bytes,
logger: Optional[logging.Logger] = default_logger,
) -> int:
"""Stores package data in the Global Control Store (GCS).
Args:
pkg_uri: The GCS key to store the data in.
data: The serialized package's bytes to store in the GCS.
logger (Optional[logging.Logger]): The logger used by this function.
Return:
int: Size of data
Raises:
RuntimeError: If the upload to the GCS fails.
ValueError: If the data's size exceeds GCS_STORAGE_MAX_SIZE.
"""
file_size = len(data)
size_str = _mib_string(file_size)
if len(data) >= GCS_STORAGE_MAX_SIZE:
raise ValueError(
f"Package size ({size_str}) exceeds the maximum size of "
f"{_mib_string(GCS_STORAGE_MAX_SIZE)}. You can exclude large "
"files using the 'excludes' option to the runtime_env."
)
logger.info(f"Pushing file package '{pkg_uri}' ({size_str}) to Ray cluster...")
try:
_internal_kv_put(pkg_uri, data)
except Exception as e:
raise RuntimeError(
"Failed to store package in the GCS.\n"
f" - GCS URI: {pkg_uri}\n"
f" - Package data ({size_str}): {data[:15]}...\n"
) from e
logger.info(f"Successfully pushed file package '{pkg_uri}'.")
return len(data)
def _get_local_path(base_directory: str, pkg_uri: str) -> str:
_, pkg_name = parse_uri(pkg_uri)
return os.path.join(base_directory, pkg_name)
def _zip_directory(
directory: str,
excludes: List[str],
output_path: str,
include_parent_dir: bool = False,
logger: Optional[logging.Logger] = default_logger,
) -> None:
"""Zip the target directory and write it to the output_path.
directory: The directory to zip.
excludes (List(str)): The directories or file to be excluded.
output_path: The output path for the zip file.
include_parent_dir: If true, includes the top-level directory as a
directory inside the zip file.
"""
pkg_file = Path(output_path).absolute()
with ZipFile(pkg_file, "w") as zip_handler:
# Put all files in the directory into the zip file.
dir_path = Path(directory).absolute()
def handler(path: Path):
# Pack this path if it's an empty directory or it's a file.
if path.is_dir() and next(path.iterdir(), None) is None or path.is_file():
file_size = path.stat().st_size
if file_size >= FILE_SIZE_WARNING:
logger.warning(
f"File {path} is very large "
f"({_mib_string(file_size)}). Consider adding this "
"file to the 'excludes' list to skip uploading it: "
"`ray.init(..., "
f"runtime_env={{'excludes': ['{path}']}})`"
)
to_path = path.relative_to(dir_path)
if include_parent_dir:
to_path = dir_path.name / to_path
zip_handler.write(path, to_path)
excludes = [_get_excludes(dir_path, excludes)]
_dir_travel(dir_path, excludes, handler, logger=logger)
def package_exists(pkg_uri: str) -> bool:
"""Check whether the package with given URI exists or not.
Args:
pkg_uri: The uri of the package
Return:
True for package existing and False for not.
"""
protocol, pkg_name = parse_uri(pkg_uri)
if protocol == Protocol.GCS:
return _internal_kv_exists(pkg_uri)
else:
raise NotImplementedError(f"Protocol {protocol} is not supported")
def get_uri_for_package(package: Path) -> str:
"""Get a content-addressable URI from a package's contents."""
if package.suffix == ".whl":
# Wheel file names include the Python package name, version
# and tags, so it is already effectively content-addressed.
return "{protocol}://{whl_filename}".format(
protocol=Protocol.GCS.value, whl_filename=package.name
)
else:
hash_val = hashlib.md5(package.read_bytes()).hexdigest()
return "{protocol}://{pkg_name}.zip".format(
protocol=Protocol.GCS.value, pkg_name=RAY_PKG_PREFIX + hash_val
)
def get_uri_for_directory(directory: str, excludes: Optional[List[str]] = None) -> str:
"""Get a content-addressable URI from a directory's contents.
This function will generate the name of the package by the directory.
It'll go through all the files in the directory and hash the contents
of the files to get the hash value of the package.
The final package name is: _ray_pkg_<HASH_VAL>.zip of this package.
e.g., _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip
Examples:
.. code-block:: python
>>> get_uri_for_directory("/my_directory")
.... _ray_pkg_af2734982a741.zip
Args:
directory: The directory.
excludes (list[str]): The dir or files that should be excluded.
Returns:
URI (str)
Raises:
ValueError if the directory doesn't exist.
"""
if excludes is None:
excludes = []
directory = Path(directory).absolute()
if not directory.exists() or not directory.is_dir():
raise ValueError(f"directory {directory} must be an existing directory")
hash_val = _hash_directory(directory, directory, _get_excludes(directory, excludes))
return "{protocol}://{pkg_name}.zip".format(
protocol=Protocol.GCS.value, pkg_name=RAY_PKG_PREFIX + hash_val.hex()
)
def upload_package_to_gcs(pkg_uri: str, pkg_bytes: bytes):
protocol, pkg_name = parse_uri(pkg_uri)
if protocol == Protocol.GCS:
_store_package_in_gcs(pkg_uri, pkg_bytes)
elif protocol in Protocol.remote_protocols():
raise RuntimeError(
"upload_package_to_gcs should not be called with remote path."
)
else:
raise NotImplementedError(f"Protocol {protocol} is not supported")
def create_package(
directory: str,
target_path: Path,
include_parent_dir: bool = False,
excludes: Optional[List[str]] = None,
logger: Optional[logging.Logger] = default_logger,
):
if excludes is None:
excludes = []
if logger is None:
logger = default_logger
if not target_path.exists():
logger.info(f"Creating a file package for local directory '{directory}'.")
_zip_directory(
directory,
excludes,
target_path,
include_parent_dir=include_parent_dir,
logger=logger,
)
def upload_package_if_needed(
pkg_uri: str,
base_directory: str,
directory: str,
include_parent_dir: bool = False,
excludes: Optional[List[str]] = None,
logger: Optional[logging.Logger] = default_logger,
) -> bool:
"""Upload the contents of the directory under the given URI.
This will first create a temporary zip file under the passed
base_directory.
If the package already exists in storage, this is a no-op.
Args:
pkg_uri: URI of the package to upload.
base_directory: Directory where package files are stored.
directory: Directory to be uploaded.
include_parent_dir: If true, includes the top-level directory as a
directory inside the zip file.
excludes: List specifying files to exclude.
"""
if excludes is None:
excludes = []
if logger is None:
logger = default_logger
pin_runtime_env_uri(pkg_uri)
if package_exists(pkg_uri):
return False
package_file = Path(_get_local_path(base_directory, pkg_uri))
create_package(
directory,
package_file,
include_parent_dir=include_parent_dir,
excludes=excludes,
)
upload_package_to_gcs(pkg_uri, package_file.read_bytes())
# Remove the local file to avoid accumulating temporary zip files.
package_file.unlink()
return True
def get_local_dir_from_uri(uri: str, base_directory: str) -> Path:
"""Return the local directory corresponding to this URI."""
pkg_file = Path(_get_local_path(base_directory, uri))
local_dir = pkg_file.with_suffix("")
return local_dir
async def download_and_unpack_package(
pkg_uri: str,
base_directory: str,
gcs_aio_client: GcsAioClient,
logger: Optional[logging.Logger] = default_logger,
) -> str:
"""Download the package corresponding to this URI and unpack it if zipped.
Will be written to a file or directory named {base_directory}/{uri}.
Returns the path to this file or directory.
"""
pkg_file = Path(_get_local_path(base_directory, pkg_uri))
with FileLock(str(pkg_file) + ".lock"):
if logger is None:
logger = default_logger
logger.debug(f"Fetching package for URI: {pkg_uri}")
local_dir = get_local_dir_from_uri(pkg_uri, base_directory)
assert local_dir != pkg_file, "Invalid pkg_file!"
if local_dir.exists():
assert local_dir.is_dir(), f"{local_dir} is not a directory"
else:
protocol, pkg_name = parse_uri(pkg_uri)
if protocol == Protocol.GCS:
# Download package from the GCS.
code = await gcs_aio_client.internal_kv_get(
pkg_uri.encode(), namespace=None, timeout=None
)
if code is None:
raise IOError(f"Failed to fetch URI {pkg_uri} from GCS.")
code = code or b""
pkg_file.write_bytes(code)
if is_zip_uri(pkg_uri):
unzip_package(
package_path=pkg_file,
target_dir=local_dir,
remove_top_level_directory=False,
unlink_zip=True,
logger=logger,
)
else:
return str(pkg_file)
elif protocol in Protocol.remote_protocols():
# Download package from remote URI
tp = None
if protocol == Protocol.S3:
try:
import boto3
from smart_open import open as open_file
except ImportError:
raise ImportError(
"You must `pip install smart_open` and "
"`pip install boto3` to fetch URIs in s3 "
"bucket."
)
tp = {"client": boto3.client("s3")}
elif protocol == Protocol.GS:
try:
from google.cloud import storage # noqa: F401
from smart_open import open as open_file
except ImportError:
raise ImportError(
"You must `pip install smart_open` and "
"`pip install google-cloud-storage` "
"to fetch URIs in Google Cloud Storage bucket."
)
elif protocol == Protocol.FILE:
pkg_uri = pkg_uri[len("file://") :]
def open_file(uri, mode, *, transport_params=None):
return open(uri, mode)
else:
try:
from smart_open import open as open_file
except ImportError:
raise ImportError(
"You must `pip install smart_open` "
f"to fetch {protocol.value.upper()} URIs."
)
with open_file(pkg_uri, "rb", transport_params=tp) as package_zip:
with open_file(pkg_file, "wb") as fin:
fin.write(package_zip.read())
unzip_package(
package_path=pkg_file,
target_dir=local_dir,
remove_top_level_directory=True,
unlink_zip=True,
logger=logger,
)
else:
raise NotImplementedError(f"Protocol {protocol} is not supported")
return str(local_dir)
def get_top_level_dir_from_compressed_package(package_path: str):
"""
If compressed package at package_path contains a single top-level
directory, returns the name of the top-level directory. Otherwise,
returns None.
"""
package_zip = ZipFile(package_path, "r")
top_level_directory = None
for file_name in package_zip.namelist():
if top_level_directory is None:
# Cache the top_level_directory name when checking
# the first file in the zipped package
if "/" in file_name:
top_level_directory = file_name.split("/")[0]
else:
return None
else:
# Confirm that all other files
# belong to the same top_level_directory
if "/" not in file_name or file_name.split("/")[0] != top_level_directory:
return None
return top_level_directory
def remove_dir_from_filepaths(base_dir: str, rdir: str):
"""
base_dir: String path of the directory containing rdir
rdir: String path of directory relative to base_dir whose contents should
be moved to its base_dir, its parent directory
Removes rdir from the filepaths of all files and directories inside it.
In other words, moves all the files inside rdir to the directory that
contains rdir. Assumes base_dir's contents and rdir's contents have no
name conflicts.
"""
# Move rdir to a temporary directory, so its contents can be moved to
# base_dir without any name conflicts
with TemporaryDirectory() as tmp_dir:
# shutil.move() is used instead of os.rename() in case rdir and tmp_dir
# are located on separate file systems
shutil.move(os.path.join(base_dir, rdir), os.path.join(tmp_dir, rdir))
# Shift children out of rdir and into base_dir
rdir_children = os.listdir(os.path.join(tmp_dir, rdir))
for child in rdir_children:
shutil.move(
os.path.join(tmp_dir, rdir, child), os.path.join(base_dir, child)
)
def unzip_package(
package_path: str,
target_dir: str,
remove_top_level_directory: bool,
unlink_zip: bool,
logger: Optional[logging.Logger] = default_logger,
):
"""
Unzip the compressed package contained at package_path and store the
contents in target_dir. If remove_top_level_directory is True, the function
will automatically remove the top_level_directory and store the contents
directly in target_dir. If unlink_zip is True, the function will unlink the
zip file stored at package_path.
"""
try:
os.mkdir(target_dir)
except FileExistsError:
logger.info(f"Directory at {target_dir} already exists")
logger.debug(f"Unpacking {package_path} to {target_dir}")
with ZipFile(str(package_path), "r") as zip_ref:
zip_ref.extractall(target_dir)
if remove_top_level_directory:
top_level_directory = get_top_level_dir_from_compressed_package(package_path)
if top_level_directory is None:
raise ValueError(
"The package at package_path must contain "
"a single top level directory. Make sure there "
"are no hidden files at the same level as the "
"top level directory."
)
remove_dir_from_filepaths(target_dir, top_level_directory)
if unlink_zip:
Path(package_path).unlink()
def delete_package(pkg_uri: str, base_directory: str) -> Tuple[bool, int]:
"""Deletes a specific URI from the local filesystem.
Args:
pkg_uri: URI to delete.
Returns:
bool: True if the URI was successfully deleted, else False.
"""
deleted = False
path = Path(_get_local_path(base_directory, pkg_uri))
with FileLock(str(path) + ".lock"):
path = path.with_suffix("")
if path.exists():
if path.is_dir() and not path.is_symlink():
shutil.rmtree(str(path))
else:
path.unlink()
deleted = True
return deleted