diff --git a/README.md b/README.md
index 145733b..1b9b647 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,8 @@
## 🚀 News
-- **27 June 2024**: 🚀 [Released datasets](https://huggingface.co/datasets/autogluon/chronos_datasets) used in the paper and an [evaluation script](./scripts/README.md#evaluating-chronos-models) to compute the WQL and MASE scores reported in the paper.
+- **26 Nov 2024**: ⚡️ Chronos-Bolt models released [on HuggingFace](https://huggingface.co/collections/amazon/chronos-models-65f1791d630a8d57cb718444). Chronos-Bolt models are more accurate (5% lower error), up to 250x faster and 20x more memory efficient than the original Chronos models of the same size!
+- **27 Jun 2024**: 🚀 [Released datasets](https://huggingface.co/datasets/autogluon/chronos_datasets) used in the paper and an [evaluation script](./scripts/README.md#evaluating-chronos-models) to compute the WQL and MASE scores reported in the paper.
- **17 May 2024**: 🐛 Fixed an off-by-one error in bin indices in the `output_transform`. This simple fix significantly improves the overall performance of Chronos. We will update the results in the next revision on ArXiv.
- **10 May 2024**: 🚀 We added the code for pretraining and fine-tuning Chronos models. You can find it in [this folder](./scripts/training). We also added [a script](./scripts/kernel-synth.py) for generating synthetic time series data from Gaussian processes (KernelSynth; see Section 4.2 in the paper for details). Check out the [usage examples](./scripts/).
- **19 Apr 2024**: 🚀 Chronos is now supported on [AutoGluon-TimeSeries](https://auto.gluon.ai/stable/tutorials/timeseries/index.html), the powerful AutoML package for time series forecasting which enables model ensembles, cloud deployments, and much more. Get started with the [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).
@@ -52,62 +53,72 @@ The models in this repository are based on the [T5 architecture](https://arxiv.o
| [**chronos-t5-small**](https://huggingface.co/amazon/chronos-t5-small) | 46M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) |
| [**chronos-t5-base**](https://huggingface.co/amazon/chronos-t5-base) | 200M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) |
| [**chronos-t5-large**](https://huggingface.co/amazon/chronos-t5-large) | 710M | [t5-efficient-large](https://huggingface.co/google/t5-efficient-large) |
+| [**chronos-bolt-tiny**](https://huggingface.co/amazon/chronos-bolt-tiny) | 9M | [t5-efficient-tiny](https://huggingface.co/google/t5-efficient-tiny) |
+| [**chronos-bolt-mini**](https://huggingface.co/amazon/chronos-bolt-mini) | 21M | [t5-efficient-mini](https://huggingface.co/google/t5-efficient-mini) |
+| [**chronos-bolt-small**](https://huggingface.co/amazon/chronos-bolt-small) | 48M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) |
+| [**chronos-bolt-base**](https://huggingface.co/amazon/chronos-bolt-base) | 205M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) |
### Zero-Shot Results
-The following figure showcases the remarkable **zero-shot** performance of Chronos models on 27 datasets against local models, task-specific models and other pretrained models. For details on the evaluation setup and other results, please refer to [the paper](https://arxiv.org/abs/2403.07815).
+The following figure showcases the remarkable **zero-shot** performance of Chronos and Chronos-Bolt models on 27 datasets against local models, task-specific models and other pretrained models. For details on the evaluation setup and other results, please refer to [the paper](https://arxiv.org/abs/2403.07815).
-
+
- Fig. 2: Performance of different models on Benchmark II, comprising 27 datasets not seen by Chronos models during training. This benchmark provides insights into the zero-shot performance of Chronos models against local statistical models, which fit parameters individually for each time series, task-specific models trained on each task, and pretrained models trained on a large corpus of time series. Pretrained Models (Other) indicates that some (or all) of the datasets in Benchmark II may have been in the training corpus of these models. The probabilistic (WQL) and point (MASE) forecasting metrics were normalized using the scores of the Seasonal Naive baseline and aggregated through a geometric mean to obtain the Agg. Relative WQL and MASE, respectively.
+ Fig. 2: Performance of different models on Benchmark II, comprising 27 datasets not seen by Chronos and Chronos-Bolt models during training. This benchmark provides insights into the zero-shot performance of Chronos and Chronos-Bolt models against local statistical models, which fit parameters individually for each time series, task-specific models trained on each task, and pretrained models trained on a large corpus of time series. Pretrained Models (Other) indicates that some (or all) of the datasets in Benchmark II may have been in the training corpus of these models. The probabilistic (WQL) and point (MASE) forecasting metrics were normalized using the scores of the Seasonal Naive baseline and aggregated through a geometric mean to obtain the Agg. Relative WQL and MASE, respectively.
## 📈 Usage
-To perform inference with Chronos models, install this package by running:
+To perform inference with Chronos or Chronos-Bolt models, install this package by running:
```
pip install git+https://github.com/amazon-science/chronos-forecasting.git
```
> [!TIP]
-> The recommended way of using Chronos for production use cases is through [AutoGluon](https://auto.gluon.ai), which features ensembling with other statistical and machine learning models for time series forecasting as well as seamless deployments on AWS with SageMaker 🧠. Check out the AutoGluon Chronos [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).
+> This repository is intended for research purposes and provides a minimal interface to Chronos models. The recommended way of using Chronos for production use cases is through [AutoGluon](https://auto.gluon.ai), which features effortless fine-tuning, augmenting Chronos models with exogenous information through covariate regressors, ensembling with other statistical and machine learning models, as well as seamless deployments on AWS with SageMaker 🧠. Check out the AutoGluon Chronos [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).
### Forecasting
-A minimal example showing how to perform forecasting using Chronos models:
+A minimal example showing how to perform forecasting using Chronos and Chronos-Bolt models:
```python
import pandas as pd # requires: pip install pandas
import torch
-from chronos import ChronosPipeline
+from chronos import BaseChronosPipeline
-pipeline = ChronosPipeline.from_pretrained(
- "amazon/chronos-t5-small",
+pipeline = BaseChronosPipeline.from_pretrained(
+ "amazon/chronos-t5-small", # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
device_map="cuda", # use "cpu" for CPU inference and "mps" for Apple Silicon
torch_dtype=torch.bfloat16,
)
-df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
+df = pd.read_csv(
+ "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
+)
# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
-# forecast shape: [num_series, num_samples, prediction_length]
+# The original Chronos models generate forecast samples, so forecast has shape
+# [num_series, num_samples, prediction_length].
+# Chronos-Bolt models generate quantile forecasts, so forecast has shape
+# [num_series, num_quantiles, prediction_length].
forecast = pipeline.predict(
- context=torch.tensor(df["#Passengers"]),
- prediction_length=12,
- num_samples=20,
+ context=torch.tensor(df["#Passengers"]), prediction_length=12
)
```
More options for `pipeline.predict` can be found with:
```python
-print(ChronosPipeline.predict.__doc__)
+from chronos import ChronosPipeline, ChronosBoltPipeline
+
+print(ChronosPipeline.predict.__doc__) # for Chronos models
+print(ChronosBoltPipeline.predict.__doc__) # for Chronos-Bolt models
```
We can now visualize the forecast:
diff --git a/figures/zero_shot-agg_scaled_score.png b/figures/zero_shot-agg_scaled_score.png
deleted file mode 100644
index 0210d1c..0000000
Binary files a/figures/zero_shot-agg_scaled_score.png and /dev/null differ
diff --git a/figures/zero_shot-agg_scaled_score.svg b/figures/zero_shot-agg_scaled_score.svg
new file mode 100644
index 0000000..560c4de
--- /dev/null
+++ b/figures/zero_shot-agg_scaled_score.svg
@@ -0,0 +1,4875 @@
+
+
+
diff --git a/pyproject.toml b/pyproject.toml
index 55dd210..c3b916d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,19 +1,19 @@
[project]
name = "chronos"
-version = "1.2.1"
+version = "1.3.0"
requires-python = ">=3.8"
license = { file = "LICENSE" }
dependencies = [
- "torch~=2.0", # package was tested on 2.2
- "transformers~=4.30",
- "accelerate",
+ "torch>=2.0,<2.6", # package was tested on 2.2
+ "transformers>=4.30,<4.48",
+ "accelerate>=0.32,<1",
]
[project.optional-dependencies]
test = ["pytest~=8.0", "numpy~=1.21"]
typecheck = ["mypy~=1.9"]
-training = ["gluonts[pro]", "numpy", "tensorboard", "typer", "typer-config", "joblib", "scikit-learn"]
-evaluation = ["gluonts[pro]", "datasets", "numpy", "typer"]
+training = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer", "typer-config", "joblib", "scikit-learn", "tensorboard"]
+evaluation = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer"]
[tool.mypy]
ignore_missing_imports = true
diff --git a/scripts/evaluation/agg-relative-score.py b/scripts/evaluation/agg-relative-score.py
new file mode 100644
index 0000000..0a308e7
--- /dev/null
+++ b/scripts/evaluation/agg-relative-score.py
@@ -0,0 +1,60 @@
+import pandas as pd
+import typer
+from scipy.stats import gmean
+from pathlib import Path
+
+app = typer.Typer(pretty_exceptions_enable=False)
+DEFAULT_RESULTS_DIR = Path(__file__).parent / "results"
+
+
+def agg_relative_score(model_csv: Path, baseline_csv: Path):
+ model_df = pd.read_csv(model_csv).set_index("dataset")
+ baseline_df = pd.read_csv(baseline_csv).set_index("dataset")
+ relative_score = model_df.drop("model", axis="columns") / baseline_df.drop(
+ "model", axis="columns"
+ )
+ return relative_score.agg(gmean)
+
+
+@app.command()
+def main(
+ model_name: str,
+ baseline_name: str = "seasonal-naive",
+ results_dir: Path = DEFAULT_RESULTS_DIR,
+):
+ """
+ Compute the aggregated relative score as reported in the Chronos paper.
+ Results will be saved to {results_dir}/{model_name}-agg-rel-scores.csv
+
+ Parameters
+ ----------
+ model_name : str
+ Name of the model used in the CSV files. The in-domain and zero-shot CSVs
+ are expected to be named {model_name}-in-domain.csv and {model_name}-zero-shot.csv.
+ results_dir : Path, optional, default = results/
+ Directory where results CSVs generated by evaluate.py are stored
+ """
+
+ in_domain_agg_score_df = agg_relative_score(
+ results_dir / f"{model_name}-in-domain.csv",
+ results_dir / f"{baseline_name}-in-domain.csv",
+ )
+ in_domain_agg_score_df.name = "value"
+ in_domain_agg_score_df.index.name = "metric"
+
+ zero_shot_agg_score_df = agg_relative_score(
+ results_dir / f"{model_name}-zero-shot.csv",
+ results_dir / f"{baseline_name}-zero-shot.csv",
+ )
+ zero_shot_agg_score_df.name = "value"
+ zero_shot_agg_score_df.index.name = "metric"
+
+ agg_score_df = pd.concat(
+ {"in-domain": in_domain_agg_score_df, "zero-shot": zero_shot_agg_score_df},
+ names=["benchmark"],
+ )
+ agg_score_df.to_csv(f"{results_dir}/{model_name}-agg-rel-scores.csv")
+
+
+if __name__ == "__main__":
+ app()
diff --git a/scripts/evaluation/evaluate.py b/scripts/evaluation/evaluate.py
index 756f544..8995e31 100644
--- a/scripts/evaluation/evaluate.py
+++ b/scripts/evaluation/evaluate.py
@@ -12,10 +12,15 @@ from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.itertools import batcher
from gluonts.model.evaluation import evaluate_forecasts
-from gluonts.model.forecast import SampleForecast
+from gluonts.model.forecast import QuantileForecast, SampleForecast
from tqdm.auto import tqdm
-from chronos import ChronosPipeline
+from chronos import (
+ BaseChronosPipeline,
+ ChronosBoltPipeline,
+ ChronosPipeline,
+ ForecastType,
+)
app = typer.Typer(pretty_exceptions_enable=False)
@@ -228,37 +233,45 @@ def load_and_split_dataset(backtest_config: dict):
return test_data
-def generate_sample_forecasts(
+def generate_forecasts(
test_data_input: Iterable,
- pipeline: ChronosPipeline,
+ pipeline: BaseChronosPipeline,
prediction_length: int,
batch_size: int,
- num_samples: int,
**predict_kwargs,
):
- # Generate forecast samples
- forecast_samples = []
+ # Generate forecasts
+ forecast_outputs = []
for batch in tqdm(batcher(test_data_input, batch_size=batch_size)):
context = [torch.tensor(entry["target"]) for entry in batch]
- forecast_samples.append(
+ forecast_outputs.append(
pipeline.predict(
context,
prediction_length=prediction_length,
- num_samples=num_samples,
**predict_kwargs,
).numpy()
)
- forecast_samples = np.concatenate(forecast_samples)
+ forecast_outputs = np.concatenate(forecast_outputs)
- # Convert forecast samples into gluonts SampleForecast objects
- sample_forecasts = []
- for item, ts in zip(forecast_samples, test_data_input):
+ # Convert forecast samples into gluonts Forecast objects
+ forecasts = []
+ for item, ts in zip(forecast_outputs, test_data_input):
forecast_start_date = ts["start"] + len(ts["target"])
- sample_forecasts.append(
- SampleForecast(samples=item, start_date=forecast_start_date)
- )
- return sample_forecasts
+ if pipeline.forecast_type == ForecastType.SAMPLES:
+ forecasts.append(
+ SampleForecast(samples=item, start_date=forecast_start_date)
+ )
+ elif pipeline.forecast_type == ForecastType.QUANTILES:
+ forecasts.append(
+ QuantileForecast(
+ forecast_arrays=item,
+ forecast_keys=list(map(str, pipeline.quantiles)),
+ start_date=forecast_start_date,
+ )
+ )
+
+ return forecasts
@app.command()
@@ -274,17 +287,65 @@ def main(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
):
+ """Evaluate Chronos models.
+
+ Parameters
+ ----------
+ config_path : Path
+ Path to the evaluation config. See ./configs/.
+ metrics_path : Path
+ Path to the CSV file where metrics will be saved.
+ chronos_model_id : str, optional, default = "amazon/chronos-t5-small"
+ HuggingFace ID of the Chronos model or local path
+ Available models on HuggingFace:
+ Chronos:
+ - amazon/chronos-t5-tiny
+ - amazon/chronos-t5-mini
+ - amazon/chronos-t5-small
+ - amazon/chronos-t5-base
+ - amazon/chronos-t5-large
+ Chronos-Bolt:
+ - amazon/chronos-bolt-tiny
+ - amazon/chronos-bolt-mini
+ - amazon/chronos-bolt-small
+ - amazon/chronos-bolt-base
+ device : str, optional, default = "cuda"
+ Device on which inference will be performed
+ torch_dtype : str, optional
+ Model's dtype, by default "bfloat16"
+ batch_size : int, optional, default = 32
+ Batch size for inference. For Chronos-Bolt models, significantly larger
+ batch sizes can be used
+ num_samples : int, optional, default = 20
+ Number of samples to draw when using the original Chronos models
+ temperature : Optional[float], optional, default = 1.0
+ Softmax temperature to used for the original Chronos models
+ top_k : Optional[int], optional, default = 50
+ Top-K sampling, by default None
+ top_p : Optional[float], optional, default = 1.0
+ Top-p sampling, by default None
+ """
if isinstance(torch_dtype, str):
torch_dtype = getattr(torch, torch_dtype)
assert isinstance(torch_dtype, torch.dtype)
# Load Chronos
- pipeline = ChronosPipeline.from_pretrained(
+ pipeline = BaseChronosPipeline.from_pretrained(
chronos_model_id,
device_map=device,
torch_dtype=torch_dtype,
)
+ if isinstance(pipeline, ChronosPipeline):
+ predict_kwargs = dict(
+ num_samples=num_samples,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ )
+ elif isinstance(pipeline, ChronosBoltPipeline):
+ predict_kwargs = {}
+
# Load backtest configs
with open(config_path) as fp:
backtest_configs = yaml.safe_load(fp)
@@ -301,21 +362,18 @@ def main(
f"Generating forecasts for {dataset_name} "
f"({len(test_data.input)} time series)"
)
- sample_forecasts = generate_sample_forecasts(
+ forecasts = generate_forecasts(
test_data.input,
pipeline=pipeline,
prediction_length=prediction_length,
batch_size=batch_size,
- num_samples=num_samples,
- temperature=temperature,
- top_k=top_k,
- top_p=top_p,
+ **predict_kwargs,
)
logger.info(f"Evaluating forecasts for {dataset_name}")
metrics = (
evaluate_forecasts(
- sample_forecasts,
+ forecasts,
test_data=test_data,
metrics=[
MASE(),
diff --git a/scripts/evaluation/results/chronos-bolt-base-agg-rel-scores.csv b/scripts/evaluation/results/chronos-bolt-base-agg-rel-scores.csv
new file mode 100644
index 0000000..3912150
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-base-agg-rel-scores.csv
@@ -0,0 +1,5 @@
+benchmark,metric,value
+in-domain,MASE,0.6800133628315155
+in-domain,WQL,0.5339263811489279
+zero-shot,MASE,0.7914551113353537
+zero-shot,WQL,0.6241424984163773
diff --git a/scripts/evaluation/results/chronos-bolt-base-in-domain.csv b/scripts/evaluation/results/chronos-bolt-base-in-domain.csv
new file mode 100644
index 0000000..abe3347
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-base-in-domain.csv
@@ -0,0 +1,16 @@
+dataset,model,MASE,WQL
+electricity_15min,amazon/chronos-bolt-base,0.41069374835605243,0.0703533790998506
+m4_daily,amazon/chronos-bolt-base,3.205192517121196,0.02110308498174413
+m4_hourly,amazon/chronos-bolt-base,0.8350129849014075,0.025353803894164
+m4_monthly,amazon/chronos-bolt-base,0.9491758928362231,0.09382496106659234
+m4_weekly,amazon/chronos-bolt-base,2.0847827409162742,0.03816605075768161
+monash_electricity_hourly,amazon/chronos-bolt-base,1.254966217685461,0.09442192616975713
+monash_electricity_weekly,amazon/chronos-bolt-base,1.8391546050108039,0.06410971963960499
+monash_kdd_cup_2018,amazon/chronos-bolt-base,0.6405985809360102,0.2509172188706336
+monash_london_smart_meters,amazon/chronos-bolt-base,0.701398572604996,0.3218915088923906
+monash_pedestrian_counts,amazon/chronos-bolt-base,0.2646412642278343,0.18789459806066328
+monash_rideshare,amazon/chronos-bolt-base,0.7695376426829713,0.11637119433040358
+monash_temperature_rain,amazon/chronos-bolt-base,0.8983612698773724,0.6050555216496304
+taxi_30min,amazon/chronos-bolt-base,0.7688908266765317,0.2363178601205094
+uber_tlc_daily,amazon/chronos-bolt-base,0.8231767493519677,0.0926036406916842
+uber_tlc_hourly,amazon/chronos-bolt-base,0.6632193728217927,0.14987786887626975
diff --git a/scripts/evaluation/results/chronos-bolt-base-zero-shot.csv b/scripts/evaluation/results/chronos-bolt-base-zero-shot.csv
new file mode 100644
index 0000000..833658a
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-base-zero-shot.csv
@@ -0,0 +1,28 @@
+dataset,model,MASE,WQL
+ETTh,amazon/chronos-bolt-base,0.7479154031956647,0.07062173821055001
+ETTm,amazon/chronos-bolt-base,0.6334357237512225,0.052261607745858835
+dominick,amazon/chronos-bolt-base,0.8560272479913918,0.3453573743726445
+ercot,amazon/chronos-bolt-base,0.6933217425507392,0.02142183038021456
+exchange_rate,amazon/chronos-bolt-base,1.7095176257412634,0.01200682136751536
+m4_quarterly,amazon/chronos-bolt-base,1.2244670010522907,0.0771066518089854
+m4_yearly,amazon/chronos-bolt-base,3.513752058541554,0.12142798053483984
+m5,amazon/chronos-bolt-base,0.9152230096463854,0.561999688057527
+monash_australian_electricity,amazon/chronos-bolt-base,0.7403239930185613,0.03584034231329335
+monash_car_parts,amazon/chronos-bolt-base,0.8550263912438314,0.9945122291263591
+monash_cif_2016,amazon/chronos-bolt-base,0.9988541862779904,0.016456104842296485
+monash_covid_deaths,amazon/chronos-bolt-base,38.901749109066415,0.047410971217640714
+monash_fred_md,amazon/chronos-bolt-base,0.6468787708795645,0.04185083716355386
+monash_hospital,amazon/chronos-bolt-base,0.6883138394434054,0.057032869931903894
+monash_m1_monthly,amazon/chronos-bolt-base,1.0997677446267855,0.1392311148066238
+monash_m1_quarterly,amazon/chronos-bolt-base,1.7737851980875563,0.1007118219350403
+monash_m1_yearly,amazon/chronos-bolt-base,4.404672537832342,0.1504617654430952
+monash_m3_monthly,amazon/chronos-bolt-base,0.8510696834878182,0.09269673913736748
+monash_m3_quarterly,amazon/chronos-bolt-base,1.2890908822598466,0.07615133571216029
+monash_m3_yearly,amazon/chronos-bolt-base,2.9067097980770082,0.12934285625258413
+monash_nn5_weekly,amazon/chronos-bolt-base,0.9158766337957451,0.08352114810139548
+monash_tourism_monthly,amazon/chronos-bolt-base,1.5283388458731357,0.09026425492612797
+monash_tourism_quarterly,amazon/chronos-bolt-base,1.756127005530011,0.06448060953595125
+monash_tourism_yearly,amazon/chronos-bolt-base,3.691545772463519,0.16548820700844424
+monash_traffic,amazon/chronos-bolt-base,0.7843310867739336,0.23148632068725078
+monash_weather,amazon/chronos-bolt-base,0.8115247139672316,0.13350830777170594
+nn5,amazon/chronos-bolt-base,0.5764084996361287,0.1500519584148468
diff --git a/scripts/evaluation/results/chronos-bolt-mini-agg-rel-scores.csv b/scripts/evaluation/results/chronos-bolt-mini-agg-rel-scores.csv
new file mode 100644
index 0000000..9fba001
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-mini-agg-rel-scores.csv
@@ -0,0 +1,5 @@
+benchmark,metric,value
+in-domain,MASE,0.7268373301543752
+in-domain,WQL,0.565140251955324
+zero-shot,MASE,0.8221798917822493
+zero-shot,WQL,0.6441645845380903
diff --git a/scripts/evaluation/results/chronos-bolt-mini-in-domain.csv b/scripts/evaluation/results/chronos-bolt-mini-in-domain.csv
new file mode 100644
index 0000000..fac01bf
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-mini-in-domain.csv
@@ -0,0 +1,16 @@
+dataset,model,MASE,WQL
+electricity_15min,amazon/chronos-bolt-mini,0.44185193304080733,0.0731477927531107
+m4_daily,amazon/chronos-bolt-mini,3.1342608828747456,0.0206872246743766
+m4_hourly,amazon/chronos-bolt-mini,0.9218285923038745,0.024383114886067574
+m4_monthly,amazon/chronos-bolt-mini,0.9628339921394529,0.09502498697494888
+m4_weekly,amazon/chronos-bolt-mini,2.2330452369879255,0.039393515325238534
+monash_electricity_hourly,amazon/chronos-bolt-mini,1.6195944363428718,0.11468972600782207
+monash_electricity_weekly,amazon/chronos-bolt-mini,1.866105365159433,0.06019900031840434
+monash_kdd_cup_2018,amazon/chronos-bolt-mini,0.74790954883436,0.3012661161484388
+monash_london_smart_meters,amazon/chronos-bolt-mini,0.7187830347765344,0.32984510693830227
+monash_pedestrian_counts,amazon/chronos-bolt-mini,0.308633944815819,0.23331301029432483
+monash_rideshare,amazon/chronos-bolt-mini,0.818948044410056,0.1297966960374544
+monash_temperature_rain,amazon/chronos-bolt-mini,0.9035244443682741,0.605031064086567
+taxi_30min,amazon/chronos-bolt-mini,0.812010120941363,0.25232294549917317
+uber_tlc_daily,amazon/chronos-bolt-mini,0.8507256206478295,0.10101757743084538
+uber_tlc_hourly,amazon/chronos-bolt-mini,0.6685484898085609,0.1515245941548974
diff --git a/scripts/evaluation/results/chronos-bolt-mini-zero-shot.csv b/scripts/evaluation/results/chronos-bolt-mini-zero-shot.csv
new file mode 100644
index 0000000..2060004
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-mini-zero-shot.csv
@@ -0,0 +1,28 @@
+dataset,model,MASE,WQL
+ETTh,amazon/chronos-bolt-mini,0.8057126710113404,0.07740387596411452
+ETTm,amazon/chronos-bolt-mini,0.6100793941108849,0.05129333450944573
+dominick,amazon/chronos-bolt-mini,0.8664152477208024,0.3499696999160997
+ercot,amazon/chronos-bolt-mini,0.6871250728215426,0.02448804863744021
+exchange_rate,amazon/chronos-bolt-mini,1.3520551553333662,0.00934663373172766
+m4_quarterly,amazon/chronos-bolt-mini,1.2569644266281508,0.07833787023275976
+m4_yearly,amazon/chronos-bolt-mini,3.7611003052413796,0.12931927951165456
+m5,amazon/chronos-bolt-mini,0.9188876472137485,0.5661303206519673
+monash_australian_electricity,amazon/chronos-bolt-mini,0.8823559450287066,0.04493688824488474
+monash_car_parts,amazon/chronos-bolt-mini,0.8604081423647779,1.0041876404811494
+monash_cif_2016,amazon/chronos-bolt-mini,1.0762361363763873,0.017641893717784202
+monash_covid_deaths,amazon/chronos-bolt-mini,38.83915011538576,0.06098317835750057
+monash_fred_md,amazon/chronos-bolt-mini,0.6169859211923081,0.03256236965040934
+monash_hospital,amazon/chronos-bolt-mini,0.6924431064606051,0.05766349075348645
+monash_m1_monthly,amazon/chronos-bolt-mini,1.147893030263777,0.13270222658510553
+monash_m1_quarterly,amazon/chronos-bolt-mini,1.8662100001165818,0.09846363409254102
+monash_m1_yearly,amazon/chronos-bolt-mini,5.319154632748303,0.16167328827180308
+monash_m3_monthly,amazon/chronos-bolt-mini,0.8758452776118432,0.09493431248614057
+monash_m3_quarterly,amazon/chronos-bolt-mini,1.3555175243802005,0.07808062465932723
+monash_m3_yearly,amazon/chronos-bolt-mini,3.605769430055575,0.15711010456482008
+monash_nn5_weekly,amazon/chronos-bolt-mini,0.9347141924977239,0.08522899825844342
+monash_tourism_monthly,amazon/chronos-bolt-mini,1.649587479665881,0.0979648261309891
+monash_tourism_quarterly,amazon/chronos-bolt-mini,1.8471553663088986,0.06501077791766902
+monash_tourism_yearly,amazon/chronos-bolt-mini,3.9932920493826245,0.1743539122097316
+monash_traffic,amazon/chronos-bolt-mini,0.8355442361271347,0.24351051123330386
+monash_weather,amazon/chronos-bolt-mini,0.800013628350165,0.13041050756802045
+nn5,amazon/chronos-bolt-mini,0.611917632501032,0.1570111102680171
diff --git a/scripts/evaluation/results/chronos-bolt-small-agg-rel-scores.csv b/scripts/evaluation/results/chronos-bolt-small-agg-rel-scores.csv
new file mode 100644
index 0000000..0652750
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-small-agg-rel-scores.csv
@@ -0,0 +1,5 @@
+benchmark,metric,value
+in-domain,MASE,0.7030801652116672
+in-domain,WQL,0.5443547623341555
+zero-shot,MASE,0.8192127745093378
+zero-shot,WQL,0.6356097843099521
diff --git a/scripts/evaluation/results/chronos-bolt-small-in-domain.csv b/scripts/evaluation/results/chronos-bolt-small-in-domain.csv
new file mode 100644
index 0000000..7c17064
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-small-in-domain.csv
@@ -0,0 +1,16 @@
+dataset,model,MASE,WQL
+electricity_15min,amazon/chronos-bolt-small,0.44920089250026723,0.08115291306964295
+m4_daily,amazon/chronos-bolt-small,3.201966619014735,0.02143368277732494
+m4_hourly,amazon/chronos-bolt-small,0.8686298207618999,0.020368729287465817
+m4_monthly,amazon/chronos-bolt-small,0.9537717737278778,0.0939247807527992
+m4_weekly,amazon/chronos-bolt-small,2.1236755094789177,0.03785184715517262
+monash_electricity_hourly,amazon/chronos-bolt-small,1.3728906161330452,0.09452411472431674
+monash_electricity_weekly,amazon/chronos-bolt-small,1.8703239487242378,0.06648479071326366
+monash_kdd_cup_2018,amazon/chronos-bolt-small,0.6458631909979771,0.25148489931571666
+monash_london_smart_meters,amazon/chronos-bolt-small,0.7126939688565166,0.326874529903459
+monash_pedestrian_counts,amazon/chronos-bolt-small,0.3015070035798365,0.2285590441093863
+monash_rideshare,amazon/chronos-bolt-small,0.823726965741684,0.12409769473500927
+monash_temperature_rain,amazon/chronos-bolt-small,0.8980348827836525,0.5984819599873311
+taxi_30min,amazon/chronos-bolt-small,0.7597818149895785,0.2348569752311862
+uber_tlc_daily,amazon/chronos-bolt-small,0.8460854328036702,0.09666483354735897
+uber_tlc_hourly,amazon/chronos-bolt-small,0.6662547495017634,0.1524256346268063
diff --git a/scripts/evaluation/results/chronos-bolt-small-zero-shot.csv b/scripts/evaluation/results/chronos-bolt-small-zero-shot.csv
new file mode 100644
index 0000000..e9f4134
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-small-zero-shot.csv
@@ -0,0 +1,28 @@
+dataset,model,MASE,WQL
+ETTh,amazon/chronos-bolt-small,0.792521748651108,0.07590654063011319
+ETTm,amazon/chronos-bolt-small,0.6209623928936988,0.05056189722606397
+dominick,amazon/chronos-bolt-small,0.8706134610400587,0.34811141409475416
+ercot,amazon/chronos-bolt-small,0.7562857616685997,0.02596064260343696
+exchange_rate,amazon/chronos-bolt-small,1.774835301692689,0.011363548847621512
+m4_quarterly,amazon/chronos-bolt-small,1.2478142413437487,0.07808795122806232
+m4_yearly,amazon/chronos-bolt-small,3.6925595655002574,0.12772564181388502
+m5,amazon/chronos-bolt-small,0.9195435643571084,0.5668430814831332
+monash_australian_electricity,amazon/chronos-bolt-small,0.8128424798841111,0.041509852162861564
+monash_car_parts,amazon/chronos-bolt-small,0.8584574663781737,1.0074689402521324
+monash_cif_2016,amazon/chronos-bolt-small,1.0182471909074982,0.01581964877692293
+monash_covid_deaths,amazon/chronos-bolt-small,36.467595559655145,0.0427382859406882
+monash_fred_md,amazon/chronos-bolt-small,0.6132863794635253,0.03730410577241995
+monash_hospital,amazon/chronos-bolt-small,0.6954489513780618,0.058119864671526154
+monash_m1_monthly,amazon/chronos-bolt-small,1.1277621848099244,0.1335656174632902
+monash_m1_quarterly,amazon/chronos-bolt-small,1.8356144904231688,0.09363028483838018
+monash_m1_yearly,amazon/chronos-bolt-small,5.098146069746402,0.15669928873371905
+monash_m3_monthly,amazon/chronos-bolt-small,0.8685125121306435,0.09396568468255145
+monash_m3_quarterly,amazon/chronos-bolt-small,1.3269103591066727,0.07691022995374203
+monash_m3_yearly,amazon/chronos-bolt-small,3.40993282700627,0.1547639821304127
+monash_nn5_weekly,amazon/chronos-bolt-small,0.9266513350636507,0.08452821221908001
+monash_tourism_monthly,amazon/chronos-bolt-small,1.6106732721197876,0.09362336754317802
+monash_tourism_quarterly,amazon/chronos-bolt-small,1.8357819365308639,0.06734337535269994
+monash_tourism_yearly,amazon/chronos-bolt-small,3.8963100495394194,0.16766064312072784
+monash_traffic,amazon/chronos-bolt-small,0.8598507749866499,0.25173786112983054
+monash_weather,amazon/chronos-bolt-small,0.8020408743877911,0.13258563963844888
+nn5,amazon/chronos-bolt-small,0.5833047644729239,0.15066847836762787
diff --git a/scripts/evaluation/results/chronos-bolt-tiny-agg-rel-scores.csv b/scripts/evaluation/results/chronos-bolt-tiny-agg-rel-scores.csv
new file mode 100644
index 0000000..72e98ec
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-tiny-agg-rel-scores.csv
@@ -0,0 +1,5 @@
+benchmark,metric,value
+in-domain,MASE,0.7403252781013574
+in-domain,WQL,0.5733728165523524
+zero-shot,MASE,0.8445407343705457
+zero-shot,WQL,0.6678781905023173
diff --git a/scripts/evaluation/results/chronos-bolt-tiny-in-domain.csv b/scripts/evaluation/results/chronos-bolt-tiny-in-domain.csv
new file mode 100644
index 0000000..bf2ffb5
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-tiny-in-domain.csv
@@ -0,0 +1,16 @@
+dataset,model,MASE,WQL
+electricity_15min,amazon/chronos-bolt-tiny,0.4676384089765091,0.0861229808117837
+m4_daily,amazon/chronos-bolt-tiny,3.1789994761356795,0.020961883512815756
+m4_hourly,amazon/chronos-bolt-tiny,0.9348005698736752,0.021087527284114574
+m4_monthly,amazon/chronos-bolt-tiny,0.965298729632761,0.0950380483243082
+m4_weekly,amazon/chronos-bolt-tiny,2.261575511029903,0.04093653263178429
+monash_electricity_hourly,amazon/chronos-bolt-tiny,1.5739346351263623,0.10808418398945202
+monash_electricity_weekly,amazon/chronos-bolt-tiny,1.8628689103722829,0.05773335283584782
+monash_kdd_cup_2018,amazon/chronos-bolt-tiny,0.6869549985391232,0.28012801092758166
+monash_london_smart_meters,amazon/chronos-bolt-tiny,0.7284234905933779,0.33496438244693033
+monash_pedestrian_counts,amazon/chronos-bolt-tiny,0.32338947321773864,0.2530637833749087
+monash_rideshare,amazon/chronos-bolt-tiny,0.8562780835002918,0.1304317657933891
+monash_temperature_rain,amazon/chronos-bolt-tiny,0.9030707620825977,0.6064087080755548
+taxi_30min,amazon/chronos-bolt-tiny,0.9122159603256838,0.28002194370731626
+uber_tlc_daily,amazon/chronos-bolt-tiny,0.9087055420190513,0.11193388685815164
+uber_tlc_hourly,amazon/chronos-bolt-tiny,0.6716569179590032,0.15310845458208555
diff --git a/scripts/evaluation/results/chronos-bolt-tiny-zero-shot.csv b/scripts/evaluation/results/chronos-bolt-tiny-zero-shot.csv
new file mode 100644
index 0000000..4ec181e
--- /dev/null
+++ b/scripts/evaluation/results/chronos-bolt-tiny-zero-shot.csv
@@ -0,0 +1,28 @@
+dataset,model,MASE,WQL
+ETTh,amazon/chronos-bolt-tiny,0.7941225847155844,0.07480860969990633
+ETTm,amazon/chronos-bolt-tiny,0.6508270995240056,0.05440068825429993
+dominick,amazon/chronos-bolt-tiny,0.876060127216559,0.35175949052933253
+ercot,amazon/chronos-bolt-tiny,0.7309134980173839,0.02468604544464515
+exchange_rate,amazon/chronos-bolt-tiny,1.6857262567077134,0.011477224264784112
+m4_quarterly,amazon/chronos-bolt-tiny,1.2605908919338378,0.0789049420017836
+m4_yearly,amazon/chronos-bolt-tiny,3.7118394116161757,0.1286932555969197
+m5,amazon/chronos-bolt-tiny,0.9195469670062033,0.5634881835998845
+monash_australian_electricity,amazon/chronos-bolt-tiny,0.8419304693259403,0.042040993880313904
+monash_car_parts,amazon/chronos-bolt-tiny,0.8625579150452282,1.0009987800801836
+monash_cif_2016,amazon/chronos-bolt-tiny,1.095219642027011,0.017550336784241796
+monash_covid_deaths,amazon/chronos-bolt-tiny,40.674057986280744,0.06723714516685976
+monash_fred_md,amazon/chronos-bolt-tiny,0.6127387450520702,0.04747523852271518
+monash_hospital,amazon/chronos-bolt-tiny,0.6980246281225624,0.05864223243167421
+monash_m1_monthly,amazon/chronos-bolt-tiny,1.1625495971731141,0.13142237467151166
+monash_m1_quarterly,amazon/chronos-bolt-tiny,1.8941765599193754,0.09972207844232561
+monash_m1_yearly,amazon/chronos-bolt-tiny,5.136332694531757,0.160331813128038
+monash_m3_monthly,amazon/chronos-bolt-tiny,0.8744553726704598,0.09435519378597752
+monash_m3_quarterly,amazon/chronos-bolt-tiny,1.364563776692303,0.07875066385737857
+monash_m3_yearly,amazon/chronos-bolt-tiny,3.3685961410254928,0.15158076519486274
+monash_nn5_weekly,amazon/chronos-bolt-tiny,0.9324436794013877,0.0847385189968909
+monash_tourism_monthly,amazon/chronos-bolt-tiny,1.7895936775088157,0.1058167042693116
+monash_tourism_quarterly,amazon/chronos-bolt-tiny,2.095262637810499,0.0710732570354461
+monash_tourism_yearly,amazon/chronos-bolt-tiny,4.042821441327848,0.172613367251472
+monash_traffic,amazon/chronos-bolt-tiny,0.8836032533767518,0.2574297134210491
+monash_weather,amazon/chronos-bolt-tiny,0.8005348255663177,0.13111355494466076
+nn5,amazon/chronos-bolt-tiny,0.7228248498869763,0.1816913098894226
diff --git a/scripts/evaluation/results/chronos-t5-base-agg-rel-scores.csv b/scripts/evaluation/results/chronos-t5-base-agg-rel-scores.csv
new file mode 100644
index 0000000..56492d3
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-base-agg-rel-scores.csv
@@ -0,0 +1,5 @@
+benchmark,metric,value
+in-domain,MASE,0.7007558507277635
+in-domain,WQL,0.5786300105297922
+zero-shot,MASE,0.8155209321160994
+zero-shot,WQL,0.6424634919486323
diff --git a/scripts/evaluation/results/chronos-t5-base-in-domain.csv b/scripts/evaluation/results/chronos-t5-base-in-domain.csv
new file mode 100644
index 0000000..7925271
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-base-in-domain.csv
@@ -0,0 +1,16 @@
+dataset,model,MASE,WQL
+electricity_15min,amazon/chronos-t5-base,0.39879754957261204,0.07738953262286181
+m4_daily,amazon/chronos-t5-base,3.160575865614404,0.02194256368254537
+m4_hourly,amazon/chronos-t5-base,0.6938747745332102,0.026354948301302205
+m4_monthly,amazon/chronos-t5-base,0.971951848755026,0.10355213196432872
+m4_weekly,amazon/chronos-t5-base,2.0143841267657945,0.03639741235815474
+monash_electricity_hourly,amazon/chronos-t5-base,1.5717251971297332,0.1078882125804548
+monash_electricity_weekly,amazon/chronos-t5-base,1.7862927210886668,0.06255982783148449
+monash_kdd_cup_2018,amazon/chronos-t5-base,0.6335225775496138,0.2684272353843692
+monash_london_smart_meters,amazon/chronos-t5-base,0.8362014889190201,0.4265549499082726
+monash_pedestrian_counts,amazon/chronos-t5-base,0.2817708325561419,0.20810108090665583
+monash_rideshare,amazon/chronos-t5-base,0.8614480533175364,0.1356591190888703
+monash_temperature_rain,amazon/chronos-t5-base,0.9692405156151607,0.660155448791624
+taxi_30min,amazon/chronos-t5-base,0.8186287575356217,0.26236060366367003
+uber_tlc_daily,amazon/chronos-t5-base,0.8338648311528079,0.0970875577681834
+uber_tlc_hourly,amazon/chronos-t5-base,0.6647193438331641,0.15436646659512512
diff --git a/scripts/evaluation/results/chronos-t5-base-zero-shot.csv b/scripts/evaluation/results/chronos-t5-base-zero-shot.csv
new file mode 100644
index 0000000..6078311
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-base-zero-shot.csv
@@ -0,0 +1,28 @@
+dataset,model,MASE,WQL
+ETTh,amazon/chronos-t5-base,0.7653491494991778,0.08087267701042929
+ETTm,amazon/chronos-t5-base,0.7737006634032871,0.07008650633028274
+dominick,amazon/chronos-t5-base,0.8194044957573132,0.33201307438298133
+ercot,amazon/chronos-t5-base,0.5014399265038706,0.013589435745554596
+exchange_rate,amazon/chronos-t5-base,2.055616906406159,0.011066070028466317
+m4_quarterly,amazon/chronos-t5-base,1.2253036947743137,0.08327936201395683
+m4_yearly,amazon/chronos-t5-base,3.639991540990927,0.13539258375263963
+m5,amazon/chronos-t5-base,0.9391874615167101,0.5867234116216755
+monash_australian_electricity,amazon/chronos-t5-base,1.2944069383163321,0.07070604202031877
+monash_car_parts,amazon/chronos-t5-base,0.9071940271035218,1.077797124337994
+monash_cif_2016,amazon/chronos-t5-base,0.9840747802099565,0.011825556826558836
+monash_covid_deaths,amazon/chronos-t5-base,42.68503365359237,0.042229910495746356
+monash_fred_md,amazon/chronos-t5-base,0.4857773806790164,0.021204829049512715
+monash_hospital,amazon/chronos-t5-base,0.7053005021431749,0.05630687524507516
+monash_m1_monthly,amazon/chronos-t5-base,1.1153039466137842,0.12724419775326076
+monash_m1_quarterly,amazon/chronos-t5-base,1.746093728928804,0.1123583549291933
+monash_m1_yearly,amazon/chronos-t5-base,4.401291522370069,0.18541586641719554
+monash_m3_monthly,amazon/chronos-t5-base,0.8627172231908679,0.09640536232169555
+monash_m3_quarterly,amazon/chronos-t5-base,1.1696030904401578,0.07392876900131434
+monash_m3_yearly,amazon/chronos-t5-base,3.1298600218573775,0.1486674447940158
+monash_nn5_weekly,amazon/chronos-t5-base,0.9334860602210187,0.08972736821598823
+monash_tourism_monthly,amazon/chronos-t5-base,1.7937702435879332,0.10260220444264027
+monash_tourism_quarterly,amazon/chronos-t5-base,1.7791494997972261,0.06852507950474919
+monash_tourism_yearly,amazon/chronos-t5-base,3.8359926053603197,0.20722699382964643
+monash_traffic,amazon/chronos-t5-base,0.8015262383138622,0.25565153982140926
+monash_weather,amazon/chronos-t5-base,0.8159511190589147,0.13802320967454584
+nn5,amazon/chronos-t5-base,0.5927076179914024,0.1630476065585159
diff --git a/scripts/evaluation/results/chronos-t5-large-agg-rel-scores.csv b/scripts/evaluation/results/chronos-t5-large-agg-rel-scores.csv
new file mode 100644
index 0000000..94beef1
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-large-agg-rel-scores.csv
@@ -0,0 +1,5 @@
+benchmark,metric,value
+in-domain,MASE,0.6944869734691035
+in-domain,WQL,0.5596857927462495
+zero-shot,MASE,0.8213682201405101
+zero-shot,WQL,0.6504834081319559
diff --git a/scripts/evaluation/results/chronos-t5-large-in-domain.csv b/scripts/evaluation/results/chronos-t5-large-in-domain.csv
new file mode 100644
index 0000000..63f4b18
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-large-in-domain.csv
@@ -0,0 +1,16 @@
+dataset,model,MASE,WQL
+electricity_15min,amazon/chronos-t5-large,0.3866310906621673,0.07759528615667297
+m4_daily,amazon/chronos-t5-large,3.134560968849699,0.02158279722410466
+m4_hourly,amazon/chronos-t5-large,0.6975930649233378,0.02086427219957674
+m4_monthly,amazon/chronos-t5-large,0.9585550091429409,0.10091221432814867
+m4_weekly,amazon/chronos-t5-large,2.0191422600104425,0.036912838355537186
+monash_electricity_hourly,amazon/chronos-t5-large,1.4069912853901292,0.09642382339452431
+monash_electricity_weekly,amazon/chronos-t5-large,1.7501880036182798,0.05765306465830232
+monash_kdd_cup_2018,amazon/chronos-t5-large,0.6788042816175427,0.2853553329804835
+monash_london_smart_meters,amazon/chronos-t5-large,0.8290300790418726,0.4235436387853963
+monash_pedestrian_counts,amazon/chronos-t5-large,0.2764118100521592,0.18692234491663473
+monash_rideshare,amazon/chronos-t5-large,0.8758058784466208,0.140260325368757
+monash_temperature_rain,amazon/chronos-t5-large,0.9738403865035117,0.6604571928063249
+taxi_30min,amazon/chronos-t5-large,0.8245662397270109,0.2653520120326771
+uber_tlc_daily,amazon/chronos-t5-large,0.8044165990021739,0.09499035584302248
+uber_tlc_hourly,amazon/chronos-t5-large,0.6700665937164474,0.15190288476653066
diff --git a/scripts/evaluation/results/chronos-t5-large-zero-shot.csv b/scripts/evaluation/results/chronos-t5-large-zero-shot.csv
new file mode 100644
index 0000000..020c100
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-large-zero-shot.csv
@@ -0,0 +1,28 @@
+dataset,model,MASE,WQL
+ETTh,amazon/chronos-t5-large,0.78160443631164,0.07884375667736107
+ETTm,amazon/chronos-t5-large,0.7325919639389967,0.06656858270921162
+dominick,amazon/chronos-t5-large,0.8200108271155829,0.3311575649734524
+ercot,amazon/chronos-t5-large,0.6050812633742764,0.01822996942395577
+exchange_rate,amazon/chronos-t5-large,2.3439287001928744,0.014841231672174684
+m4_quarterly,amazon/chronos-t5-large,1.2169666607868148,0.08235162400898562
+m4_yearly,amazon/chronos-t5-large,3.5524979814018947,0.1325675848907479
+m5,amazon/chronos-t5-large,0.9422990989146737,0.585615077637479
+monash_australian_electricity,amazon/chronos-t5-large,1.480849838497958,0.07973968848149568
+monash_car_parts,amazon/chronos-t5-large,0.901547374873302,1.0467398096496576
+monash_cif_2016,amazon/chronos-t5-large,0.9906388185665337,0.011966178555329998
+monash_covid_deaths,amazon/chronos-t5-large,44.07354193681227,0.06108999981222163
+monash_fred_md,amazon/chronos-t5-large,0.5184400880318044,0.01675533888399231
+monash_hospital,amazon/chronos-t5-large,0.7055308474630898,0.0552450850258613
+monash_m1_monthly,amazon/chronos-t5-large,1.0888995301234758,0.12729911122909737
+monash_m1_quarterly,amazon/chronos-t5-large,1.7477134564031453,0.10618253695380094
+monash_m1_yearly,amazon/chronos-t5-large,4.250667049416348,0.17128879333643188
+monash_m3_monthly,amazon/chronos-t5-large,0.8559326975903808,0.09572577431396007
+monash_m3_quarterly,amazon/chronos-t5-large,1.1867267751420676,0.07449254281607631
+monash_m3_yearly,amazon/chronos-t5-large,3.0239493021840635,0.14814710375646464
+monash_nn5_weekly,amazon/chronos-t5-large,0.9228721852437364,0.08948447200571868
+monash_tourism_monthly,amazon/chronos-t5-large,1.7304427846580348,0.09983169221760163
+monash_tourism_quarterly,amazon/chronos-t5-large,1.6437184365114073,0.0690906057781915
+monash_tourism_yearly,amazon/chronos-t5-large,3.6268503118928535,0.17732007043832695
+monash_traffic,amazon/chronos-t5-large,0.7985975530866148,0.25313515740581755
+monash_weather,amazon/chronos-t5-large,0.8187388457436171,0.1387756772600068
+nn5,amazon/chronos-t5-large,0.5755260854173723,0.15733693855465292
diff --git a/scripts/evaluation/results/chronos-t5-mini-agg-rel-scores.csv b/scripts/evaluation/results/chronos-t5-mini-agg-rel-scores.csv
new file mode 100644
index 0000000..abd06f0
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-mini-agg-rel-scores.csv
@@ -0,0 +1,5 @@
+benchmark,metric,value
+in-domain,MASE,0.7249816823595568
+in-domain,WQL,0.5965372489622094
+zero-shot,MASE,0.8411995116926901
+zero-shot,WQL,0.6888397962259065
diff --git a/scripts/evaluation/results/chronos-t5-mini-in-domain.csv b/scripts/evaluation/results/chronos-t5-mini-in-domain.csv
new file mode 100644
index 0000000..196d3e9
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-mini-in-domain.csv
@@ -0,0 +1,16 @@
+dataset,model,MASE,WQL
+electricity_15min,amazon/chronos-t5-mini,0.4446629660227641,0.08114657599239496
+m4_daily,amazon/chronos-t5-mini,3.1533349226194005,0.022000507013584743
+m4_hourly,amazon/chronos-t5-mini,0.7616830292996938,0.024630575107847653
+m4_monthly,amazon/chronos-t5-mini,0.9934074425853089,0.10402168689068064
+m4_weekly,amazon/chronos-t5-mini,2.1407189608104416,0.04138058102434373
+monash_electricity_hourly,amazon/chronos-t5-mini,1.3698378948313894,0.09189698159081384
+monash_electricity_weekly,amazon/chronos-t5-mini,1.9238345295706893,0.07015383787479901
+monash_kdd_cup_2018,amazon/chronos-t5-mini,0.6027861468526459,0.25493489598663444
+monash_london_smart_meters,amazon/chronos-t5-mini,0.8570035850603943,0.4356582737588471
+monash_pedestrian_counts,amazon/chronos-t5-mini,0.30374539593979855,0.2374083216051065
+monash_rideshare,amazon/chronos-t5-mini,0.8157349455509949,0.12963515638823117
+monash_temperature_rain,amazon/chronos-t5-mini,1.010161905102516,0.6919171702485583
+taxi_30min,amazon/chronos-t5-mini,0.9318379552979712,0.31229508015999674
+uber_tlc_daily,amazon/chronos-t5-mini,0.9213437323817685,0.10475291429149586
+uber_tlc_hourly,amazon/chronos-t5-mini,0.6812621470377416,0.15982192635434303
diff --git a/scripts/evaluation/results/chronos-t5-mini-zero-shot.csv b/scripts/evaluation/results/chronos-t5-mini-zero-shot.csv
new file mode 100644
index 0000000..bdc383f
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-mini-zero-shot.csv
@@ -0,0 +1,28 @@
+dataset,model,MASE,WQL
+ETTh,amazon/chronos-t5-mini,0.789678971785092,0.08068969536800001
+ETTm,amazon/chronos-t5-mini,0.7521219674190734,0.06791782942706617
+dominick,amazon/chronos-t5-mini,0.8207116999488602,0.34004499734299765
+ercot,amazon/chronos-t5-mini,0.5462749489237783,0.015035001020343136
+exchange_rate,amazon/chronos-t5-mini,2.1326718165798657,0.015073846769933199
+m4_quarterly,amazon/chronos-t5-mini,1.271761811062081,0.08575942238385105
+m4_yearly,amazon/chronos-t5-mini,3.7340853642679126,0.13938781939783162
+m5,amazon/chronos-t5-mini,0.9421556321929742,0.5961689098871504
+monash_australian_electricity,amazon/chronos-t5-mini,1.046297291920238,0.05424453772723559
+monash_car_parts,amazon/chronos-t5-mini,0.8913523483805221,1.0174797526818506
+monash_cif_2016,amazon/chronos-t5-mini,1.0674111822055679,0.016800831829085764
+monash_covid_deaths,amazon/chronos-t5-mini,43.69727825485175,0.08788117644141617
+monash_fred_md,amazon/chronos-t5-mini,0.46227452519609524,0.01871860604459728
+monash_hospital,amazon/chronos-t5-mini,0.7112593459108532,0.05831005112661489
+monash_m1_monthly,amazon/chronos-t5-mini,1.1756557848450433,0.14192178371159841
+monash_m1_quarterly,amazon/chronos-t5-mini,1.795009199698074,0.11760148522768847
+monash_m1_yearly,amazon/chronos-t5-mini,5.078889706085604,0.1882823108615221
+monash_m3_monthly,amazon/chronos-t5-mini,0.900404391663476,0.09935931092075681
+monash_m3_quarterly,amazon/chronos-t5-mini,1.2604342624229292,0.07807204797138119
+monash_m3_yearly,amazon/chronos-t5-mini,3.4395976709464255,0.16085249526114198
+monash_nn5_weekly,amazon/chronos-t5-mini,0.9459117943913629,0.09042762527674755
+monash_tourism_monthly,amazon/chronos-t5-mini,1.920865545569713,0.10791754513335952
+monash_tourism_quarterly,amazon/chronos-t5-mini,1.7957439111869486,0.07514539225156464
+monash_tourism_yearly,amazon/chronos-t5-mini,4.134958090482728,0.2202036957350168
+monash_traffic,amazon/chronos-t5-mini,0.8546792774237857,0.2668831661775284
+monash_weather,amazon/chronos-t5-mini,0.8607748244159247,0.15031866806333247
+nn5,amazon/chronos-t5-mini,0.6497211196906223,0.17352254058241523
diff --git a/scripts/evaluation/results/chronos-t5-small-agg-rel-scores.csv b/scripts/evaluation/results/chronos-t5-small-agg-rel-scores.csv
new file mode 100644
index 0000000..43947fb
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-small-agg-rel-scores.csv
@@ -0,0 +1,5 @@
+benchmark,metric,value
+in-domain,MASE,0.7296140269944743
+in-domain,WQL,0.6086958548874499
+zero-shot,MASE,0.8303721909132112
+zero-shot,WQL,0.6649587072099045
diff --git a/scripts/evaluation/results/chronos-t5-small-in-domain.csv b/scripts/evaluation/results/chronos-t5-small-in-domain.csv
new file mode 100644
index 0000000..206ec82
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-small-in-domain.csv
@@ -0,0 +1,16 @@
+dataset,model,MASE,WQL
+electricity_15min,amazon/chronos-t5-small,0.4115559557750193,0.08085148902238105
+m4_daily,amazon/chronos-t5-small,3.1384304946608896,0.02129901023419818
+m4_hourly,amazon/chronos-t5-small,0.7300874075370588,0.024686127211237932
+m4_monthly,amazon/chronos-t5-small,0.9797264456494642,0.10297069145186107
+m4_weekly,amazon/chronos-t5-small,2.0802214537692607,0.03959222330783002
+monash_electricity_hourly,amazon/chronos-t5-small,1.530308399040219,0.10765947926209926
+monash_electricity_weekly,amazon/chronos-t5-small,1.9249616494404531,0.07593976499899265
+monash_kdd_cup_2018,amazon/chronos-t5-small,0.6911172359201715,0.2863722811236367
+monash_london_smart_meters,amazon/chronos-t5-small,0.8405756252443325,0.4300875548402115
+monash_pedestrian_counts,amazon/chronos-t5-small,0.30836963006151696,0.2442543970311678
+monash_rideshare,amazon/chronos-t5-small,0.8436277753840817,0.1363421932158997
+monash_temperature_rain,amazon/chronos-t5-small,1.0176003932416664,0.6847726381172435
+taxi_30min,amazon/chronos-t5-small,0.976277213614167,0.32770172988517626
+uber_tlc_daily,amazon/chronos-t5-small,0.8694727058784919,0.0994889223610958
+uber_tlc_hourly,amazon/chronos-t5-small,0.6738672444888639,0.1573990617753753
diff --git a/scripts/evaluation/results/chronos-t5-small-zero-shot.csv b/scripts/evaluation/results/chronos-t5-small-zero-shot.csv
new file mode 100644
index 0000000..f3a970a
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-small-zero-shot.csv
@@ -0,0 +1,28 @@
+dataset,model,MASE,WQL
+ETTh,amazon/chronos-t5-small,0.8516754221042285,0.08667817580712385
+ETTm,amazon/chronos-t5-small,0.6825432730635727,0.06076472147001207
+dominick,amazon/chronos-t5-small,0.8108766032127683,0.3368104617474581
+ercot,amazon/chronos-t5-small,0.564879593858422,0.015547628920969682
+exchange_rate,amazon/chronos-t5-small,1.8143459139100264,0.014492477372711763
+m4_quarterly,amazon/chronos-t5-small,1.2415331521819728,0.08383826063189778
+m4_yearly,amazon/chronos-t5-small,3.738749650935195,0.1384514201649314
+m5,amazon/chronos-t5-small,0.9368713240675598,0.5896066252181699
+monash_australian_electricity,amazon/chronos-t5-small,1.2241146217392032,0.06951399165882449
+monash_car_parts,amazon/chronos-t5-small,0.8917508090523597,1.0314986717260015
+monash_cif_2016,amazon/chronos-t5-small,1.0187937383419037,0.014633240218233142
+monash_covid_deaths,amazon/chronos-t5-small,42.298997211368935,0.06339512778191682
+monash_fred_md,amazon/chronos-t5-small,0.4742159923922472,0.01486734736993978
+monash_hospital,amazon/chronos-t5-small,0.709814741753487,0.05704674270057172
+monash_m1_monthly,amazon/chronos-t5-small,1.1723041163998773,0.13799049510465802
+monash_m1_quarterly,amazon/chronos-t5-small,1.8077827825737092,0.11323432989795904
+monash_m1_yearly,amazon/chronos-t5-small,4.739967673537301,0.1730738338876877
+monash_m3_monthly,amazon/chronos-t5-small,0.8856577322724943,0.09985251429658573
+monash_m3_quarterly,amazon/chronos-t5-small,1.278907982396775,0.08094041554590593
+monash_m3_yearly,amazon/chronos-t5-small,3.382470310192457,0.157363937435307
+monash_nn5_weekly,amazon/chronos-t5-small,0.9277396908126303,0.08963913763368506
+monash_tourism_monthly,amazon/chronos-t5-small,1.9251180766131313,0.10943962474253494
+monash_tourism_quarterly,amazon/chronos-t5-small,1.7623454951333655,0.06862432764377493
+monash_tourism_yearly,amazon/chronos-t5-small,3.987690476709746,0.19960492460202509
+monash_traffic,amazon/chronos-t5-small,0.8204223927835267,0.2571189517024486
+monash_weather,amazon/chronos-t5-small,0.8550633590487968,0.1479701971025123
+nn5,amazon/chronos-t5-small,0.6130789183153671,0.16771392719859998
diff --git a/scripts/evaluation/results/chronos-t5-tiny-agg-rel-scores.csv b/scripts/evaluation/results/chronos-t5-tiny-agg-rel-scores.csv
new file mode 100644
index 0000000..440cc39
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-tiny-agg-rel-scores.csv
@@ -0,0 +1,5 @@
+benchmark,metric,value
+in-domain,MASE,0.7649019745781727
+in-domain,WQL,0.6288613368129368
+zero-shot,MASE,0.8704764463925718
+zero-shot,WQL,0.7108912052035352
diff --git a/scripts/evaluation/results/chronos-t5-tiny-in-domain.csv b/scripts/evaluation/results/chronos-t5-tiny-in-domain.csv
new file mode 100644
index 0000000..e30632f
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-tiny-in-domain.csv
@@ -0,0 +1,16 @@
+dataset,model,MASE,WQL
+electricity_15min,amazon/chronos-t5-tiny,0.5091784254243783,0.08236334376190152
+m4_daily,amazon/chronos-t5-tiny,3.203164895930929,0.022152192084951595
+m4_hourly,amazon/chronos-t5-tiny,0.8171321441164723,0.027490760558343874
+m4_monthly,amazon/chronos-t5-tiny,1.005839207921131,0.10388015368939435
+m4_weekly,amazon/chronos-t5-tiny,2.2148332313370735,0.043429655561156084
+monash_electricity_hourly,amazon/chronos-t5-tiny,1.6190021089002615,0.10967453530956882
+monash_electricity_weekly,amazon/chronos-t5-tiny,2.0774597917676734,0.08159998975612164
+monash_kdd_cup_2018,amazon/chronos-t5-tiny,0.6730886827096076,0.2616610603634618
+monash_london_smart_meters,amazon/chronos-t5-tiny,0.8830447519225436,0.4499607073491794
+monash_pedestrian_counts,amazon/chronos-t5-tiny,0.3042105240185045,0.23387631681117071
+monash_rideshare,amazon/chronos-t5-tiny,0.8431350112476247,0.1378817076926394
+monash_temperature_rain,amazon/chronos-t5-tiny,0.9887398447367799,0.6957797286648015
+taxi_30min,amazon/chronos-t5-tiny,1.035544060665179,0.3450476958104713
+uber_tlc_daily,amazon/chronos-t5-tiny,0.93025919000775,0.1105323649942084
+uber_tlc_hourly,amazon/chronos-t5-tiny,0.697558054147913,0.16320255844336232
diff --git a/scripts/evaluation/results/chronos-t5-tiny-zero-shot.csv b/scripts/evaluation/results/chronos-t5-tiny-zero-shot.csv
new file mode 100644
index 0000000..83d377c
--- /dev/null
+++ b/scripts/evaluation/results/chronos-t5-tiny-zero-shot.csv
@@ -0,0 +1,28 @@
+dataset,model,MASE,WQL
+ETTh,amazon/chronos-t5-tiny,0.8184074113571701,0.08578203438707048
+ETTm,amazon/chronos-t5-tiny,0.9103621000781905,0.07975361086322658
+dominick,amazon/chronos-t5-tiny,0.8538295532466194,0.3597090770361857
+ercot,amazon/chronos-t5-tiny,0.7273437589773705,0.020843170924006626
+exchange_rate,amazon/chronos-t5-tiny,1.6621128608546154,0.01085145980896454
+m4_quarterly,amazon/chronos-t5-tiny,1.2696259955861924,0.0861404188925996
+m4_yearly,amazon/chronos-t5-tiny,3.5293881164900527,0.13281575565500411
+m5,amazon/chronos-t5-tiny,0.9394059505709506,0.5981531758388589
+monash_australian_electricity,amazon/chronos-t5-tiny,1.4558820561269024,0.07673567331332948
+monash_car_parts,amazon/chronos-t5-tiny,0.9058206654011024,1.0236307963149358
+monash_cif_2016,amazon/chronos-t5-tiny,1.09349564130852,0.014066593076202984
+monash_covid_deaths,amazon/chronos-t5-tiny,46.53079664940016,0.09201919385053775
+monash_fred_md,amazon/chronos-t5-tiny,0.48008374212956456,0.03219550761153211
+monash_hospital,amazon/chronos-t5-tiny,0.7062562198194838,0.05790409320432609
+monash_m1_monthly,amazon/chronos-t5-tiny,1.214892145549996,0.14723095246308077
+monash_m1_quarterly,amazon/chronos-t5-tiny,1.8968576926613199,0.11026972972622998
+monash_m1_yearly,amazon/chronos-t5-tiny,4.829453202075546,0.17286063726000958
+monash_m3_monthly,amazon/chronos-t5-tiny,0.9095746605884618,0.10117875324490073
+monash_m3_quarterly,amazon/chronos-t5-tiny,1.3234957548639883,0.08209032993637215
+monash_m3_yearly,amazon/chronos-t5-tiny,3.1489371074890093,0.1492445630072877
+monash_nn5_weekly,amazon/chronos-t5-tiny,0.9637480731663901,0.09205994784693056
+monash_tourism_monthly,amazon/chronos-t5-tiny,2.151677532807024,0.11356761694754255
+monash_tourism_quarterly,amazon/chronos-t5-tiny,1.9116538900950555,0.07191734222366106
+monash_tourism_yearly,amazon/chronos-t5-tiny,3.820615532600914,0.19709256337364625
+monash_traffic,amazon/chronos-t5-tiny,0.878709088458116,0.2632101606272236
+monash_weather,amazon/chronos-t5-tiny,0.8504899606521996,0.14787595319625085
+nn5,amazon/chronos-t5-tiny,0.7021735456568664,0.19071330483289695
diff --git a/src/chronos/__init__.py b/src/chronos/__init__.py
index 4474e8e..3088d7e 100644
--- a/src/chronos/__init__.py
+++ b/src/chronos/__init__.py
@@ -1,6 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
+from .base import BaseChronosPipeline, ForecastType
from .chronos import (
ChronosConfig,
ChronosModel,
@@ -8,11 +9,16 @@ from .chronos import (
ChronosTokenizer,
MeanScaleUniformBins,
)
+from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline
__all__ = [
+ "BaseChronosPipeline",
+ "ForecastType",
"ChronosConfig",
"ChronosModel",
"ChronosPipeline",
"ChronosTokenizer",
"MeanScaleUniformBins",
+ "ChronosBoltConfig",
+ "ChronosBoltPipeline",
]
diff --git a/src/chronos/base.py b/src/chronos/base.py
new file mode 100644
index 0000000..3dc1775
--- /dev/null
+++ b/src/chronos/base.py
@@ -0,0 +1,162 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+# Authors: Caner Turkmen , Abdul Fatir Ansari , Lorenzo Stella
+# Original source:
+# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/src/autogluon/timeseries/models/chronos/pipeline/base.py
+
+from enum import Enum
+from pathlib import Path
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+
+from .utils import left_pad_and_stack_1D
+
+
+class ForecastType(Enum):
+ SAMPLES = "samples"
+ QUANTILES = "quantiles"
+
+
+class PipelineRegistry(type):
+ REGISTRY: Dict[str, "PipelineRegistry"] = {}
+
+ def __new__(cls, name, bases, attrs):
+ """See, https://github.com/faif/python-patterns."""
+ new_cls = type.__new__(cls, name, bases, attrs)
+ if name is not None:
+ cls.REGISTRY[name] = new_cls
+
+ return new_cls
+
+
+class BaseChronosPipeline(metaclass=PipelineRegistry):
+ forecast_type: ForecastType
+ dtypes = {"bfloat16": torch.bfloat16, "float32": torch.float32}
+
+ def __init__(self, inner_model: "PreTrainedModel"):
+ """
+ Parameters
+ ----------
+ inner_model : PreTrainedModel
+ A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration
+ """
+ # for easy access to the inner HF-style model
+ self.inner_model = inner_model
+
+ def _prepare_and_validate_context(
+ self, context: Union[torch.Tensor, List[torch.Tensor]]
+ ):
+ if isinstance(context, list):
+ context = left_pad_and_stack_1D(context)
+ assert isinstance(context, torch.Tensor)
+ if context.ndim == 1:
+ context = context.unsqueeze(0)
+ assert context.ndim == 2
+
+ return context
+
+ def predict(
+ self,
+ context: Union[torch.Tensor, List[torch.Tensor]],
+ prediction_length: Optional[int] = None,
+ **kwargs,
+ ):
+ """
+ Get forecasts for the given time series.
+
+ Parameters
+ ----------
+ context
+ Input series. This is either a 1D tensor, or a list
+ of 1D tensors, or a 2D tensor whose first dimension
+ is batch. In the latter case, use left-padding with
+ ``torch.nan`` to align series of different lengths.
+ prediction_length
+ Time steps to predict. Defaults to a model-dependent
+ value if not given.
+
+ Returns
+ -------
+ forecasts
+ Tensor containing forecasts. The layout and meaning
+ of the forecasts values depends on ``self.forecast_type``.
+ """
+ raise NotImplementedError()
+
+ def predict_quantiles(
+ self,
+ context: Union[torch.Tensor, List[torch.Tensor]],
+ prediction_length: Optional[int] = None,
+ quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get quantile and mean forecasts for given time series.
+
+ Parameters
+ ----------
+ context : Union[torch.Tensor, List[torch.Tensor]]
+ Input series. This is either a 1D tensor, or a list
+ of 1D tensors, or a 2D tensor whose first dimension
+ is batch. In the latter case, use left-padding with
+ ``torch.nan`` to align series of different lengths.
+ prediction_length : Optional[int], optional
+ Time steps to predict. Defaults to a model-dependent
+ value if not given.
+ quantile_levels : List[float], optional
+ Quantile levels to compute, by default [0.1, 0.2, ..., 0.9]
+
+ Returns
+ -------
+ quantiles
+ Tensor containing quantile forecasts. Shape
+ (batch_size, prediction_length, num_quantiles)
+ mean
+ Tensor containing mean (point) forecasts. Shape
+ (batch_size, prediction_length)
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, Path],
+ *model_args,
+ **kwargs,
+ ):
+ """
+ Load the model, either from a local path or from the HuggingFace Hub.
+ Supports the same arguments as ``AutoConfig`` and ``AutoModel``
+ from ``transformers``.
+ """
+ from transformers import AutoConfig
+
+ torch_dtype = kwargs.get("torch_dtype", "auto")
+ if torch_dtype != "auto" and isinstance(torch_dtype, str):
+ kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
+
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(
+ config, "chronos_config"
+ )
+
+ if not is_valid_config:
+ raise ValueError("Not a Chronos config file")
+
+ pipeline_class_name = getattr(
+ config, "chronos_pipeline_class", "ChronosPipeline"
+ )
+ class_ = PipelineRegistry.REGISTRY.get(pipeline_class_name)
+ if class_ is None:
+ raise ValueError(
+ f"Trying to load unknown pipeline class: {pipeline_class_name}"
+ )
+
+ return class_.from_pretrained( # type: ignore[attr-defined]
+ pretrained_model_name_or_path, *model_args, **kwargs
+ )
diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py
index 000dc67..ef226f6 100644
--- a/src/chronos/chronos.py
+++ b/src/chronos/chronos.py
@@ -1,7 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
-import warnings
+# Authors: Abdul Fatir Ansari , Lorenzo Stella , Caner Turkmen
+
+import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
@@ -16,6 +18,10 @@ from transformers import (
)
import chronos
+from chronos.base import BaseChronosPipeline, ForecastType
+from chronos.utils import left_pad_and_stack_1D
+
+logger = logging.getLogger(__file__)
@dataclass
@@ -364,21 +370,7 @@ class ChronosModel(nn.Module):
return preds.reshape(input_ids.size(0), num_samples, -1)
-def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
- max_len = max(len(c) for c in tensors)
- padded = []
- for c in tensors:
- assert isinstance(c, torch.Tensor)
- assert c.ndim == 1
- padding = torch.full(
- size=(max_len - len(c),), fill_value=torch.nan, device=c.device
- )
- padded.append(torch.concat((padding, c), dim=-1))
- return torch.stack(padded).to(tensors[0])
-
-
-@dataclass
-class ChronosPipeline:
+class ChronosPipeline(BaseChronosPipeline):
"""
A ``ChronosPipeline`` uses the given tokenizer and model to forecast
input time series.
@@ -396,6 +388,12 @@ class ChronosPipeline:
tokenizer: ChronosTokenizer
model: ChronosModel
+ forecast_type: ForecastType = ForecastType.SAMPLES
+
+ def __init__(self, tokenizer, model):
+ super().__init__(inner_model=model.model)
+ self.tokenizer = tokenizer
+ self.model = model
def _prepare_and_validate_context(
self, context: Union[torch.Tensor, List[torch.Tensor]]
@@ -445,7 +443,7 @@ class ChronosPipeline:
).cpu()
return embeddings, tokenizer_state
- def predict(
+ def predict( # type: ignore[override]
self,
context: Union[torch.Tensor, List[torch.Tensor]],
prediction_length: Optional[int] = None,
@@ -453,21 +451,16 @@ class ChronosPipeline:
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
- limit_prediction_length: bool = True,
+ limit_prediction_length: bool = False,
) -> torch.Tensor:
"""
Get forecasts for the given time series.
- Parameters
- ----------
- context
- Input series. This is either a 1D tensor, or a list
- of 1D tensors, or a 2D tensor whose first dimension
- is batch. In the latter case, use left-padding with
- ``torch.nan`` to align series of different lengths.
- prediction_length
- Time steps to predict. Defaults to what specified
- in ``self.model.config``.
+ Refer to the base method (``BaseChronosPipeline.predict``)
+ for details on shared parameters.
+
+ Additional parameters
+ ---------------------
num_samples
Number of sample paths to predict. Defaults to what
specified in ``self.model.config``.
@@ -482,7 +475,7 @@ class ChronosPipeline:
Defaults to what specified in ``self.model.config``.
limit_prediction_length
Force prediction length smaller or equal than the
- built-in prediction length from the model. True by
+ built-in prediction length from the model. False by
default. When true, fail loudly if longer predictions
are requested, otherwise longer predictions are allowed.
@@ -505,7 +498,7 @@ class ChronosPipeline:
if limit_prediction_length:
msg += "You can turn off this check by setting `limit_prediction_length=False`."
raise ValueError(msg)
- warnings.warn(msg)
+ logger.warning(msg)
input_dtype = context_tensor.dtype
input_device = context_tensor.device
@@ -542,6 +535,30 @@ class ChronosPipeline:
return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device)
+ def predict_quantiles(
+ self,
+ context: Union[torch.Tensor, List[torch.Tensor]],
+ prediction_length: Optional[int] = None,
+ quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
+ **predict_kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Refer to the base method (``BaseChronosPipeline.predict_quantiles``).
+ """
+ prediction_samples = (
+ self.predict(context, prediction_length=prediction_length, **predict_kwargs)
+ .detach()
+ .swapaxes(1, 2)
+ )
+ mean = prediction_samples.mean(dim=-1)
+ quantiles = torch.quantile(
+ prediction_samples,
+ q=torch.tensor(quantile_levels, dtype=prediction_samples.dtype),
+ dim=-1,
+ ).permute(1, 2, 0)
+
+ return quantiles, mean
+
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""
diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py
new file mode 100644
index 0000000..e3182f9
--- /dev/null
+++ b/src/chronos/chronos_bolt.py
@@ -0,0 +1,587 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+# Authors: Abdul Fatir Ansari , Caner Turkmen , Lorenzo Stella
+# Original source:
+# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/src/autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py
+
+import copy
+import logging
+import warnings
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from transformers import AutoConfig
+from transformers.models.t5.modeling_t5 import (
+ ACT2FN,
+ T5Config,
+ T5LayerNorm,
+ T5PreTrainedModel,
+ T5Stack,
+)
+from transformers.utils import ModelOutput
+
+from .base import BaseChronosPipeline, ForecastType
+
+
+logger = logging.getLogger(__file__)
+
+
+@dataclass
+class ChronosBoltConfig:
+ context_length: int
+ prediction_length: int
+ input_patch_size: int
+ input_patch_stride: int
+ quantiles: List[float]
+ use_reg_token: bool = False
+
+
+@dataclass
+class ChronosBoltOutput(ModelOutput):
+ loss: Optional[torch.Tensor] = None
+ quantile_preds: Optional[torch.Tensor] = None
+ attentions: Optional[torch.Tensor] = None
+ cross_attentions: Optional[torch.Tensor] = None
+
+
+class Patch(nn.Module):
+ def __init__(self, patch_size: int, patch_stride: int) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.patch_stride = patch_stride
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ length = x.shape[-1]
+
+ if length % self.patch_size != 0:
+ padding_size = (
+ *x.shape[:-1],
+ self.patch_size - (length % self.patch_size),
+ )
+ padding = torch.full(
+ size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device
+ )
+ x = torch.concat((padding, x), dim=-1)
+
+ x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
+ return x
+
+
+class InstanceNorm(nn.Module):
+ """
+ See, also, RevIN. Apply standardization along the last dimension.
+ """
+
+ def __init__(self, eps: float = 1e-5) -> None:
+ super().__init__()
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ loc_scale: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ if loc_scale is None:
+ loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0)
+ scale = torch.nan_to_num(
+ torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0
+ )
+ scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
+ else:
+ loc, scale = loc_scale
+
+ return (x - loc) / scale, (loc, scale)
+
+ def inverse(
+ self, x: torch.Tensor, loc_scale: Tuple[torch.Tensor, torch.Tensor]
+ ) -> torch.Tensor:
+ loc, scale = loc_scale
+ return x * scale + loc
+
+
+class ResidualBlock(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ h_dim: int,
+ out_dim: int,
+ act_fn_name: str,
+ dropout_p: float = 0.0,
+ use_layer_norm: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.dropout = nn.Dropout(dropout_p)
+ self.hidden_layer = nn.Linear(in_dim, h_dim)
+ self.act = ACT2FN[act_fn_name]
+ self.output_layer = nn.Linear(h_dim, out_dim)
+ self.residual_layer = nn.Linear(in_dim, out_dim)
+
+ self.use_layer_norm = use_layer_norm
+ if use_layer_norm:
+ self.layer_norm = T5LayerNorm(out_dim)
+
+ def forward(self, x: torch.Tensor):
+ hid = self.act(self.hidden_layer(x))
+ out = self.dropout(self.output_layer(hid))
+ res = self.residual_layer(x)
+
+ out = out + res
+
+ if self.use_layer_norm:
+ return self.layer_norm(out)
+ return out
+
+
+class ChronosBoltModelForForecasting(T5PreTrainedModel):
+ _keys_to_ignore_on_load_missing = [
+ r"input_patch_embedding\.",
+ r"output_patch_embedding\.",
+ ]
+ _keys_to_ignore_on_load_unexpected = [r"lm_head.weight"]
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+
+ def __init__(self, config: T5Config):
+ assert hasattr(config, "chronos_config"), "Not a Chronos config file"
+
+ super().__init__(config)
+ self.model_dim = config.d_model
+
+ self.chronos_config = ChronosBoltConfig(**config.chronos_config)
+
+ # Only decoder_start_id (and optionally REG token)
+ if self.chronos_config.use_reg_token:
+ config.reg_token_id = 1
+
+ config.vocab_size = 2 if self.chronos_config.use_reg_token else 1
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
+
+ # Input patch embedding layer
+ self.input_patch_embedding = ResidualBlock(
+ in_dim=self.chronos_config.input_patch_size * 2,
+ h_dim=config.d_ff,
+ out_dim=config.d_model,
+ act_fn_name=config.dense_act_fn,
+ dropout_p=config.dropout_rate,
+ )
+
+ # patching layer
+ self.patch = Patch(
+ patch_size=self.chronos_config.input_patch_size,
+ patch_stride=self.chronos_config.input_patch_stride,
+ )
+
+ # instance normalization, also referred to as "scaling" in Chronos and GluonTS
+ self.instance_norm = InstanceNorm()
+
+ encoder_config = copy.deepcopy(config)
+ encoder_config.is_decoder = False
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = T5Stack(encoder_config, self.shared)
+
+ self._init_decoder(config)
+
+ self.num_quantiles = len(self.chronos_config.quantiles)
+ quantiles = torch.tensor(self.chronos_config.quantiles, dtype=self.dtype)
+ self.register_buffer("quantiles", quantiles, persistent=False)
+
+ self.output_patch_embedding = ResidualBlock(
+ in_dim=config.d_model,
+ h_dim=config.d_ff,
+ out_dim=self.num_quantiles * self.chronos_config.prediction_length,
+ act_fn_name=config.dense_act_fn,
+ dropout_p=config.dropout_rate,
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ """Initialize the weights"""
+ factor = self.config.initializer_factor
+ if isinstance(module, (self.__class__)):
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
+ elif isinstance(module, ResidualBlock):
+ module.hidden_layer.weight.data.normal_(
+ mean=0.0,
+ std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
+ )
+ if (
+ hasattr(module.hidden_layer, "bias")
+ and module.hidden_layer.bias is not None
+ ):
+ module.hidden_layer.bias.data.zero_()
+
+ module.residual_layer.weight.data.normal_(
+ mean=0.0,
+ std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
+ )
+ if (
+ hasattr(module.residual_layer, "bias")
+ and module.residual_layer.bias is not None
+ ):
+ module.residual_layer.bias.data.zero_()
+
+ module.output_layer.weight.data.normal_(
+ mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
+ )
+ if (
+ hasattr(module.output_layer, "bias")
+ and module.output_layer.bias is not None
+ ):
+ module.output_layer.bias.data.zero_()
+
+ def forward(
+ self,
+ context: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ target: Optional[torch.Tensor] = None,
+ target_mask: Optional[torch.Tensor] = None,
+ ) -> ChronosBoltOutput:
+ mask = (
+ mask.to(context.dtype)
+ if mask is not None
+ else torch.isnan(context).logical_not().to(context.dtype)
+ )
+
+ batch_size, _ = context.shape
+ if context.shape[-1] > self.chronos_config.context_length:
+ context = context[..., -self.chronos_config.context_length :]
+ mask = mask[..., -self.chronos_config.context_length :]
+
+ # scaling
+ context, loc_scale = self.instance_norm(context)
+
+ # the scaling op above is done in 32-bit precision,
+ # then the context is moved to model's dtype
+ context = context.to(self.dtype)
+ mask = mask.to(self.dtype)
+
+ # patching
+ patched_context = self.patch(context)
+ patched_mask = torch.nan_to_num(self.patch(mask), nan=0.0)
+ patched_context = torch.where(patched_mask > 0.0, patched_context, 0.0)
+ # concat context and mask along patch dim
+ patched_context = torch.cat([patched_context, patched_mask], dim=-1)
+
+ # attention_mask = 1 if at least one item in the patch is observed
+ attention_mask = (
+ patched_mask.sum(dim=-1) > 0
+ ) # (batch_size, patched_seq_length)
+
+ input_embeds = self.input_patch_embedding(patched_context)
+
+ if self.chronos_config.use_reg_token:
+ # Append [REG]
+ reg_input_ids = torch.full(
+ (batch_size, 1),
+ self.config.reg_token_id,
+ device=input_embeds.device,
+ )
+ reg_embeds = self.shared(reg_input_ids)
+ input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
+ attention_mask = torch.cat(
+ [
+ attention_mask.to(self.dtype),
+ torch.ones_like(reg_input_ids).to(self.dtype),
+ ],
+ dim=-1,
+ )
+
+ encoder_outputs = self.encoder(
+ attention_mask=attention_mask,
+ inputs_embeds=input_embeds,
+ )
+ hidden_states = encoder_outputs[0]
+
+ sequence_output = self.decode(input_embeds, attention_mask, hidden_states)
+
+ quantile_preds_shape = (
+ batch_size,
+ self.num_quantiles,
+ self.chronos_config.prediction_length,
+ )
+ quantile_preds = self.output_patch_embedding(sequence_output).view(
+ *quantile_preds_shape
+ )
+
+ loss = None
+ if target is not None:
+ # normalize target
+ target, _ = self.instance_norm(target, loc_scale)
+ target = target.unsqueeze(1) # type: ignore
+ assert self.chronos_config.prediction_length >= target.shape[-1]
+
+ target = target.to(quantile_preds.device)
+ target_mask = (
+ target_mask.unsqueeze(1).to(quantile_preds.device)
+ if target_mask is not None
+ else ~torch.isnan(target)
+ )
+ target[~target_mask] = 0.0
+
+ # pad target and target_mask if they are shorter than model's prediction_length
+ if self.chronos_config.prediction_length > target.shape[-1]:
+ padding_shape = (
+ *target.shape[:-1],
+ self.chronos_config.prediction_length - target.shape[-1],
+ )
+ target = torch.cat(
+ [target, torch.zeros(padding_shape).to(target)], dim=-1
+ )
+ target_mask = torch.cat(
+ [target_mask, torch.zeros(padding_shape).to(target_mask)], dim=-1
+ )
+
+ loss = (
+ 2
+ * torch.abs(
+ (target - quantile_preds)
+ * (
+ (target <= quantile_preds).float()
+ - self.quantiles.view(1, self.num_quantiles, 1)
+ )
+ )
+ * target_mask.float()
+ )
+ loss = loss.mean(dim=-2) # Mean over prediction horizon
+ loss = loss.sum(dim=-1) # Sum over quantile levels
+ loss = loss.mean() # Mean over batch
+
+ # Unscale predictions
+ quantile_preds = self.instance_norm.inverse(
+ quantile_preds.view(batch_size, -1),
+ loc_scale,
+ ).view(*quantile_preds_shape)
+
+ return ChronosBoltOutput(
+ loss=loss,
+ quantile_preds=quantile_preds,
+ )
+
+ def _init_decoder(self, config):
+ decoder_config = copy.deepcopy(config)
+ decoder_config.is_decoder = True
+ decoder_config.is_encoder_decoder = False
+ decoder_config.num_layers = config.num_decoder_layers
+ self.decoder = T5Stack(decoder_config, self.shared)
+
+ def decode(
+ self,
+ input_embeds,
+ attention_mask,
+ hidden_states,
+ output_attentions=False,
+ ):
+ """
+ Parameters
+ ----------
+ input_embeds: torch.Tensor
+ Patched and embedded inputs. Shape (batch_size, patched_context_length, d_model)
+ attention_mask: torch.Tensor
+ Attention mask for the patched context. Shape (batch_size, patched_context_length), type: torch.int64
+ hidden_states: torch.Tensor
+ Hidden states returned by the encoder. Shape (batch_size, patched_context_length, d_model)
+
+ Returns
+ -------
+ last_hidden_state
+ Last hidden state returned by the decoder, of shape (batch_size, 1, d_model)
+ """
+ batch_size = input_embeds.shape[0]
+ decoder_input_ids = torch.full(
+ (batch_size, 1),
+ self.config.decoder_start_token_id,
+ device=input_embeds.device,
+ )
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ encoder_hidden_states=hidden_states,
+ encoder_attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ return_dict=True,
+ )
+
+ return decoder_outputs.last_hidden_state # sequence_outputs, b x 1 x d_model
+
+
+class ChronosBoltPipeline(BaseChronosPipeline):
+ forecast_type: ForecastType = ForecastType.QUANTILES
+ default_context_length: int = 2048
+
+ def __init__(self, model: ChronosBoltModelForForecasting):
+ super().__init__(inner_model=model)
+ self.model = model
+
+ @property
+ def quantiles(self) -> List[float]:
+ return self.model.config.chronos_config["quantiles"]
+
+ def predict( # type: ignore[override]
+ self,
+ context: Union[torch.Tensor, List[torch.Tensor]],
+ prediction_length: Optional[int] = None,
+ limit_prediction_length: bool = False,
+ ) -> torch.Tensor:
+ """
+ Get forecasts for the given time series.
+
+ Refer to the base method (``BaseChronosPipeline.predict``)
+ for details on shared parameters.
+ Additional parameters
+ ---------------------
+ limit_prediction_length
+ Force prediction length smaller or equal than the
+ built-in prediction length from the model. False by
+ default. When true, fail loudly if longer predictions
+ are requested, otherwise longer predictions are allowed.
+
+ Returns
+ -------
+ torch.Tensor
+ Forecasts of shape (batch_size, num_quantiles, prediction_length)
+ where num_quantiles is the number of quantiles the model has been
+ trained to output. For official Chronos-Bolt models, the value of
+ num_quantiles is 9 for [0.1, 0.2, ..., 0.9]-quantiles.
+
+ Raises
+ ------
+ ValueError
+ When limit_prediction_length is True and the prediction_length is
+ greater than model's trainig prediction_length.
+ """
+ context_tensor = self._prepare_and_validate_context(context=context)
+
+ model_context_length = self.model.config.chronos_config["context_length"]
+ model_prediction_length = self.model.config.chronos_config["prediction_length"]
+ if prediction_length is None:
+ prediction_length = model_prediction_length
+
+ if prediction_length > model_prediction_length:
+ msg = (
+ f"We recommend keeping prediction length <= {model_prediction_length}. "
+ "The quality of longer predictions may degrade since the model is not optimized for it. "
+ )
+ if limit_prediction_length:
+ msg += "You can turn off this check by setting `limit_prediction_length=False`."
+ raise ValueError(msg)
+ warnings.warn(msg)
+
+ predictions = []
+ remaining = prediction_length
+
+ # We truncate the context here because otherwise batches with very long
+ # context could take up large amounts of GPU memory unnecessarily.
+ if context_tensor.shape[-1] > model_context_length:
+ context_tensor = context_tensor[..., -model_context_length:]
+
+ # TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
+ # horizon that the model was trained with (i.e., 64). This results in variance collapsing
+ # every 64 steps.
+ while remaining > 0:
+ with torch.no_grad():
+ prediction = self.model(
+ context=context_tensor.to(
+ device=self.model.device,
+ dtype=torch.float32, # scaling should be done in 32-bit precision
+ ),
+ ).quantile_preds.to(context_tensor)
+
+ predictions.append(prediction)
+ remaining -= prediction.shape[-1]
+
+ if remaining <= 0:
+ break
+
+ central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()
+ central_prediction = prediction[:, central_idx]
+
+ context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
+
+ return torch.cat(predictions, dim=-1)[..., :prediction_length]
+
+ def predict_quantiles(
+ self,
+ context: Union[torch.Tensor, List[torch.Tensor]],
+ prediction_length: Optional[int] = None,
+ quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
+ **predict_kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Refer to the base method (``BaseChronosPipeline.predict_quantiles``).
+ """
+ # shape (batch_size, prediction_length, len(training_quantile_levels))
+ predictions = (
+ self.predict(context, prediction_length=prediction_length, **predict_kwargs)
+ .detach()
+ .swapaxes(1, 2)
+ )
+
+ training_quantile_levels = self.quantiles
+
+ if set(quantile_levels).issubset(set(training_quantile_levels)):
+ # no need to perform intra/extrapolation
+ quantiles = predictions[
+ ..., [training_quantile_levels.index(q) for q in quantile_levels]
+ ]
+ else:
+ # we rely on torch for interpolating quantiles if quantiles that
+ # Chronos Bolt was trained on were not provided
+ if min(quantile_levels) < min(training_quantile_levels) or max(
+ quantile_levels
+ ) > max(training_quantile_levels):
+ logger.warning(
+ f"\tQuantiles to be predicted ({quantile_levels}) are not within the range of "
+ f"quantiles that Chronos-Bolt was trained on ({training_quantile_levels}). "
+ "Quantile predictions will be set to the minimum/maximum levels at which Chronos-Bolt "
+ "was trained on. This may significantly affect the quality of the predictions."
+ )
+
+ # TODO: this is a hack that assumes the model's quantiles during training (training_quantile_levels)
+ # made up an equidistant grid along the quantile dimension. i.e., they were (0.1, 0.2, ..., 0.9).
+ # While this holds for official Chronos-Bolt models, this may not be true in the future, and this
+ # function may have to be revised.
+ augmented_predictions = torch.cat(
+ [predictions[..., [0]], predictions, predictions[..., [-1]]],
+ dim=-1,
+ )
+ quantiles = torch.quantile(
+ augmented_predictions,
+ q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
+ dim=-1,
+ ).permute(1, 2, 0)
+ # NOTE: the median is returned as the mean here
+ mean = predictions[:, :, training_quantile_levels.index(0.5)]
+ return quantiles, mean
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ """
+ Load the model, either from a local path or from the HuggingFace Hub.
+ Supports the same arguments as ``AutoConfig`` and ``AutoModel``
+ from ``transformers``.
+ """
+
+ config = AutoConfig.from_pretrained(*args, **kwargs)
+ assert hasattr(config, "chronos_config"), "Not a Chronos config file"
+
+ architecture = config.architectures[0]
+ class_ = globals().get(architecture)
+
+ if class_ is None:
+ logger.warning(
+ f"Unknown architecture: {architecture}, defaulting to ChronosBoltModelForForecasting"
+ )
+ class_ = ChronosBoltModelForForecasting
+
+ model = class_.from_pretrained(*args, **kwargs)
+ return cls(model=model)
diff --git a/src/chronos/utils.py b/src/chronos/utils.py
new file mode 100644
index 0000000..9248d5a
--- /dev/null
+++ b/src/chronos/utils.py
@@ -0,0 +1,20 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+from typing import List
+
+import torch
+
+
+def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
+ max_len = max(len(c) for c in tensors)
+ padded = []
+ for c in tensors:
+ assert isinstance(c, torch.Tensor)
+ assert c.ndim == 1
+ padding = torch.full(
+ size=(max_len - len(c),), fill_value=torch.nan, device=c.device
+ )
+ padded.append(torch.concat((padding, c), dim=-1))
+ return torch.stack(padded).to(tensors[0])
diff --git a/test/dummy-chronos-bolt-model/config.json b/test/dummy-chronos-bolt-model/config.json
new file mode 100644
index 0000000..96eb29a
--- /dev/null
+++ b/test/dummy-chronos-bolt-model/config.json
@@ -0,0 +1,50 @@
+{
+ "architectures": [
+ "ChronosBoltModelForForecasting"
+ ],
+ "chronos_config": {
+ "context_length": 512,
+ "input_patch_size": 16,
+ "input_patch_stride": 16,
+ "prediction_length": 64,
+ "quantiles": [
+ 0.1,
+ 0.2,
+ 0.3,
+ 0.4,
+ 0.5,
+ 0.6,
+ 0.7,
+ 0.8,
+ 0.9
+ ],
+ "use_reg_token": true
+ },
+ "chronos_pipeline_class": "ChronosBoltPipeline",
+ "classifier_dropout": 0.0,
+ "d_ff": 8,
+ "d_kv": 4,
+ "d_model": 8,
+ "decoder_start_token_id": 0,
+ "dense_act_fn": "relu",
+ "dropout_rate": 0.1,
+ "eos_token_id": 1,
+ "feed_forward_proj": "relu",
+ "initializer_factor": 0.05,
+ "is_encoder_decoder": true,
+ "is_gated_act": false,
+ "layer_norm_epsilon": 1e-06,
+ "model_type": "t5",
+ "n_positions": 512,
+ "num_decoder_layers": 4,
+ "num_heads": 4,
+ "num_layers": 4,
+ "pad_token_id": 0,
+ "reg_token_id": 1,
+ "relative_attention_max_distance": 128,
+ "relative_attention_num_buckets": 32,
+ "torch_dtype": "float32",
+ "transformers_version": "4.40.2",
+ "use_cache": true,
+ "vocab_size": 2
+}
diff --git a/test/dummy-chronos-bolt-model/model.safetensors b/test/dummy-chronos-bolt-model/model.safetensors
new file mode 100644
index 0000000..b7b5b4e
Binary files /dev/null and b/test/dummy-chronos-bolt-model/model.safetensors differ
diff --git a/test/test_chronos.py b/test/test_chronos.py
index e2b71dc..b0235c0 100644
--- a/test/test_chronos.py
+++ b/test/test_chronos.py
@@ -2,12 +2,21 @@
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
-from typing import Tuple
+from typing import Optional, Tuple
import pytest
import torch
-from chronos import ChronosConfig, ChronosPipeline, MeanScaleUniformBins
+from chronos import (
+ BaseChronosPipeline,
+ ChronosConfig,
+ ChronosPipeline,
+ MeanScaleUniformBins,
+)
+
+
+def test_base_chronos_pipeline_loads_from_huggingface():
+ BaseChronosPipeline.from_pretrained("amazon/chronos-t5-tiny", device_map="cpu")
@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
@@ -157,10 +166,14 @@ def test_tokenizer_random_data(use_eos_token: bool):
assert samples.shape == (2, 10, 4)
-def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype) -> None:
+def validate_tensor(
+ a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None
+) -> None:
assert isinstance(a, torch.Tensor)
assert a.shape == shape
- assert a.dtype == dtype
+
+ if dtype is not None:
+ assert a.dtype == dtype
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@@ -179,7 +192,9 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
with pytest.raises(ValueError):
- samples = pipeline.predict(context, num_samples=7, prediction_length=65)
+ samples = pipeline.predict(
+ context, num_samples=7, prediction_length=65, limit_prediction_length=True
+ )
samples = pipeline.predict(
context, num_samples=7, prediction_length=65, limit_prediction_length=False
@@ -192,7 +207,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
with pytest.raises(ValueError):
- samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
+ samples = pipeline.predict(
+ list(context),
+ num_samples=7,
+ prediction_length=65,
+ limit_prediction_length=True,
+ )
samples = pipeline.predict(
list(context),
@@ -208,17 +228,73 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype)
with pytest.raises(ValueError):
- samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
+ samples = pipeline.predict(
+ context[0, ...],
+ num_samples=7,
+ prediction_length=65,
+ limit_prediction_length=True,
+ )
samples = pipeline.predict(
context[0, ...],
num_samples=7,
prediction_length=65,
- limit_prediction_length=False,
)
validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype)
+@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
+@pytest.mark.parametrize("prediction_length", [3, 65])
+@pytest.mark.parametrize(
+ "quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
+)
+def test_pipeline_predict_quantiles(
+ model_dtype: torch.dtype,
+ prediction_length: int,
+ quantile_levels: list[int],
+):
+ pipeline = ChronosPipeline.from_pretrained(
+ Path(__file__).parent / "dummy-chronos-model",
+ device_map="cpu",
+ torch_dtype=model_dtype,
+ )
+ context = 10 * torch.rand(size=(4, 16)) + 10
+
+ num_expected_quantiles = len(quantile_levels)
+ # input: tensor of shape (batch_size, context_length)
+
+ quantiles, mean = pipeline.predict_quantiles(
+ context,
+ num_samples=12,
+ prediction_length=prediction_length,
+ quantile_levels=quantile_levels,
+ )
+ validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
+ validate_tensor(mean, (4, prediction_length))
+
+ # input: batch_size-long list of tensors of shape (context_length,)
+
+ quantiles, mean = pipeline.predict_quantiles(
+ list(context),
+ num_samples=12,
+ prediction_length=prediction_length,
+ quantile_levels=quantile_levels,
+ )
+ validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
+ validate_tensor(mean, (4, prediction_length))
+
+ # input: tensor of shape (context_length,)
+
+ quantiles, mean = pipeline.predict_quantiles(
+ context[0, ...],
+ num_samples=12,
+ prediction_length=prediction_length,
+ quantile_levels=quantile_levels,
+ )
+ validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles))
+ validate_tensor(mean, (1, prediction_length))
+
+
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
diff --git a/test/test_chronos_bolt.py b/test/test_chronos_bolt.py
new file mode 100644
index 0000000..c4c3db7
--- /dev/null
+++ b/test/test_chronos_bolt.py
@@ -0,0 +1,246 @@
+from pathlib import Path
+from typing import Tuple
+
+import pytest
+import torch
+
+from chronos import BaseChronosPipeline, ChronosBoltPipeline
+from chronos.chronos_bolt import InstanceNorm, Patch
+
+
+def validate_tensor(input: torch.Tensor, shape: Tuple[int, ...]) -> None:
+ assert isinstance(input, torch.Tensor)
+ assert input.shape == shape
+
+
+def test_base_chronos_pipeline_loads_from_huggingface():
+ BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-tiny", device_map="cpu")
+
+
+@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
+def test_pipeline_predict(torch_dtype: str):
+ pipeline = ChronosBoltPipeline.from_pretrained(
+ Path(__file__).parent / "dummy-chronos-bolt-model",
+ device_map="cpu",
+ torch_dtype=torch_dtype,
+ )
+ context = 10 * torch.rand(size=(4, 16)) + 10
+ expected_num_quantiles = len(pipeline.quantiles)
+
+ # input: tensor of shape (batch_size, context_length)
+
+ quantiles = pipeline.predict(context, prediction_length=3)
+ validate_tensor(quantiles, (4, expected_num_quantiles, 3))
+
+ with pytest.raises(ValueError):
+ quantiles = pipeline.predict(
+ context, prediction_length=65, limit_prediction_length=True
+ )
+
+ quantiles = pipeline.predict(context, prediction_length=65)
+ validate_tensor(quantiles, (4, expected_num_quantiles, 65))
+
+ # input: batch_size-long list of tensors of shape (context_length,)
+
+ quantiles = pipeline.predict(list(context), prediction_length=3)
+ validate_tensor(quantiles, (4, expected_num_quantiles, 3))
+
+ with pytest.raises(ValueError):
+ quantiles = pipeline.predict(
+ list(context),
+ prediction_length=65,
+ limit_prediction_length=True,
+ )
+
+ quantiles = pipeline.predict(list(context), prediction_length=65)
+ validate_tensor(quantiles, (4, expected_num_quantiles, 65))
+
+ # input: tensor of shape (context_length,)
+
+ quantiles = pipeline.predict(context[0, ...], prediction_length=3)
+ validate_tensor(quantiles, (1, expected_num_quantiles, 3))
+
+ with pytest.raises(ValueError):
+ quantiles = pipeline.predict(
+ context[0, ...],
+ prediction_length=65,
+ limit_prediction_length=True,
+ )
+
+ quantiles = pipeline.predict(
+ context[0, ...],
+ prediction_length=65,
+ )
+ validate_tensor(quantiles, (1, expected_num_quantiles, 65))
+
+
+@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
+@pytest.mark.parametrize("prediction_length", [3, 65])
+@pytest.mark.parametrize(
+ "quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
+)
+def test_pipeline_predict_quantiles(
+ torch_dtype: str, prediction_length: int, quantile_levels: list[int]
+):
+ pipeline = ChronosBoltPipeline.from_pretrained(
+ Path(__file__).parent / "dummy-chronos-bolt-model",
+ device_map="cpu",
+ torch_dtype=torch_dtype,
+ )
+ context = 10 * torch.rand(size=(4, 16)) + 10
+
+ num_expected_quantiles = len(quantile_levels)
+ # input: tensor of shape (batch_size, context_length)
+
+ quantiles, mean = pipeline.predict_quantiles(
+ context,
+ prediction_length=prediction_length,
+ quantile_levels=quantile_levels,
+ )
+ validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
+ validate_tensor(mean, (4, prediction_length))
+
+ # input: batch_size-long list of tensors of shape (context_length,)
+
+ quantiles, mean = pipeline.predict_quantiles(
+ list(context),
+ prediction_length=prediction_length,
+ quantile_levels=quantile_levels,
+ )
+ validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
+ validate_tensor(mean, (4, prediction_length))
+
+ # input: tensor of shape (context_length,)
+
+ quantiles, mean = pipeline.predict_quantiles(
+ context[0, ...],
+ prediction_length=prediction_length,
+ quantile_levels=quantile_levels,
+ )
+ validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles))
+ validate_tensor(mean, (1, prediction_length))
+
+
+# The following tests have been taken from
+# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py
+# Author: Caner Turkmen
+
+
+def test_given_even_data_patch_operator_output_is_correct():
+ batch_size = 17
+ patch_len = 16
+
+ patch = Patch(patch_len, patch_len)
+
+ batch = (
+ torch.stack([torch.arange(512)] * batch_size)
+ + torch.arange(batch_size)[:, None]
+ )
+ output = patch(batch)
+
+ assert output.shape == (batch_size, 512 // patch_len, patch_len)
+
+ assert torch.allclose(
+ output[:, 0],
+ torch.stack([torch.arange(patch_len)] * batch_size)
+ + torch.arange(batch_size)[:, None],
+ atol=1e-5,
+ )
+ assert torch.allclose(
+ output[:, 1],
+ torch.stack([torch.arange(patch_len, 2 * patch_len)] * batch_size)
+ + torch.arange(batch_size)[:, None],
+ atol=1e-5,
+ )
+ assert not torch.isnan(output).any()
+
+
+def test_given_even_data_and_strides_patch_operator_output_is_correct():
+ batch_size = 17
+ patch_len, patch_stride = 16, 8
+
+ patch = Patch(patch_len, patch_stride)
+
+ offset = torch.arange(batch_size)[:, None]
+ batch = torch.stack([torch.arange(512)] * batch_size) + offset
+ output = patch(batch)
+
+ assert torch.allclose(
+ output[:, 1],
+ torch.stack([torch.arange(patch_stride, patch_stride + patch_len)] * batch_size)
+ + offset,
+ atol=1e-5,
+ )
+ assert not torch.isnan(output).any()
+
+
+def test_given_uneven_data_patch_operator_pads_and_output_is_correct():
+ batch_size = 17
+ patch_len = 16
+
+ patch = Patch(patch_len, patch_len)
+
+ batch = (
+ torch.stack([torch.arange(512 - patch_len + 1)] * batch_size)
+ + torch.arange(batch_size)[:, None]
+ ).float()
+ output = patch(batch)
+
+ assert output.shape == (batch_size, 512 // patch_len, patch_len)
+
+ # check the first portion is padded
+ assert torch.isnan(output[:, 0, :-1]).all()
+
+ # check nowhere else is nan
+ assert not torch.isnan(output[:, 1:]).any()
+
+
+def test_when_instancenorm_applied_then_standardization_correct():
+ inorm = InstanceNorm()
+
+ input_ = torch.tensor(
+ [
+ [1, 2, 3, 4, 5],
+ [2, 3, 4, 5, 6],
+ ]
+ ).float()
+
+ normalized, (loc, scale) = inorm(input_)
+
+ assert normalized.shape == input_.shape
+ assert torch.allclose(normalized[0], normalized[1])
+ assert torch.allclose(loc.squeeze(), torch.tensor([3.0, 4.0]))
+ assert torch.allclose(scale.squeeze(), torch.tensor(1.41421))
+
+
+def test_when_instancenorm_applied_and_reversed_then_nans_preserved():
+ inorm = InstanceNorm()
+
+ input_ = torch.tensor(
+ [
+ [1, torch.nan, 3, 4, 5],
+ [2, 3, 4, 5, torch.nan],
+ ]
+ ).float()
+
+ normalized, (loc, scale) = inorm(input_)
+ assert torch.allclose(normalized.isnan(), input_.isnan())
+
+ output = inorm.inverse(normalized, (loc, scale))
+ assert torch.allclose(output, input_, equal_nan=True)
+
+
+def test_when_instancenorm_applied_and_reversed_then_output_correct():
+ inorm = InstanceNorm()
+
+ input_ = torch.tensor(
+ [
+ [1, 2, 3, 4, 5],
+ [2, 3, 4, 5, 1000],
+ ]
+ ).float()
+
+ normalized, loc_scale = inorm(input_)
+ output = inorm.inverse(normalized, loc_scale)
+
+ assert torch.allclose(output, input_)