Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / train / xgboost / xgboost_checkpoint.py
Size: Mime:
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Optional

import xgboost

from ray.train._internal.framework_checkpoint import FrameworkCheckpoint
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="beta")
class XGBoostCheckpoint(FrameworkCheckpoint):
    """A :py:class:`~ray.train.Checkpoint` with XGBoost-specific functionality."""

    MODEL_FILENAME = "model.json"

    @classmethod
    def from_model(
        cls,
        booster: xgboost.Booster,
        *,
        preprocessor: Optional["Preprocessor"] = None,
        path: Optional[str] = None,
    ) -> "XGBoostCheckpoint":
        """Create a :py:class:`~ray.train.Checkpoint` that stores an XGBoost
        model.

        Args:
            booster: The XGBoost model to store in the checkpoint.
            preprocessor: A fitted preprocessor to be applied before inference.
            path: The path to the directory where the checkpoint file will be saved.
                This should start as an empty directory, since the *entire*
                directory will be treated as the checkpoint when reported.
                By default, a temporary directory will be created.

        Returns:
            An :py:class:`XGBoostCheckpoint` containing the specified ``Estimator``.

        Examples:

            ... testcode::

                import numpy as np
                import ray
                from ray.train.xgboost import XGBoostCheckpoint
                import xgboost

                train_X = np.array([[1, 2], [3, 4]])
                train_y = np.array([0, 1])

                model = xgboost.XGBClassifier().fit(train_X, train_y)
                checkpoint = XGBoostCheckpoint.from_model(model.get_booster())

        """
        checkpoint_path = Path(path or tempfile.mkdtemp())

        if not checkpoint_path.is_dir():
            raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}")

        booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix())

        checkpoint = cls.from_directory(checkpoint_path.as_posix())
        if preprocessor:
            checkpoint.set_preprocessor(preprocessor)
        return checkpoint

    def get_model(self) -> xgboost.Booster:
        """Retrieve the XGBoost model stored in this checkpoint."""
        with self.as_directory() as checkpoint_path:
            booster = xgboost.Booster()
            booster.load_model(Path(checkpoint_path, self.MODEL_FILENAME).as_posix())
            return booster