Modelos de forecasting interpretables

Si te gusta  Skforecast , ayúdanos dándonos una estrella en  GitHub! ⭐️

Modelos de forecasting interpretables

Joaquín Amat Rodrigo, Javier Escobar Ortiz
Julio, 2024 (última actualización Agosto 2024)

Introducción

La interpretabilidad de los modelos predictivos, también conocida como explicabilidad, se refiere a la capacidad de entender, interpretar y explicar las decisiones o predicciones tomadas por los modelos de una forma comprensible para el ser humano. Su objetivo es comprender cómo un modelo llega a un determinado resultado o decisión.

Debido a la naturaleza compleja de muchos modelos modernos de machine learning, como los métodos ensemble, que a menudo funcionan como cajas negras, es dificil comprender por qué se ha hecho una predicción concreta. Las técnicas de explicabilidad pretenden desmitificar estos modelos, proporcionando información sobre su funcionamiento interno y ayudando a generar confianza, a mejorar la transparencia y a cumplir los requisitos normativos en diversos ámbitos. Mejorar la explicabilidad de los modelos no solo ayuda a comprender su comportamiento, sino también a identificar sesgos, a mejorar el rendimiento de los modelos y a permitir que las partes interesadas tomen decisiones más informadas basadas en los conocimientos del aprendizaje automático.

La librería skforecast es compatible con algunos de los métodos de interpretabilidad más utilizados: Shap values, Partial Dependency Plots y métodos específicos de los modelos.

Librerías

Librerías utilizadas en este documento.

In [24]:
# Manipulación de datos
# ==============================================================================
import pandas as pd
import numpy as np
from skforecast.datasets import fetch_dataset

# Gráficos
# ==============================================================================
import matplotlib.pyplot as plt
import shap
from skforecast.plot import set_dark_theme

# Modelado y forecasting
# ==============================================================================
import sklearn
import lightgbm
import skforecast
from sklearn.inspection import PartialDependenceDisplay
from lightgbm import LGBMRegressor
from skforecast.ForecasterAutoreg import ForecasterAutoreg

color = '\033[1m\033[38;5;208m'
print(f"{color}Versión skforecast: {skforecast.__version__}")
print(f"{color}Versión scikit-learn: {sklearn.__version__}")
print(f"{color}Versión lightgbm: {lightgbm.__version__}")
print(f"{color}Versión pandas: {pd.__version__}")
print(f"{color}Versión numpy: {np.__version__}")
Versión skforecast: 0.13.0
Versión scikit-learn: 1.5.1
Versión lightgbm: 4.4.0
Versión pandas: 2.2.2
Versión numpy: 2.0.1

Datos

In [25]:
# Descarga de los datos
# ==============================================================================
data = fetch_dataset(name="vic_electricity")
data.head(3)
vic_electricity
---------------
Half-hourly electricity demand for Victoria, Australia
O'Hara-Wild M, Hyndman R, Wang E, Godahewa R (2022).tsibbledata: Diverse
Datasets for 'tsibble'. https://tsibbledata.tidyverts.org/,
https://github.com/tidyverts/tsibbledata/.
https://tsibbledata.tidyverts.org/reference/vic_elec.html
Shape of the dataset: (52608, 4)
Out[25]:
Demand Temperature Date Holiday
Time
2011-12-31 13:00:00 4382.825174 21.40 2012-01-01 True
2011-12-31 13:30:00 4263.365526 21.05 2012-01-01 True
2011-12-31 14:00:00 4048.966046 20.70 2012-01-01 True
In [26]:
# Agregación a frecuencia diaria
# ==============================================================================
data = data.resample('D').agg({'Demand': 'sum', 'Temperature': 'mean'})
data.head(3)
Out[26]:
Demand Temperature
Time
2011-12-31 82531.745918 21.047727
2012-01-01 227778.257304 26.578125
2012-01-02 275490.988882 31.751042
In [27]:
# Crear variables de calendario
# ==============================================================================
data['day_of_week'] = data.index.dayofweek
data['month'] = data.index.month
data.head(3)
Out[27]:
Demand Temperature day_of_week month
Time
2011-12-31 82531.745918 21.047727 5 12
2012-01-01 227778.257304 26.578125 6 1
2012-01-02 275490.988882 31.751042 0 1
In [28]:
# División train-test
# ==============================================================================
end_train = '2014-12-01 23:59:00'
data_train = data.loc[: end_train, :]
data_test  = data.loc[end_train:, :]
print(f"Fechas train : {data_train.index.min()} --- {data_train.index.max()}  (n={len(data_train)})")
print(f"Fechas test  : {data_test.index.min()} --- {data_test.index.max()}  (n={len(data_test)})")
Fechas train : 2011-12-31 00:00:00 --- 2014-12-01 00:00:00  (n=1067)
Fechas test  : 2014-12-02 00:00:00 --- 2014-12-31 00:00:00  (n=30)

Modelos de forecasting

Se creará un modelo de forecasting para predecir la demanda de energía utilizando los últimos 7 valores (última semana) y la temperatura como variable exógena.

In [29]:
# Crear un forecaster recursivo de múltiples pasos (ForecasterAutoreg)
# ==============================================================================
exog_features = ['Temperature', 'day_of_week', 'month']
forecaster = ForecasterAutoreg(
                 regressor = LGBMRegressor(random_state=123, verbose=-1),
                 lags      = 7
             )

forecaster.fit(
    y    = data_train['Demand'],
    exog = data_train[exog_features],
)
forecaster
Out[29]:
================= 
ForecasterAutoreg 
================= 
Regressor: LGBMRegressor(random_state=123, verbose=-1) 
Lags: [1 2 3 4 5 6 7] 
Transformer for y: None 
Transformer for exog: None 
Window size: 7 
Weight function included: False 
Differentiation order: None 
Exogenous included: True 
Exogenous variables names: ['Temperature', 'day_of_week', 'month'] 
Training range: [Timestamp('2011-12-31 00:00:00'), Timestamp('2014-12-01 00:00:00')] 
Training index type: DatetimeIndex 
Training index frequency: D 
Regressor parameters: {'boosting_type': 'gbdt', 'class_weight': None, 'colsample_bytree': 1.0, 'importance_type': 'split', 'learning_rate': 0.1, 'max_depth': -1, 'min_child_samples': 20, 'min_child_weight': 0.001, 'min_split_gain': 0.0, 'n_estimators': 100, 'n_jobs': None, 'num_leaves': 31, 'objective': None, 'random_state': 123, 'reg_alpha': 0.0, 'reg_lambda': 0.0, 'subsample': 1.0, 'subsample_for_bin': 200000, 'subsample_freq': 0, 'verbose': -1} 
fit_kwargs: {} 
Creation date: 2024-08-10 20:48:01 
Last fit date: 2024-08-10 20:48:01 
Skforecast version: 0.13.0 
Python version: 3.12.4 
Forecaster id: None 

Importancia de predictores específica de los modelos

La importancia de los predictores en modelos de machine learning determina la relevancia de cada predictor (o variable) en la predicción de un modelo. En otras palabras, mide cuánto contribuye cada predictor al resultado del modelo.

La importancia de los predictores puede utilizarse para varios fines, como identificar aquellos más relevantes para una predicción determinada, comprender el comportamiento de un modelo y seleccionar el mejor conjunto de predictores para una tarea determinada. También puede ayudar a identificar posibles sesgos o errores en los datos utilizados para entrenar el modelo. Es importante señalar que la importancia de un predictor no es una medida definitiva de causalidad. El hecho de que una característica se identifique como importante no significa necesariamente que haya causado el resultado. También pueden intervenir otros factores, como las variables de confusión.

El método utilizado para calcular la importancia de los predictores puede variar en función del tipo de modelo de machine learning que se utilice. Los distintos modelos pueden tener distintos supuestos y características que afecten al cálculo de la importancia. Por ejemplo, los modelos basados en árboles de decisión, como Random Forest y Gradient Boosting, suelen utilizar métodos que miden la disminución de impurezas o el impacto de las permutaciones. Los modelos de regresión lineal suelen utilizar los coeficientes. La magnitud del coeficiente refleja la magnitud y la dirección de la relación entre el predictor y la variable objetivo.

La importancia de los predictores incluidos en un forecaster se puede obtener utilizando el método get_feature_importances(). Este método accede a los atributos coef_ y feature_importances_ del regresor interno.

Warning

El método `get_feature_importances()` solo devolverá valores si el regresor del forecaster tiene el atributo `coef_` o `feature_importances_`, que es el nombre por defecto en scikit-learn.
In [30]:
# Extraer importancia de los predictores
# ==============================================================================
importance = forecaster.get_feature_importances()
importance
Out[30]:
feature importance
7 Temperature 623
0 lag_1 442
6 lag_7 286
1 lag_2 268
2 lag_3 257
4 lag_5 252
5 lag_6 241
9 month 239
3 lag_4 232
8 day_of_week 160

Valores Shap

SHAP (SHapley Additive exPlanations) values are a popular method for explaining machine learning models, as they help to understand how variables and values influence predictions visually and quantitatively.

Los valores SHAP (SHapley Additive exPlanations) son una técnica de explicabilidad que se basa en la teoría de juegos cooperativos y en la idea de asignar un valor a cada jugador en función de su contribución al juego. En el contexto de los modelos de machine learning, los valores SHAP se utilizan para explicar cómo las variables y los valores influyen en las predicciones de un modelo. Proporcionan una forma de entender cómo se ha llegado a una predicción concreta, mostrando la contribución de cada variable a la predicción final.

Es posible calcular los valores SHAP a partir de los modelos de skforecast con solo dos elementos:

  • El regresor interno del forecaster.

  • Las matrices de entrenamiento creadas a partir de la serie temporal y variables exógenas, utilizadas para ajustar el forecaster.

Aprovechando estos dos componentes, los usuarios pueden crear explicaciones interpretables para sus modelos de skforecast. Estas explicaciones pueden utilizarse para verificar la fiabilidad del modelo, identificar los factores más significativos que contribuyen a las predicciones y comprender mejor la relación subyacente entre las variables de entrada y la variable objetivo.

Shap explainer y matrices de entrenamiento

In [31]:
# Matrices de entrenamiento usadas por el forecaster para ajustar el regresor interno
# ==============================================================================
X_train, y_train = forecaster.create_train_X_y(
                       y    = data_train['Demand'],
                       exog = data_train[exog_features],
                   )

display(X_train.head(3))
lag_1 lag_2 lag_3 lag_4 lag_5 lag_6 lag_7 Temperature day_of_week month
Time
2012-01-07 205338.714620 211066.426550 213792.376946 258955.329422 275490.988882 227778.257304 82531.745918 24.098958 5 1
2012-01-08 200693.270298 205338.714620 211066.426550 213792.376946 258955.329422 275490.988882 227778.257304 20.223958 6 1
2012-01-09 200061.614738 200693.270298 205338.714620 211066.426550 213792.376946 258955.329422 275490.988882 19.161458 0 1
In [32]:
# Crear un objeto explainer de SHAP (para modelos basados en árboles)
# ==============================================================================
explainer = shap.TreeExplainer(forecaster.regressor)

# Muestreo del 50% de los datos para acelerar el cálculo
rng = np.random.default_rng(seed=785412)
sample = rng.choice(X_train.index, size=int(len(X_train)*0.5), replace=False)
X_train_sample = X_train.loc[sample, :]
shap_values = explainer.shap_values(X_train_sample)

✎ Note

La librería Shap tiene varios explainers, cada uno diseñado para un tipo de modelo diferente. El explainer shap.TreeExplainer se utiliza para modelos basados en árboles, como el LGBMRegressor utilizado en este ejemplo. Para obtener más información, consulte la documentación de SHAP.

SHAP Summary Plot

El SHAP summary plot muestra la contribución de cada variable a la predicción del modelo en para varias observaciones. Muestra cuánto contribuye cada variable a alejar la predicción del modelo de un valor base (a menudo la predicción media del modelo). Al examinar un SHAP summary plot, se pueden obtener información sobre qué variables tienen un impacto más significativo en las predicciones, si influyen positiva o negativamente en el resultado y cómo contribuyen los diferentes valores de las variables a predicciones específicas.

In [33]:
# Shap summary plot (top 10)
# ==============================================================================
shap.initjs()
shap.summary_plot(shap_values, X_train_sample, max_display=10, show=False)
fig, ax = plt.gcf(), plt.gca()
ax.set_title("SHAP Summary plot")
ax.tick_params(labelsize=8)
fig.set_size_inches(6, 3)
In [34]:
shap.summary_plot(shap_values, X_train, plot_type="bar", plot_size=(6, 3))

Explicación de predicciones en datos de entrenamiento

Un shap.force_plot es un tipo específico de visualización que proporciona una vista interactiva y detallada de cómo las variables individuales contribuyen a una predicción concreta realizada por un modelo de machine learning. Es una herramienta de interpretación local que ayuda a comprender por qué un modelo ha realizado una predicción específica para una instancia dada.

Visualizar una única predicción

In [35]:
# Force plot para la primera observación del conjunto de entrenamiento
# ==============================================================================
shap.force_plot(explainer.expected_value, shap_values[0,:], X_train_sample.iloc[0,:])
Out[35]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Visualizar varias predicciones

In [36]:
# Force plot para las primeras 200 observaciones del conjunto de entrenamiento
# ==============================================================================
shap.force_plot(explainer.expected_value, shap_values[:200, :], X_train_sample.iloc[:200, :])
Out[36]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Gráfico de dependencia SHAP

Los gráficos de dependencia SHAP son visualizaciones utilizadas para entender la relación entre una variable y la salida del modelo. Para ello muestran cómo, el valor de una única variable, afecta a las predicciones realizadas por el modelo teniendo en cuenta las interacciones con otras variables. Estos gráficos son especialmente útiles para examinar cómo una determinada variable afecta a las predicciones del modelo en todo su rango de valores.

In [37]:
# Gráfico de dependencia para la variable Temperature
# ==============================================================================
fig, ax = plt.subplots(figsize=(6, 3))
shap.dependence_plot("Temperature", shap_values, X_train_sample, ax=ax)

Esplicar valores predichos

También es posible utilizar los valores SHAP para explicar las nuevas predicciones. Esto ayuda a entender por qué el modelo ha hecho una predicción específica para una fecha en el futuro.

In [38]:
# Forecasting de los próximos 30 días
# ==============================================================================
predictions = forecaster.predict(steps=30, exog=data_test[exog_features])
predictions
Out[38]:
2014-12-02    231850.159907
2014-12-03    233738.715964
2014-12-04    241516.294866
2014-12-05    207726.329720
2014-12-06    184258.602264
2014-12-07    190241.266911
2014-12-08    203667.298592
2014-12-09    201749.687739
2014-12-10    204691.377205
2014-12-11    206697.670901
2014-12-12    195775.940867
2014-12-13    199211.797661
2014-12-14    195942.950076
2014-12-15    216015.742784
2014-12-16    214146.238551
2014-12-17    208364.076939
2014-12-18    210368.280751
2014-12-19    195598.822566
2014-12-20    181364.886645
2014-12-21    199777.397488
2014-12-22    226236.725725
2014-12-23    223184.494875
2014-12-24    214212.828008
2014-12-25    209518.645332
2014-12-26    195934.342007
2014-12-27    181329.838726
2014-12-28    199747.138261
2014-12-29    208724.345197
2014-12-30    212262.252933
2014-12-31    222919.126896
Freq: D, Name: pred, dtype: float64
In [39]:
# Gráfico de las predicciones
# ==============================================================================
set_dark_theme()
fig, ax = plt.subplots(figsize=(6, 2.5))
data_test['Demand'].plot(ax=ax, label='Test')
predictions.plot(ax=ax, label='Predictions', linestyle='--')
ax.set_xlabel(None)
ax.legend();

El método create_predict_X se utiliza para crear la matriz de entrada utilizada internamente por el método predict del forecaster. Esta matriz se utiliza para generar los valores SHAP de los valores predichos.

In [40]:
# Crear la matriz de entrada usada por el método predict
# ==============================================================================
X_predict = forecaster.create_predict_X(steps=30, exog=data_test[exog_features])
X_predict
Out[40]:
lag_1 lag_2 lag_3 lag_4 lag_5 lag_6 lag_7 Temperature day_of_week month
2014-12-02 237812.592388 234970.336660 189653.758108 202017.012448 214602.854760 218321.456402 214318.765210 19.833333 1.0 12.0
2014-12-03 231850.159907 237812.592388 234970.336660 189653.758108 202017.012448 214602.854760 218321.456402 19.616667 2.0 12.0
2014-12-04 233738.715964 231850.159907 237812.592388 234970.336660 189653.758108 202017.012448 214602.854760 21.702083 3.0 12.0
2014-12-05 241516.294866 233738.715964 231850.159907 237812.592388 234970.336660 189653.758108 202017.012448 18.352083 4.0 12.0
2014-12-06 207726.329720 241516.294866 233738.715964 231850.159907 237812.592388 234970.336660 189653.758108 17.356250 5.0 12.0
2014-12-07 184258.602264 207726.329720 241516.294866 233738.715964 231850.159907 237812.592388 234970.336660 16.175000 6.0 12.0
2014-12-08 190241.266911 184258.602264 207726.329720 241516.294866 233738.715964 231850.159907 237812.592388 17.747917 0.0 12.0
2014-12-09 203667.298592 190241.266911 184258.602264 207726.329720 241516.294866 233738.715964 231850.159907 17.050000 1.0 12.0
2014-12-10 201749.687739 203667.298592 190241.266911 184258.602264 207726.329720 241516.294866 233738.715964 18.504167 2.0 12.0
2014-12-11 204691.377205 201749.687739 203667.298592 190241.266911 184258.602264 207726.329720 241516.294866 17.968750 3.0 12.0
2014-12-12 206697.670901 204691.377205 201749.687739 203667.298592 190241.266911 184258.602264 207726.329720 21.441667 4.0 12.0
2014-12-13 195775.940867 206697.670901 204691.377205 201749.687739 203667.298592 190241.266911 184258.602264 26.131250 5.0 12.0
2014-12-14 199211.797661 195775.940867 206697.670901 204691.377205 201749.687739 203667.298592 190241.266911 19.918750 6.0 12.0
2014-12-15 195942.950076 199211.797661 195775.940867 206697.670901 204691.377205 201749.687739 203667.298592 19.920833 0.0 12.0
2014-12-16 216015.742784 195942.950076 199211.797661 195775.940867 206697.670901 204691.377205 201749.687739 19.250000 1.0 12.0
2014-12-17 214146.238551 216015.742784 195942.950076 199211.797661 195775.940867 206697.670901 204691.377205 17.675000 2.0 12.0
2014-12-18 208364.076939 214146.238551 216015.742784 195942.950076 199211.797661 195775.940867 206697.670901 16.520833 3.0 12.0
2014-12-19 210368.280751 208364.076939 214146.238551 216015.742784 195942.950076 199211.797661 195775.940867 16.543750 4.0 12.0
2014-12-20 195598.822566 210368.280751 208364.076939 214146.238551 216015.742784 195942.950076 199211.797661 18.506250 5.0 12.0
2014-12-21 181364.886645 195598.822566 210368.280751 208364.076939 214146.238551 216015.742784 195942.950076 24.031250 6.0 12.0
2014-12-22 199777.397488 181364.886645 195598.822566 210368.280751 208364.076939 214146.238551 216015.742784 22.950000 0.0 12.0
2014-12-23 226236.725725 199777.397488 181364.886645 195598.822566 210368.280751 208364.076939 214146.238551 18.829167 1.0 12.0
2014-12-24 223184.494875 226236.725725 199777.397488 181364.886645 195598.822566 210368.280751 208364.076939 18.312500 2.0 12.0
2014-12-25 214212.828008 223184.494875 226236.725725 199777.397488 181364.886645 195598.822566 210368.280751 16.933333 3.0 12.0
2014-12-26 209518.645332 214212.828008 223184.494875 226236.725725 199777.397488 181364.886645 195598.822566 16.429167 4.0 12.0
2014-12-27 195934.342007 209518.645332 214212.828008 223184.494875 226236.725725 199777.397488 181364.886645 18.189583 5.0 12.0
2014-12-28 181329.838726 195934.342007 209518.645332 214212.828008 223184.494875 226236.725725 199777.397488 24.539583 6.0 12.0
2014-12-29 199747.138261 181329.838726 195934.342007 209518.645332 214212.828008 223184.494875 226236.725725 17.677083 0.0 12.0
2014-12-30 208724.345197 199747.138261 181329.838726 195934.342007 209518.645332 214212.828008 223184.494875 17.391667 1.0 12.0
2014-12-31 212262.252933 208724.345197 199747.138261 181329.838726 195934.342007 209518.645332 214212.828008 21.034615 2.0 12.0
In [41]:
# Valores SHAP para las predicciones
# ==============================================================================
shap_values = explainer.shap_values(X_predict)

Visualizar una única fecha predicha

In [42]:
# Force plot para una fecha predicha específica
# ==============================================================================
predicted_date = '2014-12-08'
iloc_predicted_date = X_predict.index.get_loc(predicted_date)
shap.force_plot(
    explainer.expected_value,
    shap_values[iloc_predicted_date,:],
    X_predict.iloc[iloc_predicted_date,:]
)
Out[42]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Visualizar varias fechas predichas

In [43]:
# Force plot para varias fechas predichas
# ==============================================================================
shap.force_plot(explainer.expected_value, shap_values, X_predict)
Out[43]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Scikit-learn gráficos de dependencia parcial

Scikit-learn permite crear gráficos de dependencia parcial utilizando la función plot_partial_dependence. Esta función visualiza el efecto de una o dos variables en el resultado predicho, marginalizando el efecto de todas las demás características. Estos gráficos proporcionan información sobre cómo influyen las variables seleccionadas en las predicciones del modelo y pueden ayudar a identificar relaciones no lineales o interacciones entre variables.

Una descripción más detallada sobre gráficos de dependencia parcial se puede encontrar en la Guía de Usuario de Scikitlearn.

In [44]:
# Scikit-learn gráfico de dependencia parcial
# ==============================================================================
fig, ax = plt.subplots(figsize=(8, 3))
pd.plots = PartialDependenceDisplay.from_estimator(
    estimator = forecaster.regressor,
    X         = X_train,
    features  = ["Temperature", "lag_1"],
    kind      = 'both',
    ax        = ax,
)
ax.set_title("Partial Dependence Plot")
fig.tight_layout();