Forecasting energy demand with machine learning

If you like  Skforecast ,  help us giving a star on   GitHub! ⭐️

Forecasting energy demand with machine learning

Joaquín Amat Rodrigo, Javier Escobar Ortiz
April, 2021 (last update April 2024)


Energy demand forecasting plays a critical role in effectively managing and planning resources for power generation, distribution, and utilization. Predicting energy demand is a complex task influenced by factors such as weather patterns, economic conditions, and societal behavior. This document will examine the creation of forecasting models utilizing machine learning to predict energy demand.

Time series and forecasting

A time series is a sequence of chronologically ordered data at equal or unequal intervals. The forecasting process consists of predicting the future value of a time series, either by modelling the series solely on the basis of its past behavior (autoregressive) or by using other external variables.

When working with time series, it is rarely necessary to predict only the next element in the series ($t_{+1}$). Instead, the most common goal is to forecast a whole future interval (($t_{+1}$), ..., ($t_{+n}$)) or a far future time ($t_{+n}$). Several strategies can be used to generate this type of forecast, skforecast has implemented the following for univariate time series forecasting:

  • Recursive multi-step forecasting: since the value $t_{n-1}$ is needed to predict $t_{n}$, and $t_{n-1}$ is unknown, a recursive process is applied in which, each new prediction, is based on the previous one. This process is known as recursive forecasting or recursive multi-step forecasting and can be easily generated with the ForecasterAutoreg and ForecasterAutoregCustom classes.

  • Direct multi-step forecasting: this method consists of training a different model for each step of the forecast horizon. For example, to predict the next 5 values of a time series, 5 different models are trained, one for each step. As a result, the predictions are independent of each other. This entire process is automated in the ForecasterAutoregDirect class.


The libraries used in this document are:

In [1]:
# Data manipulation
# ==============================================================================
import numpy as np
import pandas as pd
from astral.sun import sun
from astral import LocationInfo
from skforecast.datasets import fetch_dataset

# Plots
# ==============================================================================
import matplotlib.pyplot as plt
from import plot_acf
from import plot_pacf
from skforecast.plot import plot_residuals
import plotly.graph_objects as go
import as pio
import plotly.offline as poff
pio.templates.default = "seaborn"

# Modelling and Forecasting
# ==============================================================================
import skforecast
import lightgbm
import sklearn
from lightgbm import LGBMRegressor
from sklearn.preprocessing import PolynomialFeatures
from sklearn.feature_selection import RFECV
from skforecast.ForecasterBaseline import ForecasterEquivalentDate
from skforecast.ForecasterAutoreg import ForecasterAutoreg
from skforecast.ForecasterAutoregDirect import ForecasterAutoregDirect
from skforecast.model_selection import bayesian_search_forecaster
from skforecast.model_selection import backtesting_forecaster
from skforecast.model_selection import select_features
import shap

# Warnings configuration
# ==============================================================================
import warnings

print(f"Versión skforecast: {skforecast.__version__}")
print(f"Versión lightgbm: {lightgbm.__version__}")
print(f"Versión scikit-learn: {sklearn.__version__}")
Versión skforecast: 0.12.0
Versión lightgbm: 4.3.0
Versión scikit-learn: 1.3.2


A time series of electricity demand (MW) is available for the state of Victoria (Australia) from 2011-12-31 to 2014-12-31 is available. The data used in this document has been obtained from the R tsibbledata package. The dataset contains 5 columns and 52,608 complete records. The information in each column is:

  • Time: date and time of the record.
  • Date: date of the record.
  • Demand: electricity demand (MW).
  • Temperature: temperature in Melbourne, the capital of Victoria.
  • Holiday: indicates if the day is a public holiday.
In [2]:
# Data download
# ==============================================================================
data = fetch_dataset(name='vic_electricity', raw=True)
Half-hourly electricity demand for Victoria, Australia
O'Hara-Wild M, Hyndman R, Wang E, Godahewa R (2022).tsibbledata: Diverse
Datasets for 'tsibble'.,
Shape of the dataset: (52608, 5)
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 52608 entries, 0 to 52607
Data columns (total 5 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   Time         52608 non-null  object 
 1   Demand       52608 non-null  float64
 2   Temperature  52608 non-null  float64
 3   Date         52608 non-null  object 
 4   Holiday      52608 non-null  bool   
dtypes: bool(1), float64(2), object(2)
memory usage: 1.7+ MB

The Time column is stored as a string. To convert it to datetime, the function pd.to_datetime() is used. Once in datetime format, and to make use of pandas functionalities, it is set as an index. Since the data was recorded every 30 minutes, the frequency '30min' is also specified.

In [3]:
# Data preparation
# ==============================================================================
data = data.copy()
data['Time'] = pd.to_datetime(data['Time'], format='%Y-%m-%dT%H:%M:%SZ')
data = data.set_index('Time')
data = data.asfreq('30min')
data = data.sort_index()
Demand Temperature Date Holiday
2011-12-31 13:00:00 4382.825174 21.40 2012-01-01 True
2011-12-31 13:30:00 4263.365526 21.05 2012-01-01 True

One of the first analyses to be carried out when working with time series is to check that the series is complete, that is, that there are no missing values.

In [4]:
# Verify that a temporary index is complete
# ==============================================================================
(data.index == pd.date_range(start=data.index.min(),

print(f"Number of rows with missing values: {data.isnull().any(axis=1).mean()}")
Number of rows with missing values: 0.0
In [5]:
# Fill gaps in a temporary index
# ==============================================================================
# data.asfreq(freq='30min', fill_value=np.nan)

Although the data is at 30 minute intervals, the aim is to create a model capable of predicting hourly electricity demand, so the data needs to be aggregated. This type of transformation can be done very easily by combining the Pandas DatetimeIndex index and its resample() method.

It is very important to use the closed='left' and label='right' arguments correctly to avoid introducing future information into the training, leakage). Suppose that values are available for 10:10, 10:30, 10:45, 11:00, 11:12, and 11:30. To obtain the hourly average, the value assigned to 11:00 must be calculated using the values for 10:10, 10:30, and 10:45; and the value assigned to 12:00 must be calculated using the value for 11:00, 11:12 and 11:30.

Diagram of temporal data aggregation excluding forward-looking information.

The 11:00 average does not include the 11:00 point value because in reality the value is not available at that exact time.

In [6]:
# Aggregating in 1H intervals
# ==============================================================================
# The Date column is eliminated so that it does not generate an error when aggregating.
# The Holiday column does not generate an error since it is Boolean and is treated as 0-1.
data = data.drop(columns='Date')
data = data.resample(rule='H', closed='left', label ='right').mean()
Demand Temperature Holiday
2011-12-31 14:00:00 4323.095350 21.225 1.0
2011-12-31 15:00:00 3963.264688 20.625 1.0
2011-12-31 16:00:00 3950.913495 20.325 1.0
2011-12-31 17:00:00 3627.860675 19.850 1.0
2011-12-31 18:00:00 3396.251676 19.025 1.0
... ... ... ...
2014-12-31 09:00:00 4069.625550 21.600 0.0
2014-12-31 10:00:00 3909.230704 20.300 0.0
2014-12-31 11:00:00 3900.600901 19.650 0.0
2014-12-31 12:00:00 3758.236494 18.100 0.0
2014-12-31 13:00:00 3785.650720 17.200 0.0

26304 rows × 3 columns

The dataset starts on 2011-12-31 14:00:00 and ends on 2014-12-31 13:00:00. The first 10 and the last 13 records are discarded so that it starts on 2012-01-01 00:00:00 and ends on 2014-12-30 23:00:00. In addition, in order to optimize the hyperparameters of the model and evaluate its predictive ability, the data is divided into 3 sets: training, validation and test.

In [7]:
# Split data into train-val-test
# ==============================================================================
data = data.loc['2012-01-01 00:00:00': '2014-12-30 23:00:00'].copy()
end_train = '2013-12-31 23:59:00'
end_validation = '2014-11-30 23:59:00'
data_train = data.loc[: end_train, :].copy()
data_val   = data.loc[end_train:end_validation, :].copy()
data_test  = data.loc[end_validation:, :].copy()

print(f"Train dates      : {data_train.index.min()} --- {data_train.index.max()}  (n={len(data_train)})")
print(f"Validation dates : {data_val.index.min()} --- {data_val.index.max()}  (n={len(data_val)})")
print(f"Test dates       : {data_test.index.min()} --- {data_test.index.max()}  (n={len(data_test)})")
Train dates      : 2012-01-01 00:00:00 --- 2013-12-31 23:00:00  (n=17544)
Validation dates : 2014-01-01 00:00:00 --- 2014-11-30 23:00:00  (n=8016)
Test dates       : 2014-12-01 00:00:00 --- 2014-12-30 23:00:00  (n=720)

Graphic exploration

Graphical exploration of time series can be an effective way of identifying trends, patterns, and seasonal variations. This, in turn, helps to guide the selection of the most appropriate forecasting model.

Plot time series

Full time series

In [8]:
# Interactive plot of time series
# ==============================================================================
fig = go.Figure()
fig.add_trace(go.Scatter(x=data_train.index, y=data_train['Demand'], mode='lines', name='Train'))
fig.add_trace(go.Scatter(x=data_val.index, y=data_val['Demand'], mode='lines', name='Validation'))
fig.add_trace(go.Scatter(x=data_test.index, y=data_test['Demand'], mode='lines', name='Test'))
    title  = 'Hourly energy demand',
    margin=dict(l=20, r=20, t=35, b=20),