mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
Add evaluation script (#134)
*Description of changes:* This PR adds configs and a script to evaluate Chronos models in the same way as described in the paper. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
This commit is contained in:
parent
afd9cfd062
commit
fead4ecbca
7 changed files with 646 additions and 0 deletions
|
|
@ -13,6 +13,7 @@ dependencies = [
|
|||
test = ["pytest~=8.0", "numpy~=1.21"]
|
||||
typecheck = ["mypy~=1.9"]
|
||||
training = ["gluonts[pro]", "numpy", "tensorboard", "typer", "typer-config", "joblib", "scikit-learn"]
|
||||
evaluation = ["gluonts[pro]", "datasets", "numpy", "typer"]
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
|
|
|
|||
|
|
@ -98,4 +98,49 @@
|
|||
|
||||
pipeline = ChronosPipeline.from_pretrained("/path/to/fine-tuned/model/ckpt/dir/")
|
||||
pipeline.model.model.push_to_hub("chronos-t5-small-fine-tuned")
|
||||
```
|
||||
|
||||
## Evaluating Chronos models
|
||||
|
||||
Follow these steps to compute the WQL and MASE values for the in-domain and zero-shot benchmarks in our paper.
|
||||
|
||||
- Install this package with with the `evaluation` extra:
|
||||
```
|
||||
pip install "chronos[evaluation] @ git+https://github.com/amazon-science/chronos-forecasting.git"
|
||||
```
|
||||
- Run the evaluation script:
|
||||
```sh
|
||||
# In-domain evaluation
|
||||
# Results will be saved in: evaluation/results/chronos-t5-small-in-domain.csv
|
||||
python evaluation/evaluate.py evaluation/configs/in-domain.yaml evaluation/results/chronos-t5-small-in-domain.csv \
|
||||
--chronos-model-id "amazon/chronos-t5-small" \
|
||||
--batch-size=32 \
|
||||
--device=cuda:0 \
|
||||
--num-samples 20
|
||||
|
||||
# Zero-shot evaluation
|
||||
# Results will be saved in: evaluation/results/chronos-t5-small-zero-shot.csv
|
||||
python evaluation/evaluate.py evaluation/configs/zero-shot.yaml evaluation/results/chronos-t5-small-zero-shot.csv \
|
||||
--chronos-model-id "amazon/chronos-t5-small" \
|
||||
--batch-size=32 \
|
||||
--device=cuda:0 \
|
||||
--num-samples 20
|
||||
```
|
||||
- Use the following snippet to compute the aggregated relative WQL and MASE scores:
|
||||
```py
|
||||
import pandas as pd
|
||||
from scipy.stats import gmean # requires: pip install scipy
|
||||
|
||||
|
||||
def agg_relative_score(model_df: pd.DataFrame, baseline_df: pd.DataFrame):
|
||||
relative_score = model_df.drop("model", axis="columns") / baseline_df.drop(
|
||||
"model", axis="columns"
|
||||
)
|
||||
return relative_score.agg(gmean)
|
||||
|
||||
|
||||
result_df = pd.read_csv("evaluation/results/chronos-t5-small-in-domain.csv").set_index("dataset")
|
||||
baseline_df = pd.read_csv("evaluation/results/seasonal-naive-in-domain.csv").set_index("dataset")
|
||||
|
||||
agg_score_df = agg_relative_score(result_df, baseline_df)
|
||||
```
|
||||
78
scripts/evaluation/configs/in-domain.yaml
Normal file
78
scripts/evaluation/configs/in-domain.yaml
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# Backtest configs for the 15 "in-domain" datasets.
|
||||
# The training portion of these datasets was part of the
|
||||
# training corpus for Chronos models.
|
||||
- name: electricity_15min
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -5376
|
||||
prediction_length: 24
|
||||
num_rolls: 1
|
||||
- name: monash_electricity_hourly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -24
|
||||
prediction_length: 24
|
||||
num_rolls: 1
|
||||
- name: monash_electricity_weekly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -8
|
||||
prediction_length: 8
|
||||
num_rolls: 1
|
||||
- name: monash_kdd_cup_2018
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -48
|
||||
prediction_length: 48
|
||||
num_rolls: 1
|
||||
- name: m4_daily
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -14
|
||||
prediction_length: 14
|
||||
num_rolls: 1
|
||||
- name: m4_hourly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -48
|
||||
prediction_length: 48
|
||||
num_rolls: 1
|
||||
- name: m4_monthly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -18
|
||||
prediction_length: 18
|
||||
num_rolls: 1
|
||||
- name: m4_weekly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -13
|
||||
prediction_length: 13
|
||||
num_rolls: 1
|
||||
- name: monash_pedestrian_counts
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -48
|
||||
prediction_length: 48
|
||||
num_rolls: 1
|
||||
- name: taxi_30min
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -48
|
||||
prediction_length: 48
|
||||
num_rolls: 1
|
||||
- name: uber_tlc_hourly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -24
|
||||
prediction_length: 24
|
||||
num_rolls: 1
|
||||
- name: uber_tlc_daily
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -7
|
||||
prediction_length: 7
|
||||
num_rolls: 1
|
||||
- name: monash_rideshare
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -24
|
||||
prediction_length: 24
|
||||
num_rolls: 1
|
||||
- name: monash_temperature_rain
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -30
|
||||
prediction_length: 30
|
||||
num_rolls: 1
|
||||
- name: monash_london_smart_meters
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -48
|
||||
prediction_length: 48
|
||||
num_rolls: 1
|
||||
137
scripts/evaluation/configs/zero-shot.yaml
Normal file
137
scripts/evaluation/configs/zero-shot.yaml
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
# Backtest configs for the 27 "zero-shot" datasets.
|
||||
# These datasets were not seen by Chronos models during training.
|
||||
- name: monash_traffic
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -24
|
||||
prediction_length: 24
|
||||
num_rolls: 1
|
||||
- name: monash_australian_electricity
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -48
|
||||
prediction_length: 48
|
||||
num_rolls: 1
|
||||
- name: ercot
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -24
|
||||
prediction_length: 24
|
||||
num_rolls: 1
|
||||
- name: ETTm
|
||||
hf_repo: autogluon/chronos_datasets_extra
|
||||
offset: -96
|
||||
prediction_length: 24
|
||||
num_rolls: 1
|
||||
- name: ETTh
|
||||
hf_repo: autogluon/chronos_datasets_extra
|
||||
offset: -24
|
||||
prediction_length: 24
|
||||
num_rolls: 1
|
||||
- name: exchange_rate
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -30
|
||||
prediction_length: 30
|
||||
num_rolls: 1
|
||||
- name: nn5
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -56
|
||||
prediction_length: 56
|
||||
num_rolls: 1
|
||||
- name: monash_nn5_weekly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -8
|
||||
prediction_length: 8
|
||||
num_rolls: 1
|
||||
- name: monash_weather
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -30
|
||||
prediction_length: 30
|
||||
num_rolls: 1
|
||||
- name: monash_covid_deaths
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -30
|
||||
prediction_length: 30
|
||||
num_rolls: 1
|
||||
- name: monash_fred_md
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -12
|
||||
prediction_length: 12
|
||||
num_rolls: 1
|
||||
- name: m4_quarterly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -8
|
||||
prediction_length: 8
|
||||
num_rolls: 1
|
||||
- name: m4_yearly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -6
|
||||
prediction_length: 6
|
||||
num_rolls: 1
|
||||
- name: dominick
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -8
|
||||
prediction_length: 8
|
||||
num_rolls: 1
|
||||
- name: m5
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -28
|
||||
prediction_length: 28
|
||||
num_rolls: 1
|
||||
- name: monash_tourism_monthly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -24
|
||||
prediction_length: 24
|
||||
num_rolls: 1
|
||||
- name: monash_tourism_quarterly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -8
|
||||
prediction_length: 8
|
||||
num_rolls: 1
|
||||
- name: monash_tourism_yearly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -4
|
||||
prediction_length: 4
|
||||
num_rolls: 1
|
||||
- name: monash_car_parts
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -12
|
||||
prediction_length: 12
|
||||
num_rolls: 1
|
||||
- name: monash_hospital
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -12
|
||||
prediction_length: 12
|
||||
num_rolls: 1
|
||||
- name: monash_cif_2016
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -12
|
||||
prediction_length: 12
|
||||
num_rolls: 1
|
||||
- name: monash_m1_yearly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -6
|
||||
prediction_length: 6
|
||||
num_rolls: 1
|
||||
- name: monash_m1_quarterly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -8
|
||||
prediction_length: 8
|
||||
num_rolls: 1
|
||||
- name: monash_m1_monthly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -18
|
||||
prediction_length: 18
|
||||
num_rolls: 1
|
||||
- name: monash_m3_monthly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -18
|
||||
prediction_length: 18
|
||||
num_rolls: 1
|
||||
- name: monash_m3_yearly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -6
|
||||
prediction_length: 6
|
||||
num_rolls: 1
|
||||
- name: monash_m3_quarterly
|
||||
hf_repo: autogluon/chronos_datasets
|
||||
offset: -8
|
||||
prediction_length: 8
|
||||
num_rolls: 1
|
||||
341
scripts/evaluation/evaluate.py
Normal file
341
scripts/evaluation/evaluate.py
Normal file
|
|
@ -0,0 +1,341 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import typer
|
||||
import yaml
|
||||
from gluonts.dataset.split import split
|
||||
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
|
||||
from gluonts.itertools import batcher
|
||||
from gluonts.model.evaluation import evaluate_forecasts
|
||||
from gluonts.model.forecast import SampleForecast
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from chronos import ChronosPipeline
|
||||
|
||||
app = typer.Typer(pretty_exceptions_enable=False)
|
||||
|
||||
# Taken from pandas._libs.tslibs.dtypes.OFFSET_TO_PERIOD_FREQSTR
|
||||
offset_alias_to_period_alias = {
|
||||
"WEEKDAY": "D",
|
||||
"EOM": "M",
|
||||
"BME": "M",
|
||||
"SME": "M",
|
||||
"BQS": "Q",
|
||||
"QS": "Q",
|
||||
"BQE": "Q",
|
||||
"BQE-DEC": "Q",
|
||||
"BQE-JAN": "Q",
|
||||
"BQE-FEB": "Q",
|
||||
"BQE-MAR": "Q",
|
||||
"BQE-APR": "Q",
|
||||
"BQE-MAY": "Q",
|
||||
"BQE-JUN": "Q",
|
||||
"BQE-JUL": "Q",
|
||||
"BQE-AUG": "Q",
|
||||
"BQE-SEP": "Q",
|
||||
"BQE-OCT": "Q",
|
||||
"BQE-NOV": "Q",
|
||||
"MS": "M",
|
||||
"D": "D",
|
||||
"B": "B",
|
||||
"min": "min",
|
||||
"s": "s",
|
||||
"ms": "ms",
|
||||
"us": "us",
|
||||
"ns": "ns",
|
||||
"h": "h",
|
||||
"QE": "Q",
|
||||
"QE-DEC": "Q-DEC",
|
||||
"QE-JAN": "Q-JAN",
|
||||
"QE-FEB": "Q-FEB",
|
||||
"QE-MAR": "Q-MAR",
|
||||
"QE-APR": "Q-APR",
|
||||
"QE-MAY": "Q-MAY",
|
||||
"QE-JUN": "Q-JUN",
|
||||
"QE-JUL": "Q-JUL",
|
||||
"QE-AUG": "Q-AUG",
|
||||
"QE-SEP": "Q-SEP",
|
||||
"QE-OCT": "Q-OCT",
|
||||
"QE-NOV": "Q-NOV",
|
||||
"YE": "Y",
|
||||
"YE-DEC": "Y-DEC",
|
||||
"YE-JAN": "Y-JAN",
|
||||
"YE-FEB": "Y-FEB",
|
||||
"YE-MAR": "Y-MAR",
|
||||
"YE-APR": "Y-APR",
|
||||
"YE-MAY": "Y-MAY",
|
||||
"YE-JUN": "Y-JUN",
|
||||
"YE-JUL": "Y-JUL",
|
||||
"YE-AUG": "Y-AUG",
|
||||
"YE-SEP": "Y-SEP",
|
||||
"YE-OCT": "Y-OCT",
|
||||
"YE-NOV": "Y-NOV",
|
||||
"W": "W",
|
||||
"ME": "M",
|
||||
"Y": "Y",
|
||||
"BYE": "Y",
|
||||
"BYE-DEC": "Y",
|
||||
"BYE-JAN": "Y",
|
||||
"BYE-FEB": "Y",
|
||||
"BYE-MAR": "Y",
|
||||
"BYE-APR": "Y",
|
||||
"BYE-MAY": "Y",
|
||||
"BYE-JUN": "Y",
|
||||
"BYE-JUL": "Y",
|
||||
"BYE-AUG": "Y",
|
||||
"BYE-SEP": "Y",
|
||||
"BYE-OCT": "Y",
|
||||
"BYE-NOV": "Y",
|
||||
"YS": "Y",
|
||||
"BYS": "Y",
|
||||
"QS-JAN": "Q",
|
||||
"QS-FEB": "Q",
|
||||
"QS-MAR": "Q",
|
||||
"QS-APR": "Q",
|
||||
"QS-MAY": "Q",
|
||||
"QS-JUN": "Q",
|
||||
"QS-JUL": "Q",
|
||||
"QS-AUG": "Q",
|
||||
"QS-SEP": "Q",
|
||||
"QS-OCT": "Q",
|
||||
"QS-NOV": "Q",
|
||||
"QS-DEC": "Q",
|
||||
"BQS-JAN": "Q",
|
||||
"BQS-FEB": "Q",
|
||||
"BQS-MAR": "Q",
|
||||
"BQS-APR": "Q",
|
||||
"BQS-MAY": "Q",
|
||||
"BQS-JUN": "Q",
|
||||
"BQS-JUL": "Q",
|
||||
"BQS-AUG": "Q",
|
||||
"BQS-SEP": "Q",
|
||||
"BQS-OCT": "Q",
|
||||
"BQS-NOV": "Q",
|
||||
"BQS-DEC": "Q",
|
||||
"YS-JAN": "Y",
|
||||
"YS-FEB": "Y",
|
||||
"YS-MAR": "Y",
|
||||
"YS-APR": "Y",
|
||||
"YS-MAY": "Y",
|
||||
"YS-JUN": "Y",
|
||||
"YS-JUL": "Y",
|
||||
"YS-AUG": "Y",
|
||||
"YS-SEP": "Y",
|
||||
"YS-OCT": "Y",
|
||||
"YS-NOV": "Y",
|
||||
"YS-DEC": "Y",
|
||||
"BYS-JAN": "Y",
|
||||
"BYS-FEB": "Y",
|
||||
"BYS-MAR": "Y",
|
||||
"BYS-APR": "Y",
|
||||
"BYS-MAY": "Y",
|
||||
"BYS-JUN": "Y",
|
||||
"BYS-JUL": "Y",
|
||||
"BYS-AUG": "Y",
|
||||
"BYS-SEP": "Y",
|
||||
"BYS-OCT": "Y",
|
||||
"BYS-NOV": "Y",
|
||||
"BYS-DEC": "Y",
|
||||
"Y-JAN": "Y-JAN",
|
||||
"Y-FEB": "Y-FEB",
|
||||
"Y-MAR": "Y-MAR",
|
||||
"Y-APR": "Y-APR",
|
||||
"Y-MAY": "Y-MAY",
|
||||
"Y-JUN": "Y-JUN",
|
||||
"Y-JUL": "Y-JUL",
|
||||
"Y-AUG": "Y-AUG",
|
||||
"Y-SEP": "Y-SEP",
|
||||
"Y-OCT": "Y-OCT",
|
||||
"Y-NOV": "Y-NOV",
|
||||
"Y-DEC": "Y-DEC",
|
||||
"Q-JAN": "Q-JAN",
|
||||
"Q-FEB": "Q-FEB",
|
||||
"Q-MAR": "Q-MAR",
|
||||
"Q-APR": "Q-APR",
|
||||
"Q-MAY": "Q-MAY",
|
||||
"Q-JUN": "Q-JUN",
|
||||
"Q-JUL": "Q-JUL",
|
||||
"Q-AUG": "Q-AUG",
|
||||
"Q-SEP": "Q-SEP",
|
||||
"Q-OCT": "Q-OCT",
|
||||
"Q-NOV": "Q-NOV",
|
||||
"Q-DEC": "Q-DEC",
|
||||
"W-MON": "W-MON",
|
||||
"W-TUE": "W-TUE",
|
||||
"W-WED": "W-WED",
|
||||
"W-THU": "W-THU",
|
||||
"W-FRI": "W-FRI",
|
||||
"W-SAT": "W-SAT",
|
||||
"W-SUN": "W-SUN",
|
||||
}
|
||||
|
||||
|
||||
def to_gluonts_univariate(hf_dataset: datasets.Dataset):
|
||||
series_fields = [
|
||||
col
|
||||
for col in hf_dataset.features
|
||||
if isinstance(hf_dataset.features[col], datasets.Sequence)
|
||||
]
|
||||
series_fields.remove("timestamp")
|
||||
dataset_length = hf_dataset.info.splits["train"].num_examples * len(series_fields)
|
||||
dataset_freq = pd.infer_freq(hf_dataset[0]["timestamp"])
|
||||
dataset_freq = offset_alias_to_period_alias.get(dataset_freq, dataset_freq)
|
||||
|
||||
gts_dataset = []
|
||||
for hf_entry in hf_dataset:
|
||||
for field in series_fields:
|
||||
gts_dataset.append(
|
||||
{
|
||||
"start": pd.Period(
|
||||
hf_entry["timestamp"][0],
|
||||
freq=dataset_freq,
|
||||
),
|
||||
"target": hf_entry[field],
|
||||
}
|
||||
)
|
||||
assert len(gts_dataset) == dataset_length
|
||||
|
||||
return gts_dataset
|
||||
|
||||
|
||||
def load_and_split_dataset(backtest_config: dict):
|
||||
hf_repo = backtest_config["hf_repo"]
|
||||
dataset_name = backtest_config["name"]
|
||||
offset = backtest_config["offset"]
|
||||
prediction_length = backtest_config["prediction_length"]
|
||||
num_rolls = backtest_config["num_rolls"]
|
||||
|
||||
# This is needed because the datasets in autogluon/chronos_datasets_extra cannot
|
||||
# be distribued due to license restrictions and must be generated on the fly
|
||||
trust_remote_code = True if hf_repo == "autogluon/chronos_datasets_extra" else False
|
||||
|
||||
ds = datasets.load_dataset(
|
||||
hf_repo, dataset_name, split="train", trust_remote_code=trust_remote_code
|
||||
)
|
||||
ds.set_format("numpy")
|
||||
|
||||
gts_dataset = to_gluonts_univariate(ds)
|
||||
|
||||
# Split dataset for evaluation
|
||||
_, test_template = split(gts_dataset, offset=offset)
|
||||
test_data = test_template.generate_instances(prediction_length, windows=num_rolls)
|
||||
|
||||
return test_data
|
||||
|
||||
|
||||
def generate_sample_forecasts(
|
||||
test_data_input: Iterable,
|
||||
pipeline: ChronosPipeline,
|
||||
prediction_length: int,
|
||||
num_samples: int,
|
||||
batch_size: int,
|
||||
):
|
||||
# Generate forecast samples
|
||||
forecast_samples = []
|
||||
for batch in tqdm(batcher(test_data_input, batch_size=batch_size)):
|
||||
context = [torch.tensor(entry["target"]) for entry in batch]
|
||||
forecast_samples.append(
|
||||
pipeline.predict(
|
||||
context,
|
||||
prediction_length=prediction_length,
|
||||
num_samples=num_samples,
|
||||
).numpy()
|
||||
)
|
||||
forecast_samples = np.concatenate(forecast_samples)
|
||||
|
||||
# Convert forecast samples into gluonts SampleForecast objects
|
||||
sample_forecasts = []
|
||||
for item, ts in zip(forecast_samples, test_data_input):
|
||||
forecast_start_date = ts["start"] + len(ts["target"])
|
||||
sample_forecasts.append(
|
||||
SampleForecast(samples=item, start_date=forecast_start_date)
|
||||
)
|
||||
|
||||
return sample_forecasts
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
config_path: Path,
|
||||
metrics_path: Path,
|
||||
chronos_model_id: str = "amazon/chronos-t5-small",
|
||||
device: str = "cuda",
|
||||
torch_dtype: str = "bfloat16",
|
||||
batch_size: int = 32,
|
||||
num_samples: int = 20,
|
||||
):
|
||||
if isinstance(torch_dtype, str):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
assert isinstance(torch_dtype, torch.dtype)
|
||||
|
||||
# Load Chronos
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
chronos_model_id,
|
||||
device_map=device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Load backtest configs
|
||||
with open(config_path) as fp:
|
||||
backtest_configs = yaml.safe_load(fp)
|
||||
|
||||
result_rows = []
|
||||
for config in backtest_configs:
|
||||
dataset_name = config["name"]
|
||||
prediction_length = config["prediction_length"]
|
||||
|
||||
logger.info(f"Loading {dataset_name}")
|
||||
test_data = load_and_split_dataset(backtest_config=config)
|
||||
|
||||
logger.info(
|
||||
f"Generating forecasts for {dataset_name} "
|
||||
f"({len(test_data.input)} time series)"
|
||||
)
|
||||
sample_forecasts = generate_sample_forecasts(
|
||||
test_data.input,
|
||||
pipeline=pipeline,
|
||||
prediction_length=prediction_length,
|
||||
num_samples=num_samples,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
logger.info(f"Evaluating forecasts for {dataset_name}")
|
||||
metrics = (
|
||||
evaluate_forecasts(
|
||||
sample_forecasts,
|
||||
test_data=test_data,
|
||||
metrics=[
|
||||
MASE(),
|
||||
MeanWeightedSumQuantileLoss(np.arange(0.1, 1.0, 0.1)),
|
||||
],
|
||||
batch_size=5000,
|
||||
)
|
||||
.reset_index(drop=True)
|
||||
.to_dict(orient="records")
|
||||
)
|
||||
result_rows.append(
|
||||
{"dataset": dataset_name, "model": chronos_model_id, **metrics[0]}
|
||||
)
|
||||
|
||||
# Save results to a CSV file
|
||||
results_df = (
|
||||
pd.DataFrame(result_rows)
|
||||
.rename(
|
||||
{"MASE[0.5]": "MASE", "mean_weighted_sum_quantile_loss": "WQL"},
|
||||
axis="columns",
|
||||
)
|
||||
.sort_values(by="dataset")
|
||||
)
|
||||
results_df.to_csv(metrics_path, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger("Chronos Evaluation")
|
||||
logger.setLevel(logging.INFO)
|
||||
app()
|
||||
16
scripts/evaluation/results/seasonal-naive-in-domain.csv
Normal file
16
scripts/evaluation/results/seasonal-naive-in-domain.csv
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
dataset,model,MASE,WQL
|
||||
electricity_15min,seasonal-naive,0.4978697476132387,0.1169378163151378
|
||||
m4_daily,seasonal-naive,3.278424323759728,0.0279332664832445
|
||||
m4_hourly,seasonal-naive,1.1932105781333862,0.0483091941403194
|
||||
m4_monthly,seasonal-naive,1.2597170386001693,0.1455332906092934
|
||||
m4_weekly,seasonal-naive,2.777295109814942,0.0633986476090776
|
||||
monash_electricity_hourly,seasonal-naive,1.839634785956572,0.1468968206229902
|
||||
monash_electricity_weekly,seasonal-naive,3.0371656285424,0.1979332504059267
|
||||
monash_kdd_cup_2018,seasonal-naive,0.9943785889052376,0.5555856702439576
|
||||
monash_london_smart_meters,seasonal-naive,0.9661872287141056,0.5413187715028914
|
||||
monash_pedestrian_counts,seasonal-naive,0.3691951941442247,0.3185271550430794
|
||||
monash_rideshare,seasonal-naive,1.2495987545425715,0.1860080644135506
|
||||
monash_temperature_rain,seasonal-naive,2.243384627173123,1.4244854980220072
|
||||
taxi_30min,seasonal-naive,1.160268631066241,0.4711417890926274
|
||||
uber_tlc_daily,seasonal-naive,1.37803447078482,0.2313550175912078
|
||||
uber_tlc_hourly,seasonal-naive,0.930916273455971,0.298849044501192
|
||||
|
28
scripts/evaluation/results/seasonal-naive-zero-shot.csv
Normal file
28
scripts/evaluation/results/seasonal-naive-zero-shot.csv
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
dataset,model,MASE,WQL
|
||||
ETTh,seasonal-naive,0.9316203114697056,0.1220896585205886
|
||||
ETTm,seasonal-naive,1.1693053852270578,0.1413480385734046
|
||||
dominick,seasonal-naive,0.8706150115348875,0.4529164093744346
|
||||
ercot,seasonal-naive,0.7613354813741452,0.0366036447606282
|
||||
exchange_rate,seasonal-naive,1.7401824286954128,0.0129841406759913
|
||||
m4_quarterly,seasonal-naive,1.6022471766126911,0.1186484661559648
|
||||
m4_yearly,seasonal-naive,3.974360261259571,0.1614389663357925
|
||||
m5,seasonal-naive,1.399206213076729,1.0240883478068443
|
||||
monash_australian_electricity,seasonal-naive,1.2533189641227642,0.0836951323308387
|
||||
monash_car_parts,seasonal-naive,1.2014638390969912,1.5999522140809177
|
||||
monash_cif_2016,seasonal-naive,1.289290577415544,0.0150830409089921
|
||||
monash_covid_deaths,seasonal-naive,46.91239825526407,0.1330848762571827
|
||||
monash_fred_md,seasonal-naive,1.1008000463101226,0.1222237702571737
|
||||
monash_hospital,seasonal-naive,0.9205278266364826,0.0726263373268254
|
||||
monash_m1_monthly,seasonal-naive,1.3144614957646543,0.1914632595030148
|
||||
monash_m1_quarterly,seasonal-naive,2.077536550805995,0.1495022062865622
|
||||
monash_m1_yearly,seasonal-naive,4.894322225232431,0.2092955931101782
|
||||
monash_m3_monthly,seasonal-naive,1.1462045758327934,0.1485446007554992
|
||||
monash_m3_quarterly,seasonal-naive,1.425343793700714,0.1012520529806161
|
||||
monash_m3_yearly,seasonal-naive,3.1717102364409517,0.1665329650420048
|
||||
monash_nn5_weekly,seasonal-naive,1.0628482559107015,0.1226908962169196
|
||||
monash_tourism_monthly,seasonal-naive,1.630939994944413,0.1041824322151567
|
||||
monash_tourism_quarterly,seasonal-naive,1.6989892627474672,0.1193750169177449
|
||||
monash_tourism_yearly,seasonal-naive,3.5520097206480883,0.2091826587673241
|
||||
monash_traffic,seasonal-naive,1.0767397173107436,0.3618532196990004
|
||||
monash_weather,seasonal-naive,1.0038475713182748,0.2165947349654047
|
||||
nn5,seasonal-naive,1.2917285866431214,0.4246208074843067
|
||||
|
Loading…
Reference in a new issue