Interpretable forecasting models

If you like  Skforecast ,  please give us a star on   GitHub! ⭐️

Interpretable forecasting models

Joaquín Amat Rodrigo, Javier Escobar Ortiz
July, 2024 (last update November 2024)

Introduction

Machine learning interpretability, also known as explainability, refers to the ability to understand, interpret, and explain the decisions or predictions made by machine learning models in a human-understandable way. It aims to shed light on how a model arrives at a particular result or decision.

Due to the complex nature of many modern machine learning models, such as ensemble methods, they often function as black boxes, making it difficult to understand why a particular prediction was made. Explanability techniques aim to demystify these models, providing insight into their inner workings and helping to build trust, improve transparency, and meet regulatory requirements in various domains. Improving model explainability not only helps to understand model behavior, but also helps to identify biases, improve model performance, and enable stakeholders to make more informed decisions based on machine learning insights.

The skforecast library is compatible with some of the most used interpretability methods: Shap values, Partial Dependency Plots and Model-specific methods.

Libraries

Libraries used in this document.

In [22]:
# Data manipulation
# ==============================================================================
import pandas as pd
import numpy as np
from skforecast.datasets import fetch_dataset

# Plotting
# ==============================================================================
import matplotlib.pyplot as plt
import shap
from skforecast.plot import set_dark_theme

# Modeling and forecasting
# ==============================================================================
import sklearn
import lightgbm
import skforecast
from sklearn.inspection import PartialDependenceDisplay
from lightgbm import LGBMRegressor
from skforecast.recursive import ForecasterRecursive
from skforecast.preprocessing import RollingFeatures

color = '\033[1m\033[38;5;208m'
print(f"{color}Version skforecast: {skforecast.__version__}")
print(f"{color}Version scikit-learn: {sklearn.__version__}")
print(f"{color}Version lightgbm: {lightgbm.__version__}")
print(f"{color}Version pandas: {pd.__version__}")
print(f"{color}Version numpy: {np.__version__}")
Version skforecast: 0.14.0
Version scikit-learn: 1.5.2
Version lightgbm: 4.4.0
Version pandas: 2.2.3
Version numpy: 2.0.2

Data

In [23]:
# Download data
# ==============================================================================
data = fetch_dataset(name="vic_electricity")
data.head(3)
vic_electricity
---------------
Half-hourly electricity demand for Victoria, Australia
O'Hara-Wild M, Hyndman R, Wang E, Godahewa R (2022).tsibbledata: Diverse
Datasets for 'tsibble'. https://tsibbledata.tidyverts.org/,
https://github.com/tidyverts/tsibbledata/.
https://tsibbledata.tidyverts.org/reference/vic_elec.html
Shape of the dataset: (52608, 4)
Out[23]:
Demand Temperature Date Holiday
Time
2011-12-31 13:00:00 4382.825174 21.40 2012-01-01 True
2011-12-31 13:30:00 4263.365526 21.05 2012-01-01 True
2011-12-31 14:00:00 4048.966046 20.70 2012-01-01 True
In [24]:
# Aggregation to daily frequency
# ==============================================================================
data = data.resample('D').agg({'Demand': 'sum', 'Temperature': 'mean'})
data.head(3)
Out[24]:
Demand Temperature
Time
2011-12-31 82531.745918 21.047727
2012-01-01 227778.257304 26.578125
2012-01-02 275490.988882 31.751042
In [25]:
# Create calendar variables
# ==============================================================================
data['day_of_week'] = data.index.dayofweek
data['month'] = data.index.month
data.head(3)
Out[25]:
Demand Temperature day_of_week month
Time
2011-12-31 82531.745918 21.047727 5 12
2012-01-01 227778.257304 26.578125 6 1
2012-01-02 275490.988882 31.751042 0 1
In [26]:
# Split train-test
# ==============================================================================
end_train = '2014-12-01 23:59:00'
data_train = data.loc[: end_train, :]
data_test  = data.loc[end_train:, :]
print(f"Dates train : {data_train.index.min()} --- {data_train.index.max()}  (n={len(data_train)})")
print(f"Dates test  : {data_test.index.min()} --- {data_test.index.max()}  (n={len(data_test)})")
Dates train : 2011-12-31 00:00:00 --- 2014-12-01 00:00:00  (n=1067)
Dates test  : 2014-12-02 00:00:00 --- 2014-12-31 00:00:00  (n=30)

Forecasting model

A forecasting model is created to predict the energy demand using the past 7 values (last week) and the temperature as an exogenous variable.

In [27]:
# Create a recursive multi-step forecaster (ForecasterAutoreg)
# ==============================================================================
window_features = RollingFeatures(stats=['mean'], window_sizes=24)
exog_features = ['Temperature', 'day_of_week', 'month']
forecaster = ForecasterRecursive(
                 regressor       = LGBMRegressor(random_state=123, verbose=-1),
                 lags            = 7,
                 window_features = window_features
             )

forecaster.fit(
    y    = data_train['Demand'],
    exog = data_train[exog_features],
)
forecaster
Out[27]:

ForecasterRecursive

General Information
  • Regressor: LGBMRegressor(random_state=123, verbose=-1)
  • Lags: [1 2 3 4 5 6 7]
  • Window features: ['roll_mean_24']
  • Window size: 24
  • Exogenous included: True
  • Weight function included: False
  • Differentiation order: None
  • Creation date: 2024-11-05 13:04:27
  • Last fit date: 2024-11-05 13:04:27
  • Skforecast version: 0.14.0
  • Python version: 3.12.4
  • Forecaster id: None
Exogenous Variables
    Temperature, day_of_week, month
Data Transformations
  • Transformer for y: None
  • Transformer for exog: None
Training Information
  • Training range: [Timestamp('2011-12-31 00:00:00'), Timestamp('2014-12-01 00:00:00')]
  • Training index type: DatetimeIndex
  • Training index frequency: D
Regressor Parameters
    {'boosting_type': 'gbdt', 'class_weight': None, 'colsample_bytree': 1.0, 'importance_type': 'split', 'learning_rate': 0.1, 'max_depth': -1, 'min_child_samples': 20, 'min_child_weight': 0.001, 'min_split_gain': 0.0, 'n_estimators': 100, 'n_jobs': None, 'num_leaves': 31, 'objective': None, 'random_state': 123, 'reg_alpha': 0.0, 'reg_lambda': 0.0, 'subsample': 1.0, 'subsample_for_bin': 200000, 'subsample_freq': 0, 'verbose': -1}
Fit Kwargs
    {}

🛈 API Reference    🗎 User Guide

Model-specific feature importances

Feature importance in machine learning determines the relevance or importance of each feature (or variable) in a model's prediction. In other words, it measures how much each feature contributes to the model's output.

Feature importance can be used for several purposes, such as identifying the most relevant features for a given prediction, understanding the behavior of a model, and selecting the best set of features for a given task. It can also help to identify potential biases or errors in the data used to train the model. It is important to note that feature importance is not a definitive measure of causality. Just because a feature is identified as important does not necessarily mean that it caused the outcome. Other factors, such as confounding variables, may also be at play.

The method used to calculate feature importance may vary depending on the type of machine learning model used. Different models can have different assumptions and characteristics that affect the importance calculation. For example, decision tree-based models, such as Random Forest and Gradient Boosting, typically use methods that measure the reduction of impurities or the effect of permutations. Linear regression models typically use coefficients. The magnitude of the coefficient reflects the strength and direction of the relationship between the predictor and the target variable.

The importance of the predictors included in a forecaster can be obtained using the method get_feature_importances(). This method accesses the coef_ and feature_importances_ attributes of the internal regressor.

Warning

The get_feature_importances() method will only return values if the forecaster's regressor has either the coef_ or feature_importances_ attribute, which is the default in scikit-learn.
In [28]:
# Extract feature importance
# ==============================================================================
importance = forecaster.get_feature_importances()
importance
Out[28]:
feature importance
8 Temperature 568
0 lag_1 426
1 lag_2 291
6 lag_7 248
4 lag_5 243
2 lag_3 236
7 roll_mean_24 227
10 month 224
5 lag_6 200
3 lag_4 177
9 day_of_week 160

Shap Values

SHAP (SHapley Additive exPlanations) values are a popular method for explaining machine learning models, as they help to understand how variables and values influence predictions visually and quantitatively.

It is possible to generate SHAP-values explanations from skforecast models with just two essential elements:

  • The internal regressor of the forecaster.

  • The training matrices created from the time series and used to fit the forecaster.

By leveraging these two components, users can create insightful and interpretable explanations for their skforecast models. These explanations can be used to verify the reliability of the model, identify the most significant factors that contribute to model predictions, and gain a deeper understanding of the underlying relationship between the input variables and the target variable.

Shap explainer and training matrices

In [29]:
# Training matrices used by the forecaster to fit the internal regressor
# ==============================================================================
X_train, y_train = forecaster.create_train_X_y(
                       y    = data_train['Demand'],
                       exog = data_train[exog_features],
                   )

display(X_train.head(3))
lag_1 lag_2 lag_3 lag_4 lag_5 lag_6 lag_7 roll_mean_24 Temperature day_of_week month
Time
2012-01-24 280188.298774 239810.374218 207949.859910 225035.325476 240187.677944 247722.494256 292458.685446 222658.202570 26.611458 1.0 1.0
2012-01-25 287474.816646 280188.298774 239810.374218 207949.859910 225035.325476 240187.677944 247722.494256 231197.497184 19.759375 2.0 1.0
2012-01-26 239083.684380 287474.816646 280188.298774 239810.374218 207949.859910 225035.325476 240187.677944 231668.556646 20.038542 3.0 1.0
In [30]:
# Create SHAP explainer (for three base models)
# ==============================================================================
explainer = shap.TreeExplainer(forecaster.regressor)

# Sample 50% of the data to speed up the calculation
rng = np.random.default_rng(seed=785412)
sample = rng.choice(X_train.index, size=int(len(X_train)*0.5), replace=False)
X_train_sample = X_train.loc[sample, :]
shap_values = explainer.shap_values(X_train_sample)

✎ Note

Shap library has several explainers, each designed for a different type of model. The shap.TreeExplainer explainer is used for tree-based models, such as the LGBMRegressor used in this example. For more information, see the SHAP documentation.

SHAP Summary Plot

The SHAP summary plot typically displays the feature importance or contribution of each feature to the model's output across multiple data points. It shows how much each feature contributes to pushing the model's prediction away from a base value (often the model's average prediction). By examining a SHAP summary plot, one can gain insights into which features have the most significant impact on predictions, whether they positively or negatively influence the outcome, and how different feature values contribute to specific predictions.

In [31]:
# Shap summary plot (top 10)
# ==============================================================================
shap.initjs()
shap.summary_plot(shap_values, X_train_sample, max_display=10, show=False)
fig, ax = plt.gcf(), plt.gca()
ax.set_title("SHAP Summary plot")
ax.tick_params(labelsize=8)
fig.set_size_inches(6, 3)
In [32]:
shap.summary_plot(shap_values, X_train, plot_type="bar", plot_size=(6, 3))

Explain predictions in training data

A shap.force_plot is a specific type of visualization that provides an interactive and detailed view of how individual features contribute to a particular prediction made by a machine learning model. It's a local interpretation tool that helps understand why a model made a specific prediction for a given instance.

Visualize a single prediction

In [33]:
# Force plot for the first observation in the training set
# ==============================================================================
shap.force_plot(explainer.expected_value, shap_values[0,:], X_train_sample.iloc[0,:])
Out[33]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Visualize many predictions

In [34]:
# Force plot for the first 200 observations in the training set
# ==============================================================================
shap.force_plot(explainer.expected_value, shap_values[:200, :], X_train_sample.iloc[:200, :])
Out[34]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

SHAP Dependence Plots

SHAP dependence plots are visualizations used to understand the relationship between a feature and the model output by displaying how the value of a single feature affects predictions made by the model while considering interactions with other features. These plots are particularly useful for examining how a certain feature impacts the model's predictions across its range of values while considering interactions with other variables.

In [35]:
# Dependence plot for Temperature
# ==============================================================================
fig, ax = plt.subplots(figsize=(6, 3))
shap.dependence_plot("Temperature", shap_values, X_train_sample, ax=ax)

Explain forecasted values

It is also possible to use SHAP values to explain the forecasted values. It helps to understand why the model made a specific prediction for a date in the future.

In [36]:
# Forecasting next 30 days
# ==============================================================================
predictions = forecaster.predict(steps=30, exog=data_test[exog_features])
predictions
Out[36]:
2014-12-02    230878.900870
2014-12-03    230782.656189
2014-12-04    237992.220195
2014-12-05    204807.430451
2014-12-06    178976.825634
2014-12-07    189987.574968
2014-12-08    205003.377184
2014-12-09    203696.919024
2014-12-10    209051.236412
2014-12-11    209968.258997
2014-12-12    205781.805973
2014-12-13    195530.724715
2014-12-14    201915.375753
2014-12-15    223563.240369
2014-12-16    218723.730618
2014-12-17    218731.280255
2014-12-18    211443.910502
2014-12-19    201556.174129
2014-12-20    183424.934062
2014-12-21    206318.055051
2014-12-22    235697.634764
2014-12-23    221620.562404
2014-12-24    218372.992891
2014-12-25    219748.601630
2014-12-26    203865.864787
2014-12-27    181653.080277
2014-12-28    200849.858496
2014-12-29    212235.635255
2014-12-30    210632.512899
2014-12-31    219284.469160
Freq: D, Name: pred, dtype: float64
In [37]:
# Plot predictions
# ==============================================================================
set_dark_theme()
fig, ax = plt.subplots(figsize=(6, 2.5))
data_test['Demand'].plot(ax=ax, label='Test')
predictions.plot(ax=ax, label='Predictions', linestyle='--')
ax.set_xlabel(None)
ax.legend();

The method create_predict_X is used to create the input matrix used internally by the forecaster's predict method. This matrix is then used to generate SHAP values for the forecasted values.

In [38]:
# Create input matrix used by the predict method
# ==============================================================================
X_predict = forecaster.create_predict_X(steps=30, exog=data_test[exog_features])
X_predict
Out[38]:
lag_1 lag_2 lag_3 lag_4 lag_5 lag_6 lag_7 roll_mean_24 Temperature day_of_week month
2014-12-02 237812.592388 234970.336660 189653.758108 202017.012448 214602.854760 218321.456402 214318.765210 211369.709659 19.833333 1.0 12.0
2014-12-03 230878.900870 237812.592388 234970.336660 189653.758108 202017.012448 214602.854760 218321.456402 212777.981610 19.616667 2.0 12.0
2014-12-04 230782.656189 230878.900870 237812.592388 234970.336660 189653.758108 202017.012448 214602.854760 214485.198829 21.702083 3.0 12.0
2014-12-05 237992.220195 230782.656189 230878.900870 237812.592388 234970.336660 189653.758108 202017.012448 215457.947659 18.352083 4.0 12.0
2014-12-06 204807.430451 237992.220195 230782.656189 230878.900870 237812.592388 234970.336660 189653.758108 214962.018955 17.356250 5.0 12.0
2014-12-07 178976.825634 204807.430451 237992.220195 230782.656189 230878.900870 237812.592388 234970.336660 213141.307938 16.175000 6.0 12.0
2014-12-08 189987.574968 178976.825634 204807.430451 237992.220195 230782.656189 230878.900870 237812.592388 210937.056718 17.747917 0.0 12.0
2014-12-09 205003.377184 189987.574968 178976.825634 204807.430451 237992.220195 230782.656189 230878.900870 210786.576389 17.050000 1.0 12.0
2014-12-10 203696.919024 205003.377184 189987.574968 178976.825634 204807.430451 237992.220195 230782.656189 211532.961206 18.504167 2.0 12.0
2014-12-11 209051.236412 203696.919024 205003.377184 189987.574968 178976.825634 204807.430451 237992.220195 212209.477689 17.968750 3.0 12.0
2014-12-12 209968.258997 209051.236412 203696.919024 205003.377184 189987.574968 178976.825634 204807.430451 212188.542641 21.441667 4.0 12.0
2014-12-13 205781.805973 209968.258997 209051.236412 203696.919024 205003.377184 189987.574968 178976.825634 212016.059027 26.131250 5.0 12.0
2014-12-14 195530.724715 205781.805973 209968.258997 209051.236412 203696.919024 205003.377184 189987.574968 210936.875468 19.918750 6.0 12.0
2014-12-15 201915.375753 195530.724715 205781.805973 209968.258997 209051.236412 203696.919024 205003.377184 209902.060260 19.920833 0.0 12.0
2014-12-16 223563.240369 201915.375753 195530.724715 205781.805973 209968.258997 209051.236412 203696.919024 210653.141682 19.250000 1.0 12.0
2014-12-17 218723.730618 223563.240369 201915.375753 195530.724715 205781.805973 209968.258997 209051.236412 211786.711094 17.675000 2.0 12.0
2014-12-18 218731.280255 218723.730618 223563.240369 201915.375753 195530.724715 205781.805973 209968.258997 212502.799259 16.520833 3.0 12.0
2014-12-19 211443.910502 218731.280255 218723.730618 223563.240369 201915.375753 195530.724715 205781.805973 212022.176837 16.543750 4.0 12.0
2014-12-20 201556.174129 211443.910502 218731.280255 218723.730618 223563.240369 201915.375753 195530.724715 211490.402208 18.506250 5.0 12.0
2014-12-21 183424.934062 201556.174129 211443.910502 218731.280255 218723.730618 223563.240369 201915.375753 210036.380444 24.031250 6.0 12.0
2014-12-22 206318.055051 183424.934062 201556.174129 211443.910502 218731.280255 218723.730618 223563.240369 209691.180456 22.950000 0.0 12.0
2014-12-23 235697.634764 206318.055051 183424.934062 201556.174129 211443.910502 218731.280255 218723.730618 211094.539720 18.829167 1.0 12.0
2014-12-24 221620.562404 235697.634764 206318.055051 183424.934062 201556.174129 211443.910502 218731.280255 212426.489899 18.312500 2.0 12.0
2014-12-25 218372.992891 221620.562404 235697.634764 206318.055051 183424.934062 201556.174129 211443.910502 211734.933908 16.933333 3.0 12.0
2014-12-26 219748.601630 218372.992891 221620.562404 235697.634764 206318.055051 183424.934062 201556.174129 210982.267627 16.429167 4.0 12.0
2014-12-27 203865.864787 219748.601630 218372.992891 221620.562404 235697.634764 206318.055051 183424.934062 209856.724457 18.189583 5.0 12.0
2014-12-28 181653.080277 203865.864787 219748.601630 218372.992891 221620.562404 235697.634764 206318.055051 207809.658794 24.539583 6.0 12.0
2014-12-29 200849.858496 181653.080277 203865.864787 219748.601630 218372.992891 221620.562404 235697.634764 206262.060389 17.677083 0.0 12.0
2014-12-30 212235.635255 200849.858496 181653.080277 203865.864787 219748.601630 218372.992891 221620.562404 206571.568923 17.391667 1.0 12.0
2014-12-31 210632.512899 212235.635255 200849.858496 181653.080277 203865.864787 219748.601630 218372.992891 207890.555892 21.034615 2.0 12.0
In [39]:
# SHAP values for the predictions
# ==============================================================================
shap_values = explainer.shap_values(X_predict)

Visualize a single forecasted date

In [40]:
# Force plot for a specific forecasted date
# ==============================================================================
predicted_date = '2014-12-08'
iloc_predicted_date = X_predict.index.get_loc(predicted_date)
shap.force_plot(
    explainer.expected_value,
    shap_values[iloc_predicted_date,:],
    X_predict.iloc[iloc_predicted_date,:]
)
Out[40]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Visualize many forecasted dates

In [41]:
# Force plot for all forecasted dates
# ==============================================================================
shap.force_plot(explainer.expected_value, shap_values, X_predict)
Out[41]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Scikit-learn partial dependence plots

Partial dependence plots (PDPs) are a useful tool for understanding the relationship between a feature and the target outcome in a machine learning model. In scikit-learn, you can create partial dependence plots using the plot_partial_dependence function. This function visualizes the effect of one or two features on the predicted outcome, while marginalizing the effect of all other features.

The resulting plots show how changes in the selected feature(s) affect the predicted outcome while holding other features constant on average. Remember that these plots should be interpreted in the context of your model and data. They provide insight into the relationship between specific features and the model's predictions.

A more detailed description of the Partial Dependency Plot can be found in Scikitlearn's User Guides.

In [42]:
# Scikit-learn partial dependence plots
# ==============================================================================
fig, ax = plt.subplots(figsize=(8, 3))
pd.plots = PartialDependenceDisplay.from_estimator(
    estimator = forecaster.regressor,
    X         = X_train,
    features  = ["Temperature", "lag_1"],
    kind      = 'both',
    ax        = ax,
)
ax.set_title("Partial Dependence Plot")
fig.tight_layout();

Session information

In [43]:
import session_info
session_info.show(html=False)
-----
lightgbm            4.4.0
matplotlib          3.9.0
numpy               2.0.2
pandas              2.2.3
session_info        1.0.0
shap                0.46.0
skforecast          0.14.0
sklearn             1.5.2
-----
IPython             8.25.0
jupyter_client      8.6.2
jupyter_core        5.7.2
notebook            6.4.12
-----
Python 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:03:56) [MSC v.1929 64 bit (AMD64)]
Windows-11-10.0.26100-SP0
-----
Session information updated at 2024-11-05 13:04

Citation

How to cite this document

If you use this document or any part of it, please acknowledge the source, thank you!

Interpretable forecasting models by Joaquín Amat Rodrigo and Javier Escobar Ortiz, available under a Attribution-NonCommercial-ShareAlike 4.0 International at https://www.cienciadedatos.net/documentos/py57-interpretable-forecasting-models.html

How to cite skforecast

If you use skforecast for a scientific publication, we would appreciate it if you cite the published software.

Zenodo:

Amat Rodrigo, Joaquin, & Escobar Ortiz, Javier. (2024). skforecast (v0.14.0). Zenodo. https://doi.org/10.5281/zenodo.8382788

APA:

Amat Rodrigo, J., & Escobar Ortiz, J. (2024). skforecast (Version 0.14.0) [Computer software]. https://doi.org/10.5281/zenodo.8382788

BibTeX:

@software{skforecast, author = {Amat Rodrigo, Joaquin and Escobar Ortiz, Javier}, title = {skforecast}, version = {0.14.0}, month = {11}, year = {2024}, license = {BSD-3-Clause}, url = {https://skforecast.org/}, doi = {10.5281/zenodo.8382788} }


Did you like the article? Your support is important

Website maintenance has high cost, your contribution will help me to continue generating free educational content. Many thanks! 😊


Creative Commons Licence
This work by Joaquín Amat Rodrigo and Javier Escobar Ortiz is licensed under a Attribution-NonCommercial-ShareAlike 4.0 International.