Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tid-gradient-boosting-model / processing / data_management.py
Size: Mime:
import pandas as pd
import joblib
from sklearn.pipeline import Pipeline
from gradient_boosting_model.config.core import config, DATASET_DIR, TRAINED_MODEL_DIR
from gradient_boosting_model import __version__ as _version

import logging
from typing import List

_logger = logging.getLogger(__name__)


def load_dataset(*, file_name: str) -> pd.DataFrame:
    """Load the dataset and rename specified columns."""
    dataframe = pd.read_csv(f"{DATASET_DIR}/{file_name}")

    # Rename variables beginning with numbers to avoid syntax errors later
    transformed = dataframe.rename(
        columns=config.gradient_boosting_model_config.variables_to_rename
    )
    return transformed  # return type: pd.DataFrame


def save_pipeline(*, pipeline_to_persist: Pipeline) -> None:
    """Persist the pipeline.

    Saves the versioned model, and overwrites any previous
    saved models. This ensures that when the package is
    published, there is only one trained model that can be
    called, and we know exactly how it was built.
    """

    # Prepare versioned save file name
    save_file_name = f"{config.app_config.pipeline_save_file}{_version}.pkl"
    save_path = TRAINED_MODEL_DIR / save_file_name

    remove_old_pipelines(files_to_keep=[save_file_name])
    joblib.dump(pipeline_to_persist, save_path)
    _logger.info(f"Saved pipeline: {save_file_name}")  # return type: None


def load_pipeline(*, file_name: str) -> Pipeline:
    """Load a persisted pipeline."""
    file_path = TRAINED_MODEL_DIR / file_name
    trained_model = joblib.load(filename=file_path)
    return trained_model  # return type: Pipeline


def remove_old_pipelines(*, files_to_keep: List[str]) -> None:
    """
    Remove old model pipelines.

    This is to ensure there is a simple one-to-one
    mapping between the package version and the model
    version to be imported and used by other applications.
    """
    do_not_delete = files_to_keep + ["__init__.py"]
    for model_file in TRAINED_MODEL_DIR.iterdir():
        if model_file.name not in do_not_delete:
            model_file.unlink()  # return type: None