diff --git a/pyproject.toml b/pyproject.toml index 430efa6..7c3af18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ test = ["pytest~=8.0", "numpy~=1.21"] typecheck = ["mypy~=1.9"] training = ["gluonts[pro]", "numpy", "tensorboard", "typer", "typer-config", "joblib", "scikit-learn"] +evaluation = ["gluonts[pro]", "datasets", "numpy", "typer"] [tool.mypy] ignore_missing_imports = true diff --git a/scripts/README.md b/scripts/README.md index 3dadd90..e8bd401 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -98,4 +98,49 @@ pipeline = ChronosPipeline.from_pretrained("/path/to/fine-tuned/model/ckpt/dir/") pipeline.model.model.push_to_hub("chronos-t5-small-fine-tuned") + ``` + +## Evaluating Chronos models + +Follow these steps to compute the WQL and MASE values for the in-domain and zero-shot benchmarks in our paper. + +- Install this package with with the `evaluation` extra: + ``` + pip install "chronos[evaluation] @ git+https://github.com/amazon-science/chronos-forecasting.git" + ``` +- Run the evaluation script: + ```sh + # In-domain evaluation + # Results will be saved in: evaluation/results/chronos-t5-small-in-domain.csv + python evaluation/evaluate.py evaluation/configs/in-domain.yaml evaluation/results/chronos-t5-small-in-domain.csv \ + --chronos-model-id "amazon/chronos-t5-small" \ + --batch-size=32 \ + --device=cuda:0 \ + --num-samples 20 + + # Zero-shot evaluation + # Results will be saved in: evaluation/results/chronos-t5-small-zero-shot.csv + python evaluation/evaluate.py evaluation/configs/zero-shot.yaml evaluation/results/chronos-t5-small-zero-shot.csv \ + --chronos-model-id "amazon/chronos-t5-small" \ + --batch-size=32 \ + --device=cuda:0 \ + --num-samples 20 + ``` +- Use the following snippet to compute the aggregated relative WQL and MASE scores: + ```py + import pandas as pd + from scipy.stats import gmean # requires: pip install scipy + + + def agg_relative_score(model_df: pd.DataFrame, baseline_df: pd.DataFrame): + relative_score = model_df.drop("model", axis="columns") / baseline_df.drop( + "model", axis="columns" + ) + return relative_score.agg(gmean) + + + result_df = pd.read_csv("evaluation/results/chronos-t5-small-in-domain.csv").set_index("dataset") + baseline_df = pd.read_csv("evaluation/results/seasonal-naive-in-domain.csv").set_index("dataset") + + agg_score_df = agg_relative_score(result_df, baseline_df) ``` \ No newline at end of file diff --git a/scripts/evaluation/configs/in-domain.yaml b/scripts/evaluation/configs/in-domain.yaml new file mode 100644 index 0000000..119b22f --- /dev/null +++ b/scripts/evaluation/configs/in-domain.yaml @@ -0,0 +1,78 @@ +# Backtest configs for the 15 "in-domain" datasets. +# The training portion of these datasets was part of the +# training corpus for Chronos models. +- name: electricity_15min + hf_repo: autogluon/chronos_datasets + offset: -5376 + prediction_length: 24 + num_rolls: 1 +- name: monash_electricity_hourly + hf_repo: autogluon/chronos_datasets + offset: -24 + prediction_length: 24 + num_rolls: 1 +- name: monash_electricity_weekly + hf_repo: autogluon/chronos_datasets + offset: -8 + prediction_length: 8 + num_rolls: 1 +- name: monash_kdd_cup_2018 + hf_repo: autogluon/chronos_datasets + offset: -48 + prediction_length: 48 + num_rolls: 1 +- name: m4_daily + hf_repo: autogluon/chronos_datasets + offset: -14 + prediction_length: 14 + num_rolls: 1 +- name: m4_hourly + hf_repo: autogluon/chronos_datasets + offset: -48 + prediction_length: 48 + num_rolls: 1 +- name: m4_monthly + hf_repo: autogluon/chronos_datasets + offset: -18 + prediction_length: 18 + num_rolls: 1 +- name: m4_weekly + hf_repo: autogluon/chronos_datasets + offset: -13 + prediction_length: 13 + num_rolls: 1 +- name: monash_pedestrian_counts + hf_repo: autogluon/chronos_datasets + offset: -48 + prediction_length: 48 + num_rolls: 1 +- name: taxi_30min + hf_repo: autogluon/chronos_datasets + offset: -48 + prediction_length: 48 + num_rolls: 1 +- name: uber_tlc_hourly + hf_repo: autogluon/chronos_datasets + offset: -24 + prediction_length: 24 + num_rolls: 1 +- name: uber_tlc_daily + hf_repo: autogluon/chronos_datasets + offset: -7 + prediction_length: 7 + num_rolls: 1 +- name: monash_rideshare + hf_repo: autogluon/chronos_datasets + offset: -24 + prediction_length: 24 + num_rolls: 1 +- name: monash_temperature_rain + hf_repo: autogluon/chronos_datasets + offset: -30 + prediction_length: 30 + num_rolls: 1 +- name: monash_london_smart_meters + hf_repo: autogluon/chronos_datasets + offset: -48 + prediction_length: 48 + num_rolls: 1 diff --git a/scripts/evaluation/configs/zero-shot.yaml b/scripts/evaluation/configs/zero-shot.yaml new file mode 100644 index 0000000..e3cabd6 --- /dev/null +++ b/scripts/evaluation/configs/zero-shot.yaml @@ -0,0 +1,137 @@ +# Backtest configs for the 27 "zero-shot" datasets. +# These datasets were not seen by Chronos models during training. +- name: monash_traffic + hf_repo: autogluon/chronos_datasets + offset: -24 + prediction_length: 24 + num_rolls: 1 +- name: monash_australian_electricity + hf_repo: autogluon/chronos_datasets + offset: -48 + prediction_length: 48 + num_rolls: 1 +- name: ercot + hf_repo: autogluon/chronos_datasets + offset: -24 + prediction_length: 24 + num_rolls: 1 +- name: ETTm + hf_repo: autogluon/chronos_datasets_extra + offset: -96 + prediction_length: 24 + num_rolls: 1 +- name: ETTh + hf_repo: autogluon/chronos_datasets_extra + offset: -24 + prediction_length: 24 + num_rolls: 1 +- name: exchange_rate + hf_repo: autogluon/chronos_datasets + offset: -30 + prediction_length: 30 + num_rolls: 1 +- name: nn5 + hf_repo: autogluon/chronos_datasets + offset: -56 + prediction_length: 56 + num_rolls: 1 +- name: monash_nn5_weekly + hf_repo: autogluon/chronos_datasets + offset: -8 + prediction_length: 8 + num_rolls: 1 +- name: monash_weather + hf_repo: autogluon/chronos_datasets + offset: -30 + prediction_length: 30 + num_rolls: 1 +- name: monash_covid_deaths + hf_repo: autogluon/chronos_datasets + offset: -30 + prediction_length: 30 + num_rolls: 1 +- name: monash_fred_md + hf_repo: autogluon/chronos_datasets + offset: -12 + prediction_length: 12 + num_rolls: 1 +- name: m4_quarterly + hf_repo: autogluon/chronos_datasets + offset: -8 + prediction_length: 8 + num_rolls: 1 +- name: m4_yearly + hf_repo: autogluon/chronos_datasets + offset: -6 + prediction_length: 6 + num_rolls: 1 +- name: dominick + hf_repo: autogluon/chronos_datasets + offset: -8 + prediction_length: 8 + num_rolls: 1 +- name: m5 + hf_repo: autogluon/chronos_datasets + offset: -28 + prediction_length: 28 + num_rolls: 1 +- name: monash_tourism_monthly + hf_repo: autogluon/chronos_datasets + offset: -24 + prediction_length: 24 + num_rolls: 1 +- name: monash_tourism_quarterly + hf_repo: autogluon/chronos_datasets + offset: -8 + prediction_length: 8 + num_rolls: 1 +- name: monash_tourism_yearly + hf_repo: autogluon/chronos_datasets + offset: -4 + prediction_length: 4 + num_rolls: 1 +- name: monash_car_parts + hf_repo: autogluon/chronos_datasets + offset: -12 + prediction_length: 12 + num_rolls: 1 +- name: monash_hospital + hf_repo: autogluon/chronos_datasets + offset: -12 + prediction_length: 12 + num_rolls: 1 +- name: monash_cif_2016 + hf_repo: autogluon/chronos_datasets + offset: -12 + prediction_length: 12 + num_rolls: 1 +- name: monash_m1_yearly + hf_repo: autogluon/chronos_datasets + offset: -6 + prediction_length: 6 + num_rolls: 1 +- name: monash_m1_quarterly + hf_repo: autogluon/chronos_datasets + offset: -8 + prediction_length: 8 + num_rolls: 1 +- name: monash_m1_monthly + hf_repo: autogluon/chronos_datasets + offset: -18 + prediction_length: 18 + num_rolls: 1 +- name: monash_m3_monthly + hf_repo: autogluon/chronos_datasets + offset: -18 + prediction_length: 18 + num_rolls: 1 +- name: monash_m3_yearly + hf_repo: autogluon/chronos_datasets + offset: -6 + prediction_length: 6 + num_rolls: 1 +- name: monash_m3_quarterly + hf_repo: autogluon/chronos_datasets + offset: -8 + prediction_length: 8 + num_rolls: 1 \ No newline at end of file diff --git a/scripts/evaluation/evaluate.py b/scripts/evaluation/evaluate.py new file mode 100644 index 0000000..6d18db1 --- /dev/null +++ b/scripts/evaluation/evaluate.py @@ -0,0 +1,341 @@ +import logging +from pathlib import Path +from typing import Iterable + +import datasets +import numpy as np +import pandas as pd +import torch +import typer +import yaml +from gluonts.dataset.split import split +from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss +from gluonts.itertools import batcher +from gluonts.model.evaluation import evaluate_forecasts +from gluonts.model.forecast import SampleForecast +from tqdm.auto import tqdm + +from chronos import ChronosPipeline + +app = typer.Typer(pretty_exceptions_enable=False) + +# Taken from pandas._libs.tslibs.dtypes.OFFSET_TO_PERIOD_FREQSTR +offset_alias_to_period_alias = { + "WEEKDAY": "D", + "EOM": "M", + "BME": "M", + "SME": "M", + "BQS": "Q", + "QS": "Q", + "BQE": "Q", + "BQE-DEC": "Q", + "BQE-JAN": "Q", + "BQE-FEB": "Q", + "BQE-MAR": "Q", + "BQE-APR": "Q", + "BQE-MAY": "Q", + "BQE-JUN": "Q", + "BQE-JUL": "Q", + "BQE-AUG": "Q", + "BQE-SEP": "Q", + "BQE-OCT": "Q", + "BQE-NOV": "Q", + "MS": "M", + "D": "D", + "B": "B", + "min": "min", + "s": "s", + "ms": "ms", + "us": "us", + "ns": "ns", + "h": "h", + "QE": "Q", + "QE-DEC": "Q-DEC", + "QE-JAN": "Q-JAN", + "QE-FEB": "Q-FEB", + "QE-MAR": "Q-MAR", + "QE-APR": "Q-APR", + "QE-MAY": "Q-MAY", + "QE-JUN": "Q-JUN", + "QE-JUL": "Q-JUL", + "QE-AUG": "Q-AUG", + "QE-SEP": "Q-SEP", + "QE-OCT": "Q-OCT", + "QE-NOV": "Q-NOV", + "YE": "Y", + "YE-DEC": "Y-DEC", + "YE-JAN": "Y-JAN", + "YE-FEB": "Y-FEB", + "YE-MAR": "Y-MAR", + "YE-APR": "Y-APR", + "YE-MAY": "Y-MAY", + "YE-JUN": "Y-JUN", + "YE-JUL": "Y-JUL", + "YE-AUG": "Y-AUG", + "YE-SEP": "Y-SEP", + "YE-OCT": "Y-OCT", + "YE-NOV": "Y-NOV", + "W": "W", + "ME": "M", + "Y": "Y", + "BYE": "Y", + "BYE-DEC": "Y", + "BYE-JAN": "Y", + "BYE-FEB": "Y", + "BYE-MAR": "Y", + "BYE-APR": "Y", + "BYE-MAY": "Y", + "BYE-JUN": "Y", + "BYE-JUL": "Y", + "BYE-AUG": "Y", + "BYE-SEP": "Y", + "BYE-OCT": "Y", + "BYE-NOV": "Y", + "YS": "Y", + "BYS": "Y", + "QS-JAN": "Q", + "QS-FEB": "Q", + "QS-MAR": "Q", + "QS-APR": "Q", + "QS-MAY": "Q", + "QS-JUN": "Q", + "QS-JUL": "Q", + "QS-AUG": "Q", + "QS-SEP": "Q", + "QS-OCT": "Q", + "QS-NOV": "Q", + "QS-DEC": "Q", + "BQS-JAN": "Q", + "BQS-FEB": "Q", + "BQS-MAR": "Q", + "BQS-APR": "Q", + "BQS-MAY": "Q", + "BQS-JUN": "Q", + "BQS-JUL": "Q", + "BQS-AUG": "Q", + "BQS-SEP": "Q", + "BQS-OCT": "Q", + "BQS-NOV": "Q", + "BQS-DEC": "Q", + "YS-JAN": "Y", + "YS-FEB": "Y", + "YS-MAR": "Y", + "YS-APR": "Y", + "YS-MAY": "Y", + "YS-JUN": "Y", + "YS-JUL": "Y", + "YS-AUG": "Y", + "YS-SEP": "Y", + "YS-OCT": "Y", + "YS-NOV": "Y", + "YS-DEC": "Y", + "BYS-JAN": "Y", + "BYS-FEB": "Y", + "BYS-MAR": "Y", + "BYS-APR": "Y", + "BYS-MAY": "Y", + "BYS-JUN": "Y", + "BYS-JUL": "Y", + "BYS-AUG": "Y", + "BYS-SEP": "Y", + "BYS-OCT": "Y", + "BYS-NOV": "Y", + "BYS-DEC": "Y", + "Y-JAN": "Y-JAN", + "Y-FEB": "Y-FEB", + "Y-MAR": "Y-MAR", + "Y-APR": "Y-APR", + "Y-MAY": "Y-MAY", + "Y-JUN": "Y-JUN", + "Y-JUL": "Y-JUL", + "Y-AUG": "Y-AUG", + "Y-SEP": "Y-SEP", + "Y-OCT": "Y-OCT", + "Y-NOV": "Y-NOV", + "Y-DEC": "Y-DEC", + "Q-JAN": "Q-JAN", + "Q-FEB": "Q-FEB", + "Q-MAR": "Q-MAR", + "Q-APR": "Q-APR", + "Q-MAY": "Q-MAY", + "Q-JUN": "Q-JUN", + "Q-JUL": "Q-JUL", + "Q-AUG": "Q-AUG", + "Q-SEP": "Q-SEP", + "Q-OCT": "Q-OCT", + "Q-NOV": "Q-NOV", + "Q-DEC": "Q-DEC", + "W-MON": "W-MON", + "W-TUE": "W-TUE", + "W-WED": "W-WED", + "W-THU": "W-THU", + "W-FRI": "W-FRI", + "W-SAT": "W-SAT", + "W-SUN": "W-SUN", +} + + +def to_gluonts_univariate(hf_dataset: datasets.Dataset): + series_fields = [ + col + for col in hf_dataset.features + if isinstance(hf_dataset.features[col], datasets.Sequence) + ] + series_fields.remove("timestamp") + dataset_length = hf_dataset.info.splits["train"].num_examples * len(series_fields) + dataset_freq = pd.infer_freq(hf_dataset[0]["timestamp"]) + dataset_freq = offset_alias_to_period_alias.get(dataset_freq, dataset_freq) + + gts_dataset = [] + for hf_entry in hf_dataset: + for field in series_fields: + gts_dataset.append( + { + "start": pd.Period( + hf_entry["timestamp"][0], + freq=dataset_freq, + ), + "target": hf_entry[field], + } + ) + assert len(gts_dataset) == dataset_length + + return gts_dataset + + +def load_and_split_dataset(backtest_config: dict): + hf_repo = backtest_config["hf_repo"] + dataset_name = backtest_config["name"] + offset = backtest_config["offset"] + prediction_length = backtest_config["prediction_length"] + num_rolls = backtest_config["num_rolls"] + + # This is needed because the datasets in autogluon/chronos_datasets_extra cannot + # be distribued due to license restrictions and must be generated on the fly + trust_remote_code = True if hf_repo == "autogluon/chronos_datasets_extra" else False + + ds = datasets.load_dataset( + hf_repo, dataset_name, split="train", trust_remote_code=trust_remote_code + ) + ds.set_format("numpy") + + gts_dataset = to_gluonts_univariate(ds) + + # Split dataset for evaluation + _, test_template = split(gts_dataset, offset=offset) + test_data = test_template.generate_instances(prediction_length, windows=num_rolls) + + return test_data + + +def generate_sample_forecasts( + test_data_input: Iterable, + pipeline: ChronosPipeline, + prediction_length: int, + num_samples: int, + batch_size: int, +): + # Generate forecast samples + forecast_samples = [] + for batch in tqdm(batcher(test_data_input, batch_size=batch_size)): + context = [torch.tensor(entry["target"]) for entry in batch] + forecast_samples.append( + pipeline.predict( + context, + prediction_length=prediction_length, + num_samples=num_samples, + ).numpy() + ) + forecast_samples = np.concatenate(forecast_samples) + + # Convert forecast samples into gluonts SampleForecast objects + sample_forecasts = [] + for item, ts in zip(forecast_samples, test_data_input): + forecast_start_date = ts["start"] + len(ts["target"]) + sample_forecasts.append( + SampleForecast(samples=item, start_date=forecast_start_date) + ) + + return sample_forecasts + + +@app.command() +def main( + config_path: Path, + metrics_path: Path, + chronos_model_id: str = "amazon/chronos-t5-small", + device: str = "cuda", + torch_dtype: str = "bfloat16", + batch_size: int = 32, + num_samples: int = 20, +): + if isinstance(torch_dtype, str): + torch_dtype = getattr(torch, torch_dtype) + assert isinstance(torch_dtype, torch.dtype) + + # Load Chronos + pipeline = ChronosPipeline.from_pretrained( + chronos_model_id, + device_map=device, + torch_dtype=torch.bfloat16, + ) + + # Load backtest configs + with open(config_path) as fp: + backtest_configs = yaml.safe_load(fp) + + result_rows = [] + for config in backtest_configs: + dataset_name = config["name"] + prediction_length = config["prediction_length"] + + logger.info(f"Loading {dataset_name}") + test_data = load_and_split_dataset(backtest_config=config) + + logger.info( + f"Generating forecasts for {dataset_name} " + f"({len(test_data.input)} time series)" + ) + sample_forecasts = generate_sample_forecasts( + test_data.input, + pipeline=pipeline, + prediction_length=prediction_length, + num_samples=num_samples, + batch_size=batch_size, + ) + + logger.info(f"Evaluating forecasts for {dataset_name}") + metrics = ( + evaluate_forecasts( + sample_forecasts, + test_data=test_data, + metrics=[ + MASE(), + MeanWeightedSumQuantileLoss(np.arange(0.1, 1.0, 0.1)), + ], + batch_size=5000, + ) + .reset_index(drop=True) + .to_dict(orient="records") + ) + result_rows.append( + {"dataset": dataset_name, "model": chronos_model_id, **metrics[0]} + ) + + # Save results to a CSV file + results_df = ( + pd.DataFrame(result_rows) + .rename( + {"MASE[0.5]": "MASE", "mean_weighted_sum_quantile_loss": "WQL"}, + axis="columns", + ) + .sort_values(by="dataset") + ) + results_df.to_csv(metrics_path, index=False) + + +if __name__ == "__main__": + logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") + logger = logging.getLogger("Chronos Evaluation") + logger.setLevel(logging.INFO) + app() diff --git a/scripts/evaluation/results/seasonal-naive-in-domain.csv b/scripts/evaluation/results/seasonal-naive-in-domain.csv new file mode 100644 index 0000000..a7e806e --- /dev/null +++ b/scripts/evaluation/results/seasonal-naive-in-domain.csv @@ -0,0 +1,16 @@ +dataset,model,MASE,WQL +electricity_15min,seasonal-naive,0.4978697476132387,0.1169378163151378 +m4_daily,seasonal-naive,3.278424323759728,0.0279332664832445 +m4_hourly,seasonal-naive,1.1932105781333862,0.0483091941403194 +m4_monthly,seasonal-naive,1.2597170386001693,0.1455332906092934 +m4_weekly,seasonal-naive,2.777295109814942,0.0633986476090776 +monash_electricity_hourly,seasonal-naive,1.839634785956572,0.1468968206229902 +monash_electricity_weekly,seasonal-naive,3.0371656285424,0.1979332504059267 +monash_kdd_cup_2018,seasonal-naive,0.9943785889052376,0.5555856702439576 +monash_london_smart_meters,seasonal-naive,0.9661872287141056,0.5413187715028914 +monash_pedestrian_counts,seasonal-naive,0.3691951941442247,0.3185271550430794 +monash_rideshare,seasonal-naive,1.2495987545425715,0.1860080644135506 +monash_temperature_rain,seasonal-naive,2.243384627173123,1.4244854980220072 +taxi_30min,seasonal-naive,1.160268631066241,0.4711417890926274 +uber_tlc_daily,seasonal-naive,1.37803447078482,0.2313550175912078 +uber_tlc_hourly,seasonal-naive,0.930916273455971,0.298849044501192 diff --git a/scripts/evaluation/results/seasonal-naive-zero-shot.csv b/scripts/evaluation/results/seasonal-naive-zero-shot.csv new file mode 100644 index 0000000..24b960f --- /dev/null +++ b/scripts/evaluation/results/seasonal-naive-zero-shot.csv @@ -0,0 +1,28 @@ +dataset,model,MASE,WQL +ETTh,seasonal-naive,0.9316203114697056,0.1220896585205886 +ETTm,seasonal-naive,1.1693053852270578,0.1413480385734046 +dominick,seasonal-naive,0.8706150115348875,0.4529164093744346 +ercot,seasonal-naive,0.7613354813741452,0.0366036447606282 +exchange_rate,seasonal-naive,1.7401824286954128,0.0129841406759913 +m4_quarterly,seasonal-naive,1.6022471766126911,0.1186484661559648 +m4_yearly,seasonal-naive,3.974360261259571,0.1614389663357925 +m5,seasonal-naive,1.399206213076729,1.0240883478068443 +monash_australian_electricity,seasonal-naive,1.2533189641227642,0.0836951323308387 +monash_car_parts,seasonal-naive,1.2014638390969912,1.5999522140809177 +monash_cif_2016,seasonal-naive,1.289290577415544,0.0150830409089921 +monash_covid_deaths,seasonal-naive,46.91239825526407,0.1330848762571827 +monash_fred_md,seasonal-naive,1.1008000463101226,0.1222237702571737 +monash_hospital,seasonal-naive,0.9205278266364826,0.0726263373268254 +monash_m1_monthly,seasonal-naive,1.3144614957646543,0.1914632595030148 +monash_m1_quarterly,seasonal-naive,2.077536550805995,0.1495022062865622 +monash_m1_yearly,seasonal-naive,4.894322225232431,0.2092955931101782 +monash_m3_monthly,seasonal-naive,1.1462045758327934,0.1485446007554992 +monash_m3_quarterly,seasonal-naive,1.425343793700714,0.1012520529806161 +monash_m3_yearly,seasonal-naive,3.1717102364409517,0.1665329650420048 +monash_nn5_weekly,seasonal-naive,1.0628482559107015,0.1226908962169196 +monash_tourism_monthly,seasonal-naive,1.630939994944413,0.1041824322151567 +monash_tourism_quarterly,seasonal-naive,1.6989892627474672,0.1193750169177449 +monash_tourism_yearly,seasonal-naive,3.5520097206480883,0.2091826587673241 +monash_traffic,seasonal-naive,1.0767397173107436,0.3618532196990004 +monash_weather,seasonal-naive,1.0038475713182748,0.2165947349654047 +nn5,seasonal-naive,1.2917285866431214,0.4246208074843067