Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tid-gradient-boosting-model / train_pipeline.py
Size: Mime:
from sklearn.model_selection import train_test_split

from gradient_boosting_model import pipeline
from gradient_boosting_model.processing.data_management import (
    load_dataset,
    save_pipeline,
)
from gradient_boosting_model.config.core import config
from gradient_boosting_model import __version__ as _version

import logging


_logger = logging.getLogger(__name__)


def run_training() -> None:
    """Train the model."""

    # read training data
    data = load_dataset(file_name=config.app_config.training_data_file)

    # divide train and test
    X_train, X_test, y_train, y_test = train_test_split(
        data[config.gradient_boosting_model_config.features],  # predictors
        data[config.gradient_boosting_model_config.target],
        test_size=config.gradient_boosting_model_config.test_size,
        random_state=config.gradient_boosting_model_config.random_state,
    )

    pipeline.price_pipe.fit(X_train, y_train)

    _logger.warning(f"saving model version: {_version}")
    save_pipeline(pipeline_to_persist=pipeline.price_pipe)


if __name__ == "__main__":
    run_training()