chronos-forecasting/scripts/evaluation/evaluate.py
Abdul Fatir 72ab64166c
Add support for Chronos-Bolt models (#204)
*Issue #, if available:* N/A

*Description of changes:* This PR adds support for Chronos-Bolt models.

TODOs:

- [x] Update evaluation script
- [x] Fix and add tests for Bolt
- [x] Update docstrings
- [x] Update README example and mention Chronos-Bolt
- [x] Update results bar plot in README
- [x] Add versions for libraries in `pyproject.toml`
- [x] Check that the training and eval scripts work
- [x] Change `autogluon` -> `amazon` in model names

Post Merge:
- [ ] Update Citation style in README, both Github and HuggingFace repos
- [ ] Remove note about AutoGluon
- [ ] Update READMEs of original Chronos models to refer to Chronos-Bolt

NOTE: To be merged after Chronos-Bolt models are available under the
`amazon` namespace on HF.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
Co-authored-by: Caner Turkmen <turkmen.ac@gmail.com>
Co-authored-by: Lorenzo Stella <stellalo@amazon.com>
2024-11-26 17:47:14 +01:00

407 lines
11 KiB
Python

import logging
from pathlib import Path
from typing import Iterable, Optional
import datasets
import numpy as np
import pandas as pd
import torch
import typer
import yaml
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.itertools import batcher
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import QuantileForecast, SampleForecast
from tqdm.auto import tqdm
from chronos import (
BaseChronosPipeline,
ChronosBoltPipeline,
ChronosPipeline,
ForecastType,
)
app = typer.Typer(pretty_exceptions_enable=False)
# Taken from pandas._libs.tslibs.dtypes.OFFSET_TO_PERIOD_FREQSTR
offset_alias_to_period_alias = {
"WEEKDAY": "D",
"EOM": "M",
"BME": "M",
"SME": "M",
"BQS": "Q",
"QS": "Q",
"BQE": "Q",
"BQE-DEC": "Q",
"BQE-JAN": "Q",
"BQE-FEB": "Q",
"BQE-MAR": "Q",
"BQE-APR": "Q",
"BQE-MAY": "Q",
"BQE-JUN": "Q",
"BQE-JUL": "Q",
"BQE-AUG": "Q",
"BQE-SEP": "Q",
"BQE-OCT": "Q",
"BQE-NOV": "Q",
"MS": "M",
"D": "D",
"B": "B",
"min": "min",
"s": "s",
"ms": "ms",
"us": "us",
"ns": "ns",
"h": "h",
"QE": "Q",
"QE-DEC": "Q-DEC",
"QE-JAN": "Q-JAN",
"QE-FEB": "Q-FEB",
"QE-MAR": "Q-MAR",
"QE-APR": "Q-APR",
"QE-MAY": "Q-MAY",
"QE-JUN": "Q-JUN",
"QE-JUL": "Q-JUL",
"QE-AUG": "Q-AUG",
"QE-SEP": "Q-SEP",
"QE-OCT": "Q-OCT",
"QE-NOV": "Q-NOV",
"YE": "Y",
"YE-DEC": "Y-DEC",
"YE-JAN": "Y-JAN",
"YE-FEB": "Y-FEB",
"YE-MAR": "Y-MAR",
"YE-APR": "Y-APR",
"YE-MAY": "Y-MAY",
"YE-JUN": "Y-JUN",
"YE-JUL": "Y-JUL",
"YE-AUG": "Y-AUG",
"YE-SEP": "Y-SEP",
"YE-OCT": "Y-OCT",
"YE-NOV": "Y-NOV",
"W": "W",
"ME": "M",
"Y": "Y",
"BYE": "Y",
"BYE-DEC": "Y",
"BYE-JAN": "Y",
"BYE-FEB": "Y",
"BYE-MAR": "Y",
"BYE-APR": "Y",
"BYE-MAY": "Y",
"BYE-JUN": "Y",
"BYE-JUL": "Y",
"BYE-AUG": "Y",
"BYE-SEP": "Y",
"BYE-OCT": "Y",
"BYE-NOV": "Y",
"YS": "Y",
"BYS": "Y",
"QS-JAN": "Q",
"QS-FEB": "Q",
"QS-MAR": "Q",
"QS-APR": "Q",
"QS-MAY": "Q",
"QS-JUN": "Q",
"QS-JUL": "Q",
"QS-AUG": "Q",
"QS-SEP": "Q",
"QS-OCT": "Q",
"QS-NOV": "Q",
"QS-DEC": "Q",
"BQS-JAN": "Q",
"BQS-FEB": "Q",
"BQS-MAR": "Q",
"BQS-APR": "Q",
"BQS-MAY": "Q",
"BQS-JUN": "Q",
"BQS-JUL": "Q",
"BQS-AUG": "Q",
"BQS-SEP": "Q",
"BQS-OCT": "Q",
"BQS-NOV": "Q",
"BQS-DEC": "Q",
"YS-JAN": "Y",
"YS-FEB": "Y",
"YS-MAR": "Y",
"YS-APR": "Y",
"YS-MAY": "Y",
"YS-JUN": "Y",
"YS-JUL": "Y",
"YS-AUG": "Y",
"YS-SEP": "Y",
"YS-OCT": "Y",
"YS-NOV": "Y",
"YS-DEC": "Y",
"BYS-JAN": "Y",
"BYS-FEB": "Y",
"BYS-MAR": "Y",
"BYS-APR": "Y",
"BYS-MAY": "Y",
"BYS-JUN": "Y",
"BYS-JUL": "Y",
"BYS-AUG": "Y",
"BYS-SEP": "Y",
"BYS-OCT": "Y",
"BYS-NOV": "Y",
"BYS-DEC": "Y",
"Y-JAN": "Y-JAN",
"Y-FEB": "Y-FEB",
"Y-MAR": "Y-MAR",
"Y-APR": "Y-APR",
"Y-MAY": "Y-MAY",
"Y-JUN": "Y-JUN",
"Y-JUL": "Y-JUL",
"Y-AUG": "Y-AUG",
"Y-SEP": "Y-SEP",
"Y-OCT": "Y-OCT",
"Y-NOV": "Y-NOV",
"Y-DEC": "Y-DEC",
"Q-JAN": "Q-JAN",
"Q-FEB": "Q-FEB",
"Q-MAR": "Q-MAR",
"Q-APR": "Q-APR",
"Q-MAY": "Q-MAY",
"Q-JUN": "Q-JUN",
"Q-JUL": "Q-JUL",
"Q-AUG": "Q-AUG",
"Q-SEP": "Q-SEP",
"Q-OCT": "Q-OCT",
"Q-NOV": "Q-NOV",
"Q-DEC": "Q-DEC",
"W-MON": "W-MON",
"W-TUE": "W-TUE",
"W-WED": "W-WED",
"W-THU": "W-THU",
"W-FRI": "W-FRI",
"W-SAT": "W-SAT",
"W-SUN": "W-SUN",
}
def to_gluonts_univariate(hf_dataset: datasets.Dataset):
series_fields = [
col
for col in hf_dataset.features
if isinstance(hf_dataset.features[col], datasets.Sequence)
]
series_fields.remove("timestamp")
dataset_length = hf_dataset.info.splits["train"].num_examples * len(series_fields)
dataset_freq = pd.infer_freq(hf_dataset[0]["timestamp"])
dataset_freq = offset_alias_to_period_alias.get(dataset_freq, dataset_freq)
gts_dataset = []
for hf_entry in hf_dataset:
for field in series_fields:
gts_dataset.append(
{
"start": pd.Period(
hf_entry["timestamp"][0],
freq=dataset_freq,
),
"target": hf_entry[field],
}
)
assert len(gts_dataset) == dataset_length
return gts_dataset
def load_and_split_dataset(backtest_config: dict):
hf_repo = backtest_config["hf_repo"]
dataset_name = backtest_config["name"]
offset = backtest_config["offset"]
prediction_length = backtest_config["prediction_length"]
num_rolls = backtest_config["num_rolls"]
# This is needed because the datasets in autogluon/chronos_datasets_extra cannot
# be distribued due to license restrictions and must be generated on the fly
trust_remote_code = True if hf_repo == "autogluon/chronos_datasets_extra" else False
ds = datasets.load_dataset(
hf_repo, dataset_name, split="train", trust_remote_code=trust_remote_code
)
ds.set_format("numpy")
gts_dataset = to_gluonts_univariate(ds)
# Split dataset for evaluation
_, test_template = split(gts_dataset, offset=offset)
test_data = test_template.generate_instances(prediction_length, windows=num_rolls)
return test_data
def generate_forecasts(
test_data_input: Iterable,
pipeline: BaseChronosPipeline,
prediction_length: int,
batch_size: int,
**predict_kwargs,
):
# Generate forecasts
forecast_outputs = []
for batch in tqdm(batcher(test_data_input, batch_size=batch_size)):
context = [torch.tensor(entry["target"]) for entry in batch]
forecast_outputs.append(
pipeline.predict(
context,
prediction_length=prediction_length,
**predict_kwargs,
).numpy()
)
forecast_outputs = np.concatenate(forecast_outputs)
# Convert forecast samples into gluonts Forecast objects
forecasts = []
for item, ts in zip(forecast_outputs, test_data_input):
forecast_start_date = ts["start"] + len(ts["target"])
if pipeline.forecast_type == ForecastType.SAMPLES:
forecasts.append(
SampleForecast(samples=item, start_date=forecast_start_date)
)
elif pipeline.forecast_type == ForecastType.QUANTILES:
forecasts.append(
QuantileForecast(
forecast_arrays=item,
forecast_keys=list(map(str, pipeline.quantiles)),
start_date=forecast_start_date,
)
)
return forecasts
@app.command()
def main(
config_path: Path,
metrics_path: Path,
chronos_model_id: str = "amazon/chronos-t5-small",
device: str = "cuda",
torch_dtype: str = "bfloat16",
batch_size: int = 32,
num_samples: int = 20,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
):
"""Evaluate Chronos models.
Parameters
----------
config_path : Path
Path to the evaluation config. See ./configs/.
metrics_path : Path
Path to the CSV file where metrics will be saved.
chronos_model_id : str, optional, default = "amazon/chronos-t5-small"
HuggingFace ID of the Chronos model or local path
Available models on HuggingFace:
Chronos:
- amazon/chronos-t5-tiny
- amazon/chronos-t5-mini
- amazon/chronos-t5-small
- amazon/chronos-t5-base
- amazon/chronos-t5-large
Chronos-Bolt:
- amazon/chronos-bolt-tiny
- amazon/chronos-bolt-mini
- amazon/chronos-bolt-small
- amazon/chronos-bolt-base
device : str, optional, default = "cuda"
Device on which inference will be performed
torch_dtype : str, optional
Model's dtype, by default "bfloat16"
batch_size : int, optional, default = 32
Batch size for inference. For Chronos-Bolt models, significantly larger
batch sizes can be used
num_samples : int, optional, default = 20
Number of samples to draw when using the original Chronos models
temperature : Optional[float], optional, default = 1.0
Softmax temperature to used for the original Chronos models
top_k : Optional[int], optional, default = 50
Top-K sampling, by default None
top_p : Optional[float], optional, default = 1.0
Top-p sampling, by default None
"""
if isinstance(torch_dtype, str):
torch_dtype = getattr(torch, torch_dtype)
assert isinstance(torch_dtype, torch.dtype)
# Load Chronos
pipeline = BaseChronosPipeline.from_pretrained(
chronos_model_id,
device_map=device,
torch_dtype=torch_dtype,
)
if isinstance(pipeline, ChronosPipeline):
predict_kwargs = dict(
num_samples=num_samples,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
elif isinstance(pipeline, ChronosBoltPipeline):
predict_kwargs = {}
# Load backtest configs
with open(config_path) as fp:
backtest_configs = yaml.safe_load(fp)
result_rows = []
for config in backtest_configs:
dataset_name = config["name"]
prediction_length = config["prediction_length"]
logger.info(f"Loading {dataset_name}")
test_data = load_and_split_dataset(backtest_config=config)
logger.info(
f"Generating forecasts for {dataset_name} "
f"({len(test_data.input)} time series)"
)
forecasts = generate_forecasts(
test_data.input,
pipeline=pipeline,
prediction_length=prediction_length,
batch_size=batch_size,
**predict_kwargs,
)
logger.info(f"Evaluating forecasts for {dataset_name}")
metrics = (
evaluate_forecasts(
forecasts,
test_data=test_data,
metrics=[
MASE(),
MeanWeightedSumQuantileLoss(np.arange(0.1, 1.0, 0.1)),
],
batch_size=5000,
)
.reset_index(drop=True)
.to_dict(orient="records")
)
result_rows.append(
{"dataset": dataset_name, "model": chronos_model_id, **metrics[0]}
)
# Save results to a CSV file
results_df = (
pd.DataFrame(result_rows)
.rename(
{"MASE[0.5]": "MASE", "mean_weighted_sum_quantile_loss": "WQL"},
axis="columns",
)
.sort_values(by="dataset")
)
results_df.to_csv(metrics_path, index=False)
if __name__ == "__main__":
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("Chronos Evaluation")
logger.setLevel(logging.INFO)
app()