Skip to content

Instantly share code, notes, and snippets.

@lostella
Last active March 4, 2025 08:56
Show Gist options
  • Select an option

  • Save lostella/30c90c5349c04800b01420a48d71fc9e to your computer and use it in GitHub Desktop.

Select an option

Save lostella/30c90c5349c04800b01420a48d71fc9e to your computer and use it in GitHub Desktop.
Chronos rolling evaluation
# Example script running rolling evaluation of Chronos models.
# See: https://github.com/amazon-science/chronos-forecasting
#
# Requirements:
# uv pip install -U chronos-forecasting matplotlib numpy pandas torch
#
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from itertools import islice
from chronos import BaseChronosPipeline
def batched(iterable, n: int):
assert n >= 1
iterator = iter(iterable)
batch = list(islice(iterator, n))
while len(batch) > 0:
yield batch
batch = list(islice(iterator, n))
def cutoffs(total_length, test_length, lead_time, prediction_length):
assert total_length > 0
assert test_length > 0
assert total_length >= test_length
assert prediction_length > 0
assert lead_time > 0
assert prediction_length >= lead_time
stride = prediction_length - lead_time + 1 # stride >= 1
upper_bound = total_length - lead_time + 1
return range(upper_bound - test_length, upper_bound, stride)
assert list(
cutoffs(total_length=10, test_length=3, lead_time=1, prediction_length=1)
) == [7, 8, 9]
assert list(
cutoffs(total_length=10, test_length=3, lead_time=2, prediction_length=2)
) == [6, 7, 8]
assert list(
cutoffs(total_length=10, test_length=3, lead_time=2, prediction_length=3)
) == [6, 8]
assert list(
cutoffs(total_length=10, test_length=3, lead_time=1, prediction_length=3)
) == [7]
def roll_chronos(
model,
series: pd.Series,
test_length: int,
lead_time: int,
prediction_length: int,
batch_size: int = 64,
quantile_levels: list = [0.1, 0.5, 0.9],
) -> pd.DataFrame:
assert series.index.is_monotonic_increasing
assert isinstance(series.index, pd.DatetimeIndex)
assert series.index.freq is not None
assert prediction_length > 0
assert lead_time > 0
assert prediction_length >= lead_time
assert test_length > 0
assert batch_size > 0
cutoff_indices = cutoffs(
total_length=len(series),
test_length=test_length,
lead_time=lead_time,
prediction_length=prediction_length,
)
series_tensor = torch.tensor(series.values)
contexts = [series_tensor[:k] for k in cutoff_indices]
batch_forecast = []
for batch_context in batched(contexts, batch_size):
quantiles, _ = model.predict_quantiles(
batch_context,
prediction_length=prediction_length,
quantile_levels=quantile_levels,
) # (batch_size, prediction_length, len(quantile_levels))
quantiles = quantiles.detach().numpy()[:, lead_time - 1 :, :]
batch_forecast.append(quantiles.reshape(-1, len(quantile_levels)))
forecast_index = series.index[-test_length:]
forecast_data = np.concatenate(batch_forecast, axis=0)[: len(forecast_index)]
forecasts_df = pd.DataFrame(
forecast_data,
index=forecast_index,
columns=[f"{q}-quantile" for q in quantile_levels],
)
return forecasts_df
def roll_naive(
series: pd.Series,
test_length: int,
lead_time: int,
prediction_length: int,
num_seasons: int = 1,
) -> pd.DataFrame:
assert series.index.is_monotonic_increasing
assert isinstance(series.index, pd.DatetimeIndex)
assert series.index.freq is not None
assert prediction_length > 0
assert lead_time > 0
assert prediction_length >= lead_time
assert test_length > 0
assert num_seasons > 0
lookback = int(np.ceil(prediction_length / num_seasons)) * num_seasons
forecast_index = series.index[-test_length:]
forecast_data = series.values[-test_length - lookback : -lookback]
return pd.DataFrame(
data=forecast_data,
index=forecast_index,
columns=["forecast"],
)
def plot_forecasts(series: pd.Series, forecasts: pd.DataFrame):
plt.figure(figsize=(20, 8))
plt.subplot(1, 2, 1)
plt.plot(series)
plt.plot(forecasts)
plt.legend([series.name] + list(forecasts.columns))
ground_truth = series[forecasts.index]
plt.subplot(1, 2, 2)
plt.plot(ground_truth)
plt.plot(forecasts)
plt.legend([ground_truth.name] + list(forecasts.columns))
plt.show()
if __name__ == "__main__":
chronos_bolt_small = BaseChronosPipeline.from_pretrained(
"amazon/chronos-bolt-small"
)
df = pd.read_csv(
"https://autogluon.s3.amazonaws.com/datasets/timeseries/m4_hourly_subset/test.csv"
)
df.head()
item_id = "H144"
group = df.groupby("item_id").get_group(item_id)
series = (
pd.Series(group["target"].values, index=pd.to_datetime(group["timestamp"]))
.asfreq("1h")
.rename(item_id)
)
test_length = 168
lead_time = 1
prediction_length = 24
forecast_chronos = roll_chronos(
series=series,
model=chronos_bolt_small,
test_length=test_length,
lead_time=lead_time,
prediction_length=prediction_length,
)
forecast_naive = roll_naive(
series=series,
num_seasons=1,
test_length=test_length,
lead_time=lead_time,
prediction_length=prediction_length,
)
forecast_naive_1d = roll_naive(
series=series,
num_seasons=24,
test_length=test_length,
lead_time=lead_time,
prediction_length=prediction_length,
)
forecast_naive_1w = roll_naive(
series=series,
num_seasons=168,
test_length=test_length,
lead_time=lead_time,
prediction_length=prediction_length,
)
forecasts = pd.DataFrame(
{
"chronos-bolt-small": forecast_chronos["0.5-quantile"],
"naive": forecast_naive["forecast"],
"naive-1d": forecast_naive_1d["forecast"],
"naive-1w": forecast_naive_1w["forecast"],
}
)
assert not forecasts.isna().any().any()
ground_truth = series[forecasts.index]
errors = forecasts.sub(ground_truth, axis=0)
metrics = pd.DataFrame(
{
"nMAE": errors.abs().mean() / ground_truth.abs().mean(),
"RMSE": (errors * errors).mean().map(np.sqrt),
}
)
print(metrics)
plot_forecasts(series=series, forecasts=forecasts)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment