Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / train / sklearn / _sklearn_utils.py
Size: Mime:
from sklearn.base import BaseEstimator

# thread_count is a catboost parameter
SKLEARN_CPU_PARAM_NAMES = ["n_jobs", "thread_count"]


def _has_cpu_params(estimator: BaseEstimator) -> bool:
    """Returns True if estimator has any CPU-related params."""
    return any(
        any(
            param.endswith(cpu_param_name) for cpu_param_name in SKLEARN_CPU_PARAM_NAMES
        )
        for param in estimator.get_params(deep=True)
    )


def _set_cpu_params(estimator: BaseEstimator, num_cpus: int) -> None:
    """Sets all CPU-related params to num_cpus (incl. nested)."""
    cpu_params = {
        param: num_cpus
        for param in estimator.get_params(deep=True)
        if any(
            param.endswith(cpu_param_name) for cpu_param_name in SKLEARN_CPU_PARAM_NAMES
        )
    }
    estimator.set_params(**cpu_params)