Forecasting time series with gradient boosting: Skforecast, XGBoost, LightGBM and CatBoost

If you like  Skforecast ,  help us giving a star on   GitHub! ⭐️

Forecasting time series with gradient boosting: Skforecast, XGBoost, LightGBM, Scikit-learn and CatBoost

Joaquín Amat Rodrigo, Javier Escobar Ortiz
February, 2021 (last update April 2024)


Gradient boosting models have gained popularity in the machine learning community due to their ability to achieve excellent results in a wide range of use cases, including both regression and classification. Although these models have traditionally been less common in forecasting, they can be highly effective in this domain. Some of the key benefits of using gradient boosting models for forecasting include:

  • The ease with which exogenous variables can be included in the model, in addition to autoregressive variables.

  • The ability to capture non-linear relationships between variables.

  • High scalability, allowing models to handle large volumes of data.

  • Some implementations allow the inclusion of categorical variables without the need for additional encoding, such as one-hot encoding.

Despite these benefits, the use of machine learning models for forecasting can present several challenges that can make analysts reluctant to use them, the main ones being:

  • Transforming the data so that it can be used as a regression problem.

  • Depending on how many future predictions are needed (prediction horizon), an iterative process may be required where each new prediction is based on previous ones.

  • Model validation requires specific strategies such as backtesting, walk-forward validation or time series cross-validation. Traditional cross-validation cannot be used.

The skforecast library provides automated solutions to these challenges, making it easier to apply and validate machine learning models to forecasting problems. The library supports several advanced gradient boosting models, including XGBoost, LightGBM, Catboost and scikit-learn HistGradientBoostingRegressor. This document shows how to use them to build accurate forecasting models. To ensure a smooth learning experience, an initial exploration of the data is performed. Then, the modeling process is explained step by step, starting with a recursive model utilizing a LightGBM regressor and moving on to a model incorporating exogenous variables and various coding strategies. The document concludes by demonstrating the use of other gradient boosting model implementations, including XGBoost, CatBoost, and the scikit-learn HistGradientBoostingRegressor.

✎ Note

Machine learning models do not always outperform statistical learning models such as AR, ARIMA or Exponential Smoothing. Which one works best depends largely on the characteristics of the use case to which it is applied. Visit ARIMA and SARIMAX models with skforecast to learn more about statistical models.

✎ Note

Additional examples of how using gradient boosting models for forecasting can be found in the documents Forecasting energy demand with machine learning and Global Forecasting Models.

Use case

Bicycle sharing is a popular shared transport service that provides bicycles to individuals for short-term use. These systems typically provide bike docks where riders can borrow a bike and return it to any dock belonging to the same system. The docks are equipped with special bike racks that secure the bike and only release it via computer control. One of the major challenges faced by operators of these systems is the need to redistribute bikes to ensure that there are bikes available at all docks, as well as free spaces for returns.

In order to improve the planning and execution of bicycle distribution, it is proposed to create a model capable of forecasting the number of users over the next 36 hours. In this way, at 12:00 every day, the company in charge of managing the system will be able to know the expected users for the rest of the day (12 hours) and the next day (24 hours).

For illustrative purposes, the current example only models a single station, but it can be easily extended to cover multiple stations unsing global multi-series forecasting, thereby improving the management of bike-sharing systems on a larger scale.


Libraries used in this document.

In [1]:
# Data processing
# ==============================================================================
import numpy as np
import pandas as pd
from astral.sun import sun
from astral import LocationInfo
from skforecast.datasets import fetch_dataset

# Plots
# ==============================================================================
import matplotlib.pyplot as plt
from import plot_acf
from import plot_pacf
from skforecast.plot import plot_residuals
import plotly.graph_objects as go
import as pio
import plotly.offline as poff
pio.templates.default = "seaborn"

# Modelling and Forecasting
# ==============================================================================
import xgboost
import lightgbm
import catboost
import sklearn
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from catboost import CatBoostRegressor
from sklearn.ensemble import HistGradientBoostingRegressor

from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import FunctionTransformer
from sklearn.preprocessing import PolynomialFeatures
from sklearn.feature_selection import RFECV
from sklearn.pipeline import make_pipeline
from sklearn.compose import make_column_transformer
from sklearn.compose import make_column_selector

import skforecast
from skforecast.ForecasterBaseline import ForecasterEquivalentDate
from skforecast.ForecasterAutoreg import ForecasterAutoreg
from skforecast.model_selection import bayesian_search_forecaster
from skforecast.model_selection import backtesting_forecaster
from skforecast.model_selection import select_features
import shap

# Warnings configuration
# ==============================================================================
import warnings

print(f"Version skforecast: {skforecast.__version__}")
print(f"Version scikit-learn: {sklearn.__version__}")
print(f"Version lightgbm: {lightgbm.__version__}")
print(f"Version xgboost: {xgboost.__version__}")
print(f"Version catboost: {catboost.__version__}")
Version skforecast: 0.12.0
Version scikit-learn: 1.4.2
Version lightgbm: 4.3.0
Version xgboost: 2.0.3
Version catboost: 1.2.5


The data in this document represent the hourly usage of the bike share system in the city of Washington, D.C. during the years 2011 and 2012. In addition to the number of users per hour, information about weather conditions and holidays is available. The original data was obtained from the UCI Machine Learning Repository and has been previously cleaned (code) applying the following modifications:

  • Renamed columns with more descriptive names.

  • Renamed categories of the weather variables. The category of heavy_rain has been combined with that of rain.

  • Denormalized temperature, humidity and wind variables.

  • date_time variable created and set as index.

  • Imputed missing values by forward filling.

In [2]:
# Downloading data
# ==============================================================================
data = fetch_dataset('bike_sharing', raw=True)
Hourly usage of the bike share system in the city of Washington D.C. during the
years 2011 and 2012. In addition to the number of users per hour, information
about weather conditions and holidays is available.
Fanaee-T,Hadi. (2013). Bike Sharing Dataset. UCI Machine Learning Repository.
Shape of the dataset: (17544, 12)
In [3]:
# Preprocessing data (setting index and frequency)
# ==============================================================================
data = data[['date_time', 'users', 'holiday', 'weather', 'temp', 'atemp', 'hum', 'windspeed']]
data['date_time'] = pd.to_datetime(data['date_time'], format='%Y-%m-%d %H:%M:%S')
data = data.set_index('date_time')
data = data.asfreq('H')
data = data.sort_index()
users holiday weather temp atemp hum windspeed
2011-01-01 00:00:00 16.0 0.0 clear 9.84 14.395 81.0 0.0
2011-01-01 01:00:00 40.0 0.0 clear 9.02 13.635 80.0 0.0
2011-01-01 02:00:00 32.0 0.0 clear 9.02 13.635 80.0 0.0
2011-01-01 03:00:00 13.0 0.0 clear 9.84 14.395 75.0 0.0
2011-01-01 04:00:00 1.0 0.0 clear 9.84 14.395 75.0 0.0

To facilitate the training of the models, the search for optimal hyperparameters and the evaluation of their predictive accuracy, the data are divided into three separate sets: training, validation and test.

In [4]:
# Split train-validation-test
# ==============================================================================
end_train = '2012-03-31 23:59:00'
end_validation = '2012-08-31 23:59:00'
data_train = data.loc[: end_train, :]
data_val   = data.loc[end_train:end_validation, :]
data_test  = data.loc[end_validation:, :]

print(f"Dates train      : {data_train.index.min()} --- {data_train.index.max()}  (n={len(data_train)})")
print(f"Dates validacion : {data_val.index.min()} --- {data_val.index.max()}  (n={len(data_val)})")
print(f"Dates test       : {data_test.index.min()} --- {data_test.index.max()}  (n={len(data_test)})")
Dates train      : 2011-01-01 00:00:00 --- 2012-03-31 23:00:00  (n=10944)
Dates validacion : 2012-04-01 00:00:00 --- 2012-08-31 23:00:00  (n=3672)
Dates test       : 2012-09-01 00:00:00 --- 2012-12-31 23:00:00  (n=2928)

Data exploration

Graphical exploration of time series can be an effective way of identifying trends, patterns, and seasonal variations. This, in turn, helps to guide the selection of the most appropriate forecasting model.

Plot time series

Full time series

In [5]:
# Interactive plot of time series
# ==============================================================================
fig = go.Figure()
fig.add_trace(go.Scatter(x=data_train.index, y=data_train['users'], mode='lines', name='Train'))
fig.add_trace(go.Scatter(x=data_val.index, y=data_val['users'], mode='lines', name='Validation'))
fig.add_trace(go.Scatter(x=data_test.index, y=data_test['users'], mode='lines', name='Test'))
    title  = 'Number of users',
    margin=dict(l=20, r=20, t=35, b=20),
In [6]:
# Static plot of time series with zoom
# ==============================================================================
zoom = ('2011-08-01 00:00:00', '2011-08-15 00:00:00')
fig = plt.figure(figsize=(8, 4))
grid = plt.GridSpec(nrows=8, ncols=1, hspace=0.1, wspace=0)
main_ax = fig.add_subplot(grid[1:3, :])

data_train['users'].plot(ax=main_ax, label='train', alpha=0.5)
data_val['users'].plot(ax=main_ax, label='validation', alpha=0.5)
data_test['users'].plot(ax=main_ax, label='test', alpha=0.5)
min_y = min(data['users'])
max_y = max(data['users'])
main_ax.fill_between(zoom, min_y, max_y, facecolor='blue', alpha=0.5, zorder=0)
main_ax.set_title(f'Number of users: {data.index.min()}, {data.index.max()}', fontsize=10)
main_ax.legend(loc='lower center', ncol=3, bbox_to_anchor=(0.5, -0.8))
zoom_ax = fig.add_subplot(grid[5:, :])
data.loc[zoom[0]: zoom[1]]['users'].plot(ax=zoom_ax, color='blue', linewidth=1)
zoom_ax.set_title(f'Number of users: {zoom}', fontsize=10)

Seasonality plots

Seasonal plots are a useful tool for identifying seasonal patterns and trends in a time series. They are created by averaging the values of each season over time and then plotting them against time.

In [7]:
# Annual, weekly and daily seasonality
# ==============================================================================
fig, axs = plt.subplots(2, 2, figsize=(8.5, 5.5), sharex=False, sharey=True)
axs = axs.ravel()

# Users distribution by month
data['month'] = data.index.month
data.boxplot(column='users', by='month', ax=axs[0])
data.groupby('month')['users'].median().plot(style='o-', linewidth=0.8, ax=axs[0])
axs[0].set_title('Users distribution by month', fontsize=10)

# Users distribution by week day
data['week_day'] = data.index.day_of_week + 1
data.boxplot(column='users', by='week_day', ax=axs[1])
data.groupby('week_day')['users'].median().plot(style='o-', linewidth=0.8, ax=axs[1])
axs[1].set_title('Users distribution by week day', fontsize=10)

# Users distribution by the hour of the day
data['hour_day'] = data.index.hour + 1
data.boxplot(column='users', by='hour_day', ax=axs[2])
data.groupby('hour_day')['users'].median().plot(style='o-', linewidth=0.8, ax=axs[2])
axs[2].set_title('Users distribution by the hour of the day', fontsize=10)

# Users distribution by week day and hour of the day
mean_day_hour = data.groupby(["week_day", "hour_day"])["users"].mean()
    title       = "Mean users during week",
    xticks      = [i * 24 for i in range(7)],
    xticklabels = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"],
    xlabel      = "Day and hour",
    ylabel      = "Number of users"

fig.suptitle("Seasonality plots", fontsize=14)

There is a clear difference between weekdays and weekends. There is also a clear intra-day pattern, with a different influx of users depending on the time of day.

Autocorrelation plots

Auto-correlation plots are a useful tool for identifying the order of an autoregressive model. The autocorrelation function (ACF) is a measure of the correlation between the time series and a lagged version of itself. The partial autocorrelation function (PACF) is a measure of the correlation between the time series and a lagged version of itself, controlling for the values of the time series at all shorter lags. These plots are useful for identifying the lags to be included in the autoregressive model.

In [8]:
# Autocorrelation plot
# ==============================================================================
fig, ax = plt.subplots(figsize=(5, 2))
plot_acf(data['users'], ax=ax, lags=72)