Last active
March 4, 2025 08:56
-
-
Save lostella/30c90c5349c04800b01420a48d71fc9e to your computer and use it in GitHub Desktop.
Chronos rolling evaluation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Example script running rolling evaluation of Chronos models. | |
| # See: https://github.com/amazon-science/chronos-forecasting | |
| # | |
| # Requirements: | |
| # uv pip install -U chronos-forecasting matplotlib numpy pandas torch | |
| # | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from itertools import islice | |
| from chronos import BaseChronosPipeline | |
| def batched(iterable, n: int): | |
| assert n >= 1 | |
| iterator = iter(iterable) | |
| batch = list(islice(iterator, n)) | |
| while len(batch) > 0: | |
| yield batch | |
| batch = list(islice(iterator, n)) | |
| def cutoffs(total_length, test_length, lead_time, prediction_length): | |
| assert total_length > 0 | |
| assert test_length > 0 | |
| assert total_length >= test_length | |
| assert prediction_length > 0 | |
| assert lead_time > 0 | |
| assert prediction_length >= lead_time | |
| stride = prediction_length - lead_time + 1 # stride >= 1 | |
| upper_bound = total_length - lead_time + 1 | |
| return range(upper_bound - test_length, upper_bound, stride) | |
| assert list( | |
| cutoffs(total_length=10, test_length=3, lead_time=1, prediction_length=1) | |
| ) == [7, 8, 9] | |
| assert list( | |
| cutoffs(total_length=10, test_length=3, lead_time=2, prediction_length=2) | |
| ) == [6, 7, 8] | |
| assert list( | |
| cutoffs(total_length=10, test_length=3, lead_time=2, prediction_length=3) | |
| ) == [6, 8] | |
| assert list( | |
| cutoffs(total_length=10, test_length=3, lead_time=1, prediction_length=3) | |
| ) == [7] | |
| def roll_chronos( | |
| model, | |
| series: pd.Series, | |
| test_length: int, | |
| lead_time: int, | |
| prediction_length: int, | |
| batch_size: int = 64, | |
| quantile_levels: list = [0.1, 0.5, 0.9], | |
| ) -> pd.DataFrame: | |
| assert series.index.is_monotonic_increasing | |
| assert isinstance(series.index, pd.DatetimeIndex) | |
| assert series.index.freq is not None | |
| assert prediction_length > 0 | |
| assert lead_time > 0 | |
| assert prediction_length >= lead_time | |
| assert test_length > 0 | |
| assert batch_size > 0 | |
| cutoff_indices = cutoffs( | |
| total_length=len(series), | |
| test_length=test_length, | |
| lead_time=lead_time, | |
| prediction_length=prediction_length, | |
| ) | |
| series_tensor = torch.tensor(series.values) | |
| contexts = [series_tensor[:k] for k in cutoff_indices] | |
| batch_forecast = [] | |
| for batch_context in batched(contexts, batch_size): | |
| quantiles, _ = model.predict_quantiles( | |
| batch_context, | |
| prediction_length=prediction_length, | |
| quantile_levels=quantile_levels, | |
| ) # (batch_size, prediction_length, len(quantile_levels)) | |
| quantiles = quantiles.detach().numpy()[:, lead_time - 1 :, :] | |
| batch_forecast.append(quantiles.reshape(-1, len(quantile_levels))) | |
| forecast_index = series.index[-test_length:] | |
| forecast_data = np.concatenate(batch_forecast, axis=0)[: len(forecast_index)] | |
| forecasts_df = pd.DataFrame( | |
| forecast_data, | |
| index=forecast_index, | |
| columns=[f"{q}-quantile" for q in quantile_levels], | |
| ) | |
| return forecasts_df | |
| def roll_naive( | |
| series: pd.Series, | |
| test_length: int, | |
| lead_time: int, | |
| prediction_length: int, | |
| num_seasons: int = 1, | |
| ) -> pd.DataFrame: | |
| assert series.index.is_monotonic_increasing | |
| assert isinstance(series.index, pd.DatetimeIndex) | |
| assert series.index.freq is not None | |
| assert prediction_length > 0 | |
| assert lead_time > 0 | |
| assert prediction_length >= lead_time | |
| assert test_length > 0 | |
| assert num_seasons > 0 | |
| lookback = int(np.ceil(prediction_length / num_seasons)) * num_seasons | |
| forecast_index = series.index[-test_length:] | |
| forecast_data = series.values[-test_length - lookback : -lookback] | |
| return pd.DataFrame( | |
| data=forecast_data, | |
| index=forecast_index, | |
| columns=["forecast"], | |
| ) | |
| def plot_forecasts(series: pd.Series, forecasts: pd.DataFrame): | |
| plt.figure(figsize=(20, 8)) | |
| plt.subplot(1, 2, 1) | |
| plt.plot(series) | |
| plt.plot(forecasts) | |
| plt.legend([series.name] + list(forecasts.columns)) | |
| ground_truth = series[forecasts.index] | |
| plt.subplot(1, 2, 2) | |
| plt.plot(ground_truth) | |
| plt.plot(forecasts) | |
| plt.legend([ground_truth.name] + list(forecasts.columns)) | |
| plt.show() | |
| if __name__ == "__main__": | |
| chronos_bolt_small = BaseChronosPipeline.from_pretrained( | |
| "amazon/chronos-bolt-small" | |
| ) | |
| df = pd.read_csv( | |
| "https://autogluon.s3.amazonaws.com/datasets/timeseries/m4_hourly_subset/test.csv" | |
| ) | |
| df.head() | |
| item_id = "H144" | |
| group = df.groupby("item_id").get_group(item_id) | |
| series = ( | |
| pd.Series(group["target"].values, index=pd.to_datetime(group["timestamp"])) | |
| .asfreq("1h") | |
| .rename(item_id) | |
| ) | |
| test_length = 168 | |
| lead_time = 1 | |
| prediction_length = 24 | |
| forecast_chronos = roll_chronos( | |
| series=series, | |
| model=chronos_bolt_small, | |
| test_length=test_length, | |
| lead_time=lead_time, | |
| prediction_length=prediction_length, | |
| ) | |
| forecast_naive = roll_naive( | |
| series=series, | |
| num_seasons=1, | |
| test_length=test_length, | |
| lead_time=lead_time, | |
| prediction_length=prediction_length, | |
| ) | |
| forecast_naive_1d = roll_naive( | |
| series=series, | |
| num_seasons=24, | |
| test_length=test_length, | |
| lead_time=lead_time, | |
| prediction_length=prediction_length, | |
| ) | |
| forecast_naive_1w = roll_naive( | |
| series=series, | |
| num_seasons=168, | |
| test_length=test_length, | |
| lead_time=lead_time, | |
| prediction_length=prediction_length, | |
| ) | |
| forecasts = pd.DataFrame( | |
| { | |
| "chronos-bolt-small": forecast_chronos["0.5-quantile"], | |
| "naive": forecast_naive["forecast"], | |
| "naive-1d": forecast_naive_1d["forecast"], | |
| "naive-1w": forecast_naive_1w["forecast"], | |
| } | |
| ) | |
| assert not forecasts.isna().any().any() | |
| ground_truth = series[forecasts.index] | |
| errors = forecasts.sub(ground_truth, axis=0) | |
| metrics = pd.DataFrame( | |
| { | |
| "nMAE": errors.abs().mean() / ground_truth.abs().mean(), | |
| "RMSE": (errors * errors).mean().map(np.sqrt), | |
| } | |
| ) | |
| print(metrics) | |
| plot_forecasts(series=series, forecasts=forecasts) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment