mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 09:39:35 +00:00
⚡ Add support for Chronos-Bolt models (#204)
*Issue #, if available:* N/A *Description of changes:* This PR adds support for Chronos-Bolt models. TODOs: - [x] Update evaluation script - [x] Fix and add tests for Bolt - [x] Update docstrings - [x] Update README example and mention Chronos-Bolt - [x] Update results bar plot in README - [x] Add versions for libraries in `pyproject.toml` - [x] Check that the training and eval scripts work - [x] Change `autogluon` -> `amazon` in model names Post Merge: - [ ] Update Citation style in README, both Github and HuggingFace repos - [ ] Remove note about AutoGluon - [ ] Update READMEs of original Chronos models to refer to Chronos-Bolt NOTE: To be merged after Chronos-Bolt models are available under the `amazon` namespace on HF. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de> Co-authored-by: Caner Turkmen <turkmen.ac@gmail.com> Co-authored-by: Lorenzo Stella <stellalo@amazon.com>
This commit is contained in:
parent
d0c114c81d
commit
72ab64166c
42 changed files with 6693 additions and 84 deletions
43
README.md
43
README.md
|
|
@ -17,7 +17,8 @@
|
|||
|
||||
## 🚀 News
|
||||
|
||||
- **27 June 2024**: 🚀 [Released datasets](https://huggingface.co/datasets/autogluon/chronos_datasets) used in the paper and an [evaluation script](./scripts/README.md#evaluating-chronos-models) to compute the WQL and MASE scores reported in the paper.
|
||||
- **26 Nov 2024**: ⚡️ Chronos-Bolt models released [on HuggingFace](https://huggingface.co/collections/amazon/chronos-models-65f1791d630a8d57cb718444). Chronos-Bolt models are more accurate (5% lower error), up to 250x faster and 20x more memory efficient than the original Chronos models of the same size!
|
||||
- **27 Jun 2024**: 🚀 [Released datasets](https://huggingface.co/datasets/autogluon/chronos_datasets) used in the paper and an [evaluation script](./scripts/README.md#evaluating-chronos-models) to compute the WQL and MASE scores reported in the paper.
|
||||
- **17 May 2024**: 🐛 Fixed an off-by-one error in bin indices in the `output_transform`. This simple fix significantly improves the overall performance of Chronos. We will update the results in the next revision on ArXiv.
|
||||
- **10 May 2024**: 🚀 We added the code for pretraining and fine-tuning Chronos models. You can find it in [this folder](./scripts/training). We also added [a script](./scripts/kernel-synth.py) for generating synthetic time series data from Gaussian processes (KernelSynth; see Section 4.2 in the paper for details). Check out the [usage examples](./scripts/).
|
||||
- **19 Apr 2024**: 🚀 Chronos is now supported on [AutoGluon-TimeSeries](https://auto.gluon.ai/stable/tutorials/timeseries/index.html), the powerful AutoML package for time series forecasting which enables model ensembles, cloud deployments, and much more. Get started with the [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).
|
||||
|
|
@ -52,62 +53,72 @@ The models in this repository are based on the [T5 architecture](https://arxiv.o
|
|||
| [**chronos-t5-small**](https://huggingface.co/amazon/chronos-t5-small) | 46M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) |
|
||||
| [**chronos-t5-base**](https://huggingface.co/amazon/chronos-t5-base) | 200M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) |
|
||||
| [**chronos-t5-large**](https://huggingface.co/amazon/chronos-t5-large) | 710M | [t5-efficient-large](https://huggingface.co/google/t5-efficient-large) |
|
||||
| [**chronos-bolt-tiny**](https://huggingface.co/amazon/chronos-bolt-tiny) | 9M | [t5-efficient-tiny](https://huggingface.co/google/t5-efficient-tiny) |
|
||||
| [**chronos-bolt-mini**](https://huggingface.co/amazon/chronos-bolt-mini) | 21M | [t5-efficient-mini](https://huggingface.co/google/t5-efficient-mini) |
|
||||
| [**chronos-bolt-small**](https://huggingface.co/amazon/chronos-bolt-small) | 48M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) |
|
||||
| [**chronos-bolt-base**](https://huggingface.co/amazon/chronos-bolt-base) | 205M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) |
|
||||
|
||||
</div>
|
||||
|
||||
### Zero-Shot Results
|
||||
|
||||
The following figure showcases the remarkable **zero-shot** performance of Chronos models on 27 datasets against local models, task-specific models and other pretrained models. For details on the evaluation setup and other results, please refer to [the paper](https://arxiv.org/abs/2403.07815).
|
||||
The following figure showcases the remarkable **zero-shot** performance of Chronos and Chronos-Bolt models on 27 datasets against local models, task-specific models and other pretrained models. For details on the evaluation setup and other results, please refer to [the paper](https://arxiv.org/abs/2403.07815).
|
||||
|
||||
<p align="center">
|
||||
<img src="figures/zero_shot-agg_scaled_score.png" width="80%">
|
||||
<img src="figures/zero_shot-agg_scaled_score.svg" width="100%">
|
||||
<br />
|
||||
<span>
|
||||
Fig. 2: Performance of different models on Benchmark II, comprising 27 datasets <b>not seen</b> by Chronos models during training. This benchmark provides insights into the zero-shot performance of Chronos models against local statistical models, which fit parameters individually for each time series, task-specific models <i>trained on each task</i>, and pretrained models trained on a large corpus of time series. Pretrained Models (Other) indicates that some (or all) of the datasets in Benchmark II may have been in the training corpus of these models. The probabilistic (WQL) and point (MASE) forecasting metrics were normalized using the scores of the Seasonal Naive baseline and aggregated through a geometric mean to obtain the Agg. Relative WQL and MASE, respectively.
|
||||
Fig. 2: Performance of different models on Benchmark II, comprising 27 datasets <b>not seen</b> by Chronos and Chronos-Bolt models during training. This benchmark provides insights into the zero-shot performance of Chronos and Chronos-Bolt models against local statistical models, which fit parameters individually for each time series, task-specific models <i>trained on each task</i>, and pretrained models trained on a large corpus of time series. Pretrained Models (Other) indicates that some (or all) of the datasets in Benchmark II may have been in the training corpus of these models. The probabilistic (WQL) and point (MASE) forecasting metrics were normalized using the scores of the Seasonal Naive baseline and aggregated through a geometric mean to obtain the Agg. Relative WQL and MASE, respectively.
|
||||
</span>
|
||||
</p>
|
||||
|
||||
## 📈 Usage
|
||||
|
||||
To perform inference with Chronos models, install this package by running:
|
||||
To perform inference with Chronos or Chronos-Bolt models, install this package by running:
|
||||
|
||||
```
|
||||
pip install git+https://github.com/amazon-science/chronos-forecasting.git
|
||||
```
|
||||
> [!TIP]
|
||||
> The recommended way of using Chronos for production use cases is through [AutoGluon](https://auto.gluon.ai), which features ensembling with other statistical and machine learning models for time series forecasting as well as seamless deployments on AWS with SageMaker 🧠. Check out the AutoGluon Chronos [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).
|
||||
> This repository is intended for research purposes and provides a minimal interface to Chronos models. The recommended way of using Chronos for production use cases is through [AutoGluon](https://auto.gluon.ai), which features effortless fine-tuning, augmenting Chronos models with exogenous information through covariate regressors, ensembling with other statistical and machine learning models, as well as seamless deployments on AWS with SageMaker 🧠. Check out the AutoGluon Chronos [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).
|
||||
|
||||
### Forecasting
|
||||
|
||||
A minimal example showing how to perform forecasting using Chronos models:
|
||||
A minimal example showing how to perform forecasting using Chronos and Chronos-Bolt models:
|
||||
|
||||
```python
|
||||
import pandas as pd # requires: pip install pandas
|
||||
import torch
|
||||
from chronos import ChronosPipeline
|
||||
from chronos import BaseChronosPipeline
|
||||
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
"amazon/chronos-t5-small",
|
||||
pipeline = BaseChronosPipeline.from_pretrained(
|
||||
"amazon/chronos-t5-small", # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
|
||||
device_map="cuda", # use "cpu" for CPU inference and "mps" for Apple Silicon
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
|
||||
df = pd.read_csv(
|
||||
"https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
|
||||
)
|
||||
|
||||
# context must be either a 1D tensor, a list of 1D tensors,
|
||||
# or a left-padded 2D tensor with batch as the first dimension
|
||||
# forecast shape: [num_series, num_samples, prediction_length]
|
||||
# The original Chronos models generate forecast samples, so forecast has shape
|
||||
# [num_series, num_samples, prediction_length].
|
||||
# Chronos-Bolt models generate quantile forecasts, so forecast has shape
|
||||
# [num_series, num_quantiles, prediction_length].
|
||||
forecast = pipeline.predict(
|
||||
context=torch.tensor(df["#Passengers"]),
|
||||
prediction_length=12,
|
||||
num_samples=20,
|
||||
context=torch.tensor(df["#Passengers"]), prediction_length=12
|
||||
)
|
||||
```
|
||||
|
||||
More options for `pipeline.predict` can be found with:
|
||||
|
||||
```python
|
||||
print(ChronosPipeline.predict.__doc__)
|
||||
from chronos import ChronosPipeline, ChronosBoltPipeline
|
||||
|
||||
print(ChronosPipeline.predict.__doc__) # for Chronos models
|
||||
print(ChronosBoltPipeline.predict.__doc__) # for Chronos-Bolt models
|
||||
```
|
||||
|
||||
We can now visualize the forecast:
|
||||
|
|
|
|||
Binary file not shown.
|
Before Width: | Height: | Size: 318 KiB |
4875
figures/zero_shot-agg_scaled_score.svg
Normal file
4875
figures/zero_shot-agg_scaled_score.svg
Normal file
File diff suppressed because it is too large
Load diff
|
After Width: | Height: | Size: 149 KiB |
|
|
@ -1,19 +1,19 @@
|
|||
[project]
|
||||
name = "chronos"
|
||||
version = "1.2.1"
|
||||
version = "1.3.0"
|
||||
requires-python = ">=3.8"
|
||||
license = { file = "LICENSE" }
|
||||
dependencies = [
|
||||
"torch~=2.0", # package was tested on 2.2
|
||||
"transformers~=4.30",
|
||||
"accelerate",
|
||||
"torch>=2.0,<2.6", # package was tested on 2.2
|
||||
"transformers>=4.30,<4.48",
|
||||
"accelerate>=0.32,<1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = ["pytest~=8.0", "numpy~=1.21"]
|
||||
typecheck = ["mypy~=1.9"]
|
||||
training = ["gluonts[pro]", "numpy", "tensorboard", "typer", "typer-config", "joblib", "scikit-learn"]
|
||||
evaluation = ["gluonts[pro]", "datasets", "numpy", "typer"]
|
||||
training = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer", "typer-config", "joblib", "scikit-learn", "tensorboard"]
|
||||
evaluation = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer"]
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
|
|
|
|||
60
scripts/evaluation/agg-relative-score.py
Normal file
60
scripts/evaluation/agg-relative-score.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import pandas as pd
|
||||
import typer
|
||||
from scipy.stats import gmean
|
||||
from pathlib import Path
|
||||
|
||||
app = typer.Typer(pretty_exceptions_enable=False)
|
||||
DEFAULT_RESULTS_DIR = Path(__file__).parent / "results"
|
||||
|
||||
|
||||
def agg_relative_score(model_csv: Path, baseline_csv: Path):
|
||||
model_df = pd.read_csv(model_csv).set_index("dataset")
|
||||
baseline_df = pd.read_csv(baseline_csv).set_index("dataset")
|
||||
relative_score = model_df.drop("model", axis="columns") / baseline_df.drop(
|
||||
"model", axis="columns"
|
||||
)
|
||||
return relative_score.agg(gmean)
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
model_name: str,
|
||||
baseline_name: str = "seasonal-naive",
|
||||
results_dir: Path = DEFAULT_RESULTS_DIR,
|
||||
):
|
||||
"""
|
||||
Compute the aggregated relative score as reported in the Chronos paper.
|
||||
Results will be saved to {results_dir}/{model_name}-agg-rel-scores.csv
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str
|
||||
Name of the model used in the CSV files. The in-domain and zero-shot CSVs
|
||||
are expected to be named {model_name}-in-domain.csv and {model_name}-zero-shot.csv.
|
||||
results_dir : Path, optional, default = results/
|
||||
Directory where results CSVs generated by evaluate.py are stored
|
||||
"""
|
||||
|
||||
in_domain_agg_score_df = agg_relative_score(
|
||||
results_dir / f"{model_name}-in-domain.csv",
|
||||
results_dir / f"{baseline_name}-in-domain.csv",
|
||||
)
|
||||
in_domain_agg_score_df.name = "value"
|
||||
in_domain_agg_score_df.index.name = "metric"
|
||||
|
||||
zero_shot_agg_score_df = agg_relative_score(
|
||||
results_dir / f"{model_name}-zero-shot.csv",
|
||||
results_dir / f"{baseline_name}-zero-shot.csv",
|
||||
)
|
||||
zero_shot_agg_score_df.name = "value"
|
||||
zero_shot_agg_score_df.index.name = "metric"
|
||||
|
||||
agg_score_df = pd.concat(
|
||||
{"in-domain": in_domain_agg_score_df, "zero-shot": zero_shot_agg_score_df},
|
||||
names=["benchmark"],
|
||||
)
|
||||
agg_score_df.to_csv(f"{results_dir}/{model_name}-agg-rel-scores.csv")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
|
@ -12,10 +12,15 @@ from gluonts.dataset.split import split
|
|||
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
|
||||
from gluonts.itertools import batcher
|
||||
from gluonts.model.evaluation import evaluate_forecasts
|
||||
from gluonts.model.forecast import SampleForecast
|
||||
from gluonts.model.forecast import QuantileForecast, SampleForecast
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from chronos import ChronosPipeline
|
||||
from chronos import (
|
||||
BaseChronosPipeline,
|
||||
ChronosBoltPipeline,
|
||||
ChronosPipeline,
|
||||
ForecastType,
|
||||
)
|
||||
|
||||
app = typer.Typer(pretty_exceptions_enable=False)
|
||||
|
||||
|
|
@ -228,37 +233,45 @@ def load_and_split_dataset(backtest_config: dict):
|
|||
return test_data
|
||||
|
||||
|
||||
def generate_sample_forecasts(
|
||||
def generate_forecasts(
|
||||
test_data_input: Iterable,
|
||||
pipeline: ChronosPipeline,
|
||||
pipeline: BaseChronosPipeline,
|
||||
prediction_length: int,
|
||||
batch_size: int,
|
||||
num_samples: int,
|
||||
**predict_kwargs,
|
||||
):
|
||||
# Generate forecast samples
|
||||
forecast_samples = []
|
||||
# Generate forecasts
|
||||
forecast_outputs = []
|
||||
for batch in tqdm(batcher(test_data_input, batch_size=batch_size)):
|
||||
context = [torch.tensor(entry["target"]) for entry in batch]
|
||||
forecast_samples.append(
|
||||
forecast_outputs.append(
|
||||
pipeline.predict(
|
||||
context,
|
||||
prediction_length=prediction_length,
|
||||
num_samples=num_samples,
|
||||
**predict_kwargs,
|
||||
).numpy()
|
||||
)
|
||||
forecast_samples = np.concatenate(forecast_samples)
|
||||
forecast_outputs = np.concatenate(forecast_outputs)
|
||||
|
||||
# Convert forecast samples into gluonts SampleForecast objects
|
||||
sample_forecasts = []
|
||||
for item, ts in zip(forecast_samples, test_data_input):
|
||||
# Convert forecast samples into gluonts Forecast objects
|
||||
forecasts = []
|
||||
for item, ts in zip(forecast_outputs, test_data_input):
|
||||
forecast_start_date = ts["start"] + len(ts["target"])
|
||||
sample_forecasts.append(
|
||||
SampleForecast(samples=item, start_date=forecast_start_date)
|
||||
)
|
||||
|
||||
return sample_forecasts
|
||||
if pipeline.forecast_type == ForecastType.SAMPLES:
|
||||
forecasts.append(
|
||||
SampleForecast(samples=item, start_date=forecast_start_date)
|
||||
)
|
||||
elif pipeline.forecast_type == ForecastType.QUANTILES:
|
||||
forecasts.append(
|
||||
QuantileForecast(
|
||||
forecast_arrays=item,
|
||||
forecast_keys=list(map(str, pipeline.quantiles)),
|
||||
start_date=forecast_start_date,
|
||||
)
|
||||
)
|
||||
|
||||
return forecasts
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
@ -274,17 +287,65 @@ def main(
|
|||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
):
|
||||
"""Evaluate Chronos models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : Path
|
||||
Path to the evaluation config. See ./configs/.
|
||||
metrics_path : Path
|
||||
Path to the CSV file where metrics will be saved.
|
||||
chronos_model_id : str, optional, default = "amazon/chronos-t5-small"
|
||||
HuggingFace ID of the Chronos model or local path
|
||||
Available models on HuggingFace:
|
||||
Chronos:
|
||||
- amazon/chronos-t5-tiny
|
||||
- amazon/chronos-t5-mini
|
||||
- amazon/chronos-t5-small
|
||||
- amazon/chronos-t5-base
|
||||
- amazon/chronos-t5-large
|
||||
Chronos-Bolt:
|
||||
- amazon/chronos-bolt-tiny
|
||||
- amazon/chronos-bolt-mini
|
||||
- amazon/chronos-bolt-small
|
||||
- amazon/chronos-bolt-base
|
||||
device : str, optional, default = "cuda"
|
||||
Device on which inference will be performed
|
||||
torch_dtype : str, optional
|
||||
Model's dtype, by default "bfloat16"
|
||||
batch_size : int, optional, default = 32
|
||||
Batch size for inference. For Chronos-Bolt models, significantly larger
|
||||
batch sizes can be used
|
||||
num_samples : int, optional, default = 20
|
||||
Number of samples to draw when using the original Chronos models
|
||||
temperature : Optional[float], optional, default = 1.0
|
||||
Softmax temperature to used for the original Chronos models
|
||||
top_k : Optional[int], optional, default = 50
|
||||
Top-K sampling, by default None
|
||||
top_p : Optional[float], optional, default = 1.0
|
||||
Top-p sampling, by default None
|
||||
"""
|
||||
if isinstance(torch_dtype, str):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
assert isinstance(torch_dtype, torch.dtype)
|
||||
|
||||
# Load Chronos
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
pipeline = BaseChronosPipeline.from_pretrained(
|
||||
chronos_model_id,
|
||||
device_map=device,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
if isinstance(pipeline, ChronosPipeline):
|
||||
predict_kwargs = dict(
|
||||
num_samples=num_samples,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
elif isinstance(pipeline, ChronosBoltPipeline):
|
||||
predict_kwargs = {}
|
||||
|
||||
# Load backtest configs
|
||||
with open(config_path) as fp:
|
||||
backtest_configs = yaml.safe_load(fp)
|
||||
|
|
@ -301,21 +362,18 @@ def main(
|
|||
f"Generating forecasts for {dataset_name} "
|
||||
f"({len(test_data.input)} time series)"
|
||||
)
|
||||
sample_forecasts = generate_sample_forecasts(
|
||||
forecasts = generate_forecasts(
|
||||
test_data.input,
|
||||
pipeline=pipeline,
|
||||
prediction_length=prediction_length,
|
||||
batch_size=batch_size,
|
||||
num_samples=num_samples,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
**predict_kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"Evaluating forecasts for {dataset_name}")
|
||||
metrics = (
|
||||
evaluate_forecasts(
|
||||
sample_forecasts,
|
||||
forecasts,
|
||||
test_data=test_data,
|
||||
metrics=[
|
||||
MASE(),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
benchmark,metric,value
|
||||
in-domain,MASE,0.6800133628315155
|
||||
in-domain,WQL,0.5339263811489279
|
||||
zero-shot,MASE,0.7914551113353537
|
||||
zero-shot,WQL,0.6241424984163773
|
||||
|
16
scripts/evaluation/results/chronos-bolt-base-in-domain.csv
Normal file
16
scripts/evaluation/results/chronos-bolt-base-in-domain.csv
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
dataset,model,MASE,WQL
|
||||
electricity_15min,amazon/chronos-bolt-base,0.41069374835605243,0.0703533790998506
|
||||
m4_daily,amazon/chronos-bolt-base,3.205192517121196,0.02110308498174413
|
||||
m4_hourly,amazon/chronos-bolt-base,0.8350129849014075,0.025353803894164
|
||||
m4_monthly,amazon/chronos-bolt-base,0.9491758928362231,0.09382496106659234
|
||||
m4_weekly,amazon/chronos-bolt-base,2.0847827409162742,0.03816605075768161
|
||||
monash_electricity_hourly,amazon/chronos-bolt-base,1.254966217685461,0.09442192616975713
|
||||
monash_electricity_weekly,amazon/chronos-bolt-base,1.8391546050108039,0.06410971963960499
|
||||
monash_kdd_cup_2018,amazon/chronos-bolt-base,0.6405985809360102,0.2509172188706336
|
||||
monash_london_smart_meters,amazon/chronos-bolt-base,0.701398572604996,0.3218915088923906
|
||||
monash_pedestrian_counts,amazon/chronos-bolt-base,0.2646412642278343,0.18789459806066328
|
||||
monash_rideshare,amazon/chronos-bolt-base,0.7695376426829713,0.11637119433040358
|
||||
monash_temperature_rain,amazon/chronos-bolt-base,0.8983612698773724,0.6050555216496304
|
||||
taxi_30min,amazon/chronos-bolt-base,0.7688908266765317,0.2363178601205094
|
||||
uber_tlc_daily,amazon/chronos-bolt-base,0.8231767493519677,0.0926036406916842
|
||||
uber_tlc_hourly,amazon/chronos-bolt-base,0.6632193728217927,0.14987786887626975
|
||||
|
28
scripts/evaluation/results/chronos-bolt-base-zero-shot.csv
Normal file
28
scripts/evaluation/results/chronos-bolt-base-zero-shot.csv
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
dataset,model,MASE,WQL
|
||||
ETTh,amazon/chronos-bolt-base,0.7479154031956647,0.07062173821055001
|
||||
ETTm,amazon/chronos-bolt-base,0.6334357237512225,0.052261607745858835
|
||||
dominick,amazon/chronos-bolt-base,0.8560272479913918,0.3453573743726445
|
||||
ercot,amazon/chronos-bolt-base,0.6933217425507392,0.02142183038021456
|
||||
exchange_rate,amazon/chronos-bolt-base,1.7095176257412634,0.01200682136751536
|
||||
m4_quarterly,amazon/chronos-bolt-base,1.2244670010522907,0.0771066518089854
|
||||
m4_yearly,amazon/chronos-bolt-base,3.513752058541554,0.12142798053483984
|
||||
m5,amazon/chronos-bolt-base,0.9152230096463854,0.561999688057527
|
||||
monash_australian_electricity,amazon/chronos-bolt-base,0.7403239930185613,0.03584034231329335
|
||||
monash_car_parts,amazon/chronos-bolt-base,0.8550263912438314,0.9945122291263591
|
||||
monash_cif_2016,amazon/chronos-bolt-base,0.9988541862779904,0.016456104842296485
|
||||
monash_covid_deaths,amazon/chronos-bolt-base,38.901749109066415,0.047410971217640714
|
||||
monash_fred_md,amazon/chronos-bolt-base,0.6468787708795645,0.04185083716355386
|
||||
monash_hospital,amazon/chronos-bolt-base,0.6883138394434054,0.057032869931903894
|
||||
monash_m1_monthly,amazon/chronos-bolt-base,1.0997677446267855,0.1392311148066238
|
||||
monash_m1_quarterly,amazon/chronos-bolt-base,1.7737851980875563,0.1007118219350403
|
||||
monash_m1_yearly,amazon/chronos-bolt-base,4.404672537832342,0.1504617654430952
|
||||
monash_m3_monthly,amazon/chronos-bolt-base,0.8510696834878182,0.09269673913736748
|
||||
monash_m3_quarterly,amazon/chronos-bolt-base,1.2890908822598466,0.07615133571216029
|
||||
monash_m3_yearly,amazon/chronos-bolt-base,2.9067097980770082,0.12934285625258413
|
||||
monash_nn5_weekly,amazon/chronos-bolt-base,0.9158766337957451,0.08352114810139548
|
||||
monash_tourism_monthly,amazon/chronos-bolt-base,1.5283388458731357,0.09026425492612797
|
||||
monash_tourism_quarterly,amazon/chronos-bolt-base,1.756127005530011,0.06448060953595125
|
||||
monash_tourism_yearly,amazon/chronos-bolt-base,3.691545772463519,0.16548820700844424
|
||||
monash_traffic,amazon/chronos-bolt-base,0.7843310867739336,0.23148632068725078
|
||||
monash_weather,amazon/chronos-bolt-base,0.8115247139672316,0.13350830777170594
|
||||
nn5,amazon/chronos-bolt-base,0.5764084996361287,0.1500519584148468
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
benchmark,metric,value
|
||||
in-domain,MASE,0.7268373301543752
|
||||
in-domain,WQL,0.565140251955324
|
||||
zero-shot,MASE,0.8221798917822493
|
||||
zero-shot,WQL,0.6441645845380903
|
||||
|
16
scripts/evaluation/results/chronos-bolt-mini-in-domain.csv
Normal file
16
scripts/evaluation/results/chronos-bolt-mini-in-domain.csv
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
dataset,model,MASE,WQL
|
||||
electricity_15min,amazon/chronos-bolt-mini,0.44185193304080733,0.0731477927531107
|
||||
m4_daily,amazon/chronos-bolt-mini,3.1342608828747456,0.0206872246743766
|
||||
m4_hourly,amazon/chronos-bolt-mini,0.9218285923038745,0.024383114886067574
|
||||
m4_monthly,amazon/chronos-bolt-mini,0.9628339921394529,0.09502498697494888
|
||||
m4_weekly,amazon/chronos-bolt-mini,2.2330452369879255,0.039393515325238534
|
||||
monash_electricity_hourly,amazon/chronos-bolt-mini,1.6195944363428718,0.11468972600782207
|
||||
monash_electricity_weekly,amazon/chronos-bolt-mini,1.866105365159433,0.06019900031840434
|
||||
monash_kdd_cup_2018,amazon/chronos-bolt-mini,0.74790954883436,0.3012661161484388
|
||||
monash_london_smart_meters,amazon/chronos-bolt-mini,0.7187830347765344,0.32984510693830227
|
||||
monash_pedestrian_counts,amazon/chronos-bolt-mini,0.308633944815819,0.23331301029432483
|
||||
monash_rideshare,amazon/chronos-bolt-mini,0.818948044410056,0.1297966960374544
|
||||
monash_temperature_rain,amazon/chronos-bolt-mini,0.9035244443682741,0.605031064086567
|
||||
taxi_30min,amazon/chronos-bolt-mini,0.812010120941363,0.25232294549917317
|
||||
uber_tlc_daily,amazon/chronos-bolt-mini,0.8507256206478295,0.10101757743084538
|
||||
uber_tlc_hourly,amazon/chronos-bolt-mini,0.6685484898085609,0.1515245941548974
|
||||
|
28
scripts/evaluation/results/chronos-bolt-mini-zero-shot.csv
Normal file
28
scripts/evaluation/results/chronos-bolt-mini-zero-shot.csv
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
dataset,model,MASE,WQL
|
||||
ETTh,amazon/chronos-bolt-mini,0.8057126710113404,0.07740387596411452
|
||||
ETTm,amazon/chronos-bolt-mini,0.6100793941108849,0.05129333450944573
|
||||
dominick,amazon/chronos-bolt-mini,0.8664152477208024,0.3499696999160997
|
||||
ercot,amazon/chronos-bolt-mini,0.6871250728215426,0.02448804863744021
|
||||
exchange_rate,amazon/chronos-bolt-mini,1.3520551553333662,0.00934663373172766
|
||||
m4_quarterly,amazon/chronos-bolt-mini,1.2569644266281508,0.07833787023275976
|
||||
m4_yearly,amazon/chronos-bolt-mini,3.7611003052413796,0.12931927951165456
|
||||
m5,amazon/chronos-bolt-mini,0.9188876472137485,0.5661303206519673
|
||||
monash_australian_electricity,amazon/chronos-bolt-mini,0.8823559450287066,0.04493688824488474
|
||||
monash_car_parts,amazon/chronos-bolt-mini,0.8604081423647779,1.0041876404811494
|
||||
monash_cif_2016,amazon/chronos-bolt-mini,1.0762361363763873,0.017641893717784202
|
||||
monash_covid_deaths,amazon/chronos-bolt-mini,38.83915011538576,0.06098317835750057
|
||||
monash_fred_md,amazon/chronos-bolt-mini,0.6169859211923081,0.03256236965040934
|
||||
monash_hospital,amazon/chronos-bolt-mini,0.6924431064606051,0.05766349075348645
|
||||
monash_m1_monthly,amazon/chronos-bolt-mini,1.147893030263777,0.13270222658510553
|
||||
monash_m1_quarterly,amazon/chronos-bolt-mini,1.8662100001165818,0.09846363409254102
|
||||
monash_m1_yearly,amazon/chronos-bolt-mini,5.319154632748303,0.16167328827180308
|
||||
monash_m3_monthly,amazon/chronos-bolt-mini,0.8758452776118432,0.09493431248614057
|
||||
monash_m3_quarterly,amazon/chronos-bolt-mini,1.3555175243802005,0.07808062465932723
|
||||
monash_m3_yearly,amazon/chronos-bolt-mini,3.605769430055575,0.15711010456482008
|
||||
monash_nn5_weekly,amazon/chronos-bolt-mini,0.9347141924977239,0.08522899825844342
|
||||
monash_tourism_monthly,amazon/chronos-bolt-mini,1.649587479665881,0.0979648261309891
|
||||
monash_tourism_quarterly,amazon/chronos-bolt-mini,1.8471553663088986,0.06501077791766902
|
||||
monash_tourism_yearly,amazon/chronos-bolt-mini,3.9932920493826245,0.1743539122097316
|
||||
monash_traffic,amazon/chronos-bolt-mini,0.8355442361271347,0.24351051123330386
|
||||
monash_weather,amazon/chronos-bolt-mini,0.800013628350165,0.13041050756802045
|
||||
nn5,amazon/chronos-bolt-mini,0.611917632501032,0.1570111102680171
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
benchmark,metric,value
|
||||
in-domain,MASE,0.7030801652116672
|
||||
in-domain,WQL,0.5443547623341555
|
||||
zero-shot,MASE,0.8192127745093378
|
||||
zero-shot,WQL,0.6356097843099521
|
||||
|
16
scripts/evaluation/results/chronos-bolt-small-in-domain.csv
Normal file
16
scripts/evaluation/results/chronos-bolt-small-in-domain.csv
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
dataset,model,MASE,WQL
|
||||
electricity_15min,amazon/chronos-bolt-small,0.44920089250026723,0.08115291306964295
|
||||
m4_daily,amazon/chronos-bolt-small,3.201966619014735,0.02143368277732494
|
||||
m4_hourly,amazon/chronos-bolt-small,0.8686298207618999,0.020368729287465817
|
||||
m4_monthly,amazon/chronos-bolt-small,0.9537717737278778,0.0939247807527992
|
||||
m4_weekly,amazon/chronos-bolt-small,2.1236755094789177,0.03785184715517262
|
||||
monash_electricity_hourly,amazon/chronos-bolt-small,1.3728906161330452,0.09452411472431674
|
||||
monash_electricity_weekly,amazon/chronos-bolt-small,1.8703239487242378,0.06648479071326366
|
||||
monash_kdd_cup_2018,amazon/chronos-bolt-small,0.6458631909979771,0.25148489931571666
|
||||
monash_london_smart_meters,amazon/chronos-bolt-small,0.7126939688565166,0.326874529903459
|
||||
monash_pedestrian_counts,amazon/chronos-bolt-small,0.3015070035798365,0.2285590441093863
|
||||
monash_rideshare,amazon/chronos-bolt-small,0.823726965741684,0.12409769473500927
|
||||
monash_temperature_rain,amazon/chronos-bolt-small,0.8980348827836525,0.5984819599873311
|
||||
taxi_30min,amazon/chronos-bolt-small,0.7597818149895785,0.2348569752311862
|
||||
uber_tlc_daily,amazon/chronos-bolt-small,0.8460854328036702,0.09666483354735897
|
||||
uber_tlc_hourly,amazon/chronos-bolt-small,0.6662547495017634,0.1524256346268063
|
||||
|
28
scripts/evaluation/results/chronos-bolt-small-zero-shot.csv
Normal file
28
scripts/evaluation/results/chronos-bolt-small-zero-shot.csv
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
dataset,model,MASE,WQL
|
||||
ETTh,amazon/chronos-bolt-small,0.792521748651108,0.07590654063011319
|
||||
ETTm,amazon/chronos-bolt-small,0.6209623928936988,0.05056189722606397
|
||||
dominick,amazon/chronos-bolt-small,0.8706134610400587,0.34811141409475416
|
||||
ercot,amazon/chronos-bolt-small,0.7562857616685997,0.02596064260343696
|
||||
exchange_rate,amazon/chronos-bolt-small,1.774835301692689,0.011363548847621512
|
||||
m4_quarterly,amazon/chronos-bolt-small,1.2478142413437487,0.07808795122806232
|
||||
m4_yearly,amazon/chronos-bolt-small,3.6925595655002574,0.12772564181388502
|
||||
m5,amazon/chronos-bolt-small,0.9195435643571084,0.5668430814831332
|
||||
monash_australian_electricity,amazon/chronos-bolt-small,0.8128424798841111,0.041509852162861564
|
||||
monash_car_parts,amazon/chronos-bolt-small,0.8584574663781737,1.0074689402521324
|
||||
monash_cif_2016,amazon/chronos-bolt-small,1.0182471909074982,0.01581964877692293
|
||||
monash_covid_deaths,amazon/chronos-bolt-small,36.467595559655145,0.0427382859406882
|
||||
monash_fred_md,amazon/chronos-bolt-small,0.6132863794635253,0.03730410577241995
|
||||
monash_hospital,amazon/chronos-bolt-small,0.6954489513780618,0.058119864671526154
|
||||
monash_m1_monthly,amazon/chronos-bolt-small,1.1277621848099244,0.1335656174632902
|
||||
monash_m1_quarterly,amazon/chronos-bolt-small,1.8356144904231688,0.09363028483838018
|
||||
monash_m1_yearly,amazon/chronos-bolt-small,5.098146069746402,0.15669928873371905
|
||||
monash_m3_monthly,amazon/chronos-bolt-small,0.8685125121306435,0.09396568468255145
|
||||
monash_m3_quarterly,amazon/chronos-bolt-small,1.3269103591066727,0.07691022995374203
|
||||
monash_m3_yearly,amazon/chronos-bolt-small,3.40993282700627,0.1547639821304127
|
||||
monash_nn5_weekly,amazon/chronos-bolt-small,0.9266513350636507,0.08452821221908001
|
||||
monash_tourism_monthly,amazon/chronos-bolt-small,1.6106732721197876,0.09362336754317802
|
||||
monash_tourism_quarterly,amazon/chronos-bolt-small,1.8357819365308639,0.06734337535269994
|
||||
monash_tourism_yearly,amazon/chronos-bolt-small,3.8963100495394194,0.16766064312072784
|
||||
monash_traffic,amazon/chronos-bolt-small,0.8598507749866499,0.25173786112983054
|
||||
monash_weather,amazon/chronos-bolt-small,0.8020408743877911,0.13258563963844888
|
||||
nn5,amazon/chronos-bolt-small,0.5833047644729239,0.15066847836762787
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
benchmark,metric,value
|
||||
in-domain,MASE,0.7403252781013574
|
||||
in-domain,WQL,0.5733728165523524
|
||||
zero-shot,MASE,0.8445407343705457
|
||||
zero-shot,WQL,0.6678781905023173
|
||||
|
16
scripts/evaluation/results/chronos-bolt-tiny-in-domain.csv
Normal file
16
scripts/evaluation/results/chronos-bolt-tiny-in-domain.csv
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
dataset,model,MASE,WQL
|
||||
electricity_15min,amazon/chronos-bolt-tiny,0.4676384089765091,0.0861229808117837
|
||||
m4_daily,amazon/chronos-bolt-tiny,3.1789994761356795,0.020961883512815756
|
||||
m4_hourly,amazon/chronos-bolt-tiny,0.9348005698736752,0.021087527284114574
|
||||
m4_monthly,amazon/chronos-bolt-tiny,0.965298729632761,0.0950380483243082
|
||||
m4_weekly,amazon/chronos-bolt-tiny,2.261575511029903,0.04093653263178429
|
||||
monash_electricity_hourly,amazon/chronos-bolt-tiny,1.5739346351263623,0.10808418398945202
|
||||
monash_electricity_weekly,amazon/chronos-bolt-tiny,1.8628689103722829,0.05773335283584782
|
||||
monash_kdd_cup_2018,amazon/chronos-bolt-tiny,0.6869549985391232,0.28012801092758166
|
||||
monash_london_smart_meters,amazon/chronos-bolt-tiny,0.7284234905933779,0.33496438244693033
|
||||
monash_pedestrian_counts,amazon/chronos-bolt-tiny,0.32338947321773864,0.2530637833749087
|
||||
monash_rideshare,amazon/chronos-bolt-tiny,0.8562780835002918,0.1304317657933891
|
||||
monash_temperature_rain,amazon/chronos-bolt-tiny,0.9030707620825977,0.6064087080755548
|
||||
taxi_30min,amazon/chronos-bolt-tiny,0.9122159603256838,0.28002194370731626
|
||||
uber_tlc_daily,amazon/chronos-bolt-tiny,0.9087055420190513,0.11193388685815164
|
||||
uber_tlc_hourly,amazon/chronos-bolt-tiny,0.6716569179590032,0.15310845458208555
|
||||
|
28
scripts/evaluation/results/chronos-bolt-tiny-zero-shot.csv
Normal file
28
scripts/evaluation/results/chronos-bolt-tiny-zero-shot.csv
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
dataset,model,MASE,WQL
|
||||
ETTh,amazon/chronos-bolt-tiny,0.7941225847155844,0.07480860969990633
|
||||
ETTm,amazon/chronos-bolt-tiny,0.6508270995240056,0.05440068825429993
|
||||
dominick,amazon/chronos-bolt-tiny,0.876060127216559,0.35175949052933253
|
||||
ercot,amazon/chronos-bolt-tiny,0.7309134980173839,0.02468604544464515
|
||||
exchange_rate,amazon/chronos-bolt-tiny,1.6857262567077134,0.011477224264784112
|
||||
m4_quarterly,amazon/chronos-bolt-tiny,1.2605908919338378,0.0789049420017836
|
||||
m4_yearly,amazon/chronos-bolt-tiny,3.7118394116161757,0.1286932555969197
|
||||
m5,amazon/chronos-bolt-tiny,0.9195469670062033,0.5634881835998845
|
||||
monash_australian_electricity,amazon/chronos-bolt-tiny,0.8419304693259403,0.042040993880313904
|
||||
monash_car_parts,amazon/chronos-bolt-tiny,0.8625579150452282,1.0009987800801836
|
||||
monash_cif_2016,amazon/chronos-bolt-tiny,1.095219642027011,0.017550336784241796
|
||||
monash_covid_deaths,amazon/chronos-bolt-tiny,40.674057986280744,0.06723714516685976
|
||||
monash_fred_md,amazon/chronos-bolt-tiny,0.6127387450520702,0.04747523852271518
|
||||
monash_hospital,amazon/chronos-bolt-tiny,0.6980246281225624,0.05864223243167421
|
||||
monash_m1_monthly,amazon/chronos-bolt-tiny,1.1625495971731141,0.13142237467151166
|
||||
monash_m1_quarterly,amazon/chronos-bolt-tiny,1.8941765599193754,0.09972207844232561
|
||||
monash_m1_yearly,amazon/chronos-bolt-tiny,5.136332694531757,0.160331813128038
|
||||
monash_m3_monthly,amazon/chronos-bolt-tiny,0.8744553726704598,0.09435519378597752
|
||||
monash_m3_quarterly,amazon/chronos-bolt-tiny,1.364563776692303,0.07875066385737857
|
||||
monash_m3_yearly,amazon/chronos-bolt-tiny,3.3685961410254928,0.15158076519486274
|
||||
monash_nn5_weekly,amazon/chronos-bolt-tiny,0.9324436794013877,0.0847385189968909
|
||||
monash_tourism_monthly,amazon/chronos-bolt-tiny,1.7895936775088157,0.1058167042693116
|
||||
monash_tourism_quarterly,amazon/chronos-bolt-tiny,2.095262637810499,0.0710732570354461
|
||||
monash_tourism_yearly,amazon/chronos-bolt-tiny,4.042821441327848,0.172613367251472
|
||||
monash_traffic,amazon/chronos-bolt-tiny,0.8836032533767518,0.2574297134210491
|
||||
monash_weather,amazon/chronos-bolt-tiny,0.8005348255663177,0.13111355494466076
|
||||
nn5,amazon/chronos-bolt-tiny,0.7228248498869763,0.1816913098894226
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
benchmark,metric,value
|
||||
in-domain,MASE,0.7007558507277635
|
||||
in-domain,WQL,0.5786300105297922
|
||||
zero-shot,MASE,0.8155209321160994
|
||||
zero-shot,WQL,0.6424634919486323
|
||||
|
16
scripts/evaluation/results/chronos-t5-base-in-domain.csv
Normal file
16
scripts/evaluation/results/chronos-t5-base-in-domain.csv
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
dataset,model,MASE,WQL
|
||||
electricity_15min,amazon/chronos-t5-base,0.39879754957261204,0.07738953262286181
|
||||
m4_daily,amazon/chronos-t5-base,3.160575865614404,0.02194256368254537
|
||||
m4_hourly,amazon/chronos-t5-base,0.6938747745332102,0.026354948301302205
|
||||
m4_monthly,amazon/chronos-t5-base,0.971951848755026,0.10355213196432872
|
||||
m4_weekly,amazon/chronos-t5-base,2.0143841267657945,0.03639741235815474
|
||||
monash_electricity_hourly,amazon/chronos-t5-base,1.5717251971297332,0.1078882125804548
|
||||
monash_electricity_weekly,amazon/chronos-t5-base,1.7862927210886668,0.06255982783148449
|
||||
monash_kdd_cup_2018,amazon/chronos-t5-base,0.6335225775496138,0.2684272353843692
|
||||
monash_london_smart_meters,amazon/chronos-t5-base,0.8362014889190201,0.4265549499082726
|
||||
monash_pedestrian_counts,amazon/chronos-t5-base,0.2817708325561419,0.20810108090665583
|
||||
monash_rideshare,amazon/chronos-t5-base,0.8614480533175364,0.1356591190888703
|
||||
monash_temperature_rain,amazon/chronos-t5-base,0.9692405156151607,0.660155448791624
|
||||
taxi_30min,amazon/chronos-t5-base,0.8186287575356217,0.26236060366367003
|
||||
uber_tlc_daily,amazon/chronos-t5-base,0.8338648311528079,0.0970875577681834
|
||||
uber_tlc_hourly,amazon/chronos-t5-base,0.6647193438331641,0.15436646659512512
|
||||
|
28
scripts/evaluation/results/chronos-t5-base-zero-shot.csv
Normal file
28
scripts/evaluation/results/chronos-t5-base-zero-shot.csv
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
dataset,model,MASE,WQL
|
||||
ETTh,amazon/chronos-t5-base,0.7653491494991778,0.08087267701042929
|
||||
ETTm,amazon/chronos-t5-base,0.7737006634032871,0.07008650633028274
|
||||
dominick,amazon/chronos-t5-base,0.8194044957573132,0.33201307438298133
|
||||
ercot,amazon/chronos-t5-base,0.5014399265038706,0.013589435745554596
|
||||
exchange_rate,amazon/chronos-t5-base,2.055616906406159,0.011066070028466317
|
||||
m4_quarterly,amazon/chronos-t5-base,1.2253036947743137,0.08327936201395683
|
||||
m4_yearly,amazon/chronos-t5-base,3.639991540990927,0.13539258375263963
|
||||
m5,amazon/chronos-t5-base,0.9391874615167101,0.5867234116216755
|
||||
monash_australian_electricity,amazon/chronos-t5-base,1.2944069383163321,0.07070604202031877
|
||||
monash_car_parts,amazon/chronos-t5-base,0.9071940271035218,1.077797124337994
|
||||
monash_cif_2016,amazon/chronos-t5-base,0.9840747802099565,0.011825556826558836
|
||||
monash_covid_deaths,amazon/chronos-t5-base,42.68503365359237,0.042229910495746356
|
||||
monash_fred_md,amazon/chronos-t5-base,0.4857773806790164,0.021204829049512715
|
||||
monash_hospital,amazon/chronos-t5-base,0.7053005021431749,0.05630687524507516
|
||||
monash_m1_monthly,amazon/chronos-t5-base,1.1153039466137842,0.12724419775326076
|
||||
monash_m1_quarterly,amazon/chronos-t5-base,1.746093728928804,0.1123583549291933
|
||||
monash_m1_yearly,amazon/chronos-t5-base,4.401291522370069,0.18541586641719554
|
||||
monash_m3_monthly,amazon/chronos-t5-base,0.8627172231908679,0.09640536232169555
|
||||
monash_m3_quarterly,amazon/chronos-t5-base,1.1696030904401578,0.07392876900131434
|
||||
monash_m3_yearly,amazon/chronos-t5-base,3.1298600218573775,0.1486674447940158
|
||||
monash_nn5_weekly,amazon/chronos-t5-base,0.9334860602210187,0.08972736821598823
|
||||
monash_tourism_monthly,amazon/chronos-t5-base,1.7937702435879332,0.10260220444264027
|
||||
monash_tourism_quarterly,amazon/chronos-t5-base,1.7791494997972261,0.06852507950474919
|
||||
monash_tourism_yearly,amazon/chronos-t5-base,3.8359926053603197,0.20722699382964643
|
||||
monash_traffic,amazon/chronos-t5-base,0.8015262383138622,0.25565153982140926
|
||||
monash_weather,amazon/chronos-t5-base,0.8159511190589147,0.13802320967454584
|
||||
nn5,amazon/chronos-t5-base,0.5927076179914024,0.1630476065585159
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
benchmark,metric,value
|
||||
in-domain,MASE,0.6944869734691035
|
||||
in-domain,WQL,0.5596857927462495
|
||||
zero-shot,MASE,0.8213682201405101
|
||||
zero-shot,WQL,0.6504834081319559
|
||||
|
16
scripts/evaluation/results/chronos-t5-large-in-domain.csv
Normal file
16
scripts/evaluation/results/chronos-t5-large-in-domain.csv
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
dataset,model,MASE,WQL
|
||||
electricity_15min,amazon/chronos-t5-large,0.3866310906621673,0.07759528615667297
|
||||
m4_daily,amazon/chronos-t5-large,3.134560968849699,0.02158279722410466
|
||||
m4_hourly,amazon/chronos-t5-large,0.6975930649233378,0.02086427219957674
|
||||
m4_monthly,amazon/chronos-t5-large,0.9585550091429409,0.10091221432814867
|
||||
m4_weekly,amazon/chronos-t5-large,2.0191422600104425,0.036912838355537186
|
||||
monash_electricity_hourly,amazon/chronos-t5-large,1.4069912853901292,0.09642382339452431
|
||||
monash_electricity_weekly,amazon/chronos-t5-large,1.7501880036182798,0.05765306465830232
|
||||
monash_kdd_cup_2018,amazon/chronos-t5-large,0.6788042816175427,0.2853553329804835
|
||||
monash_london_smart_meters,amazon/chronos-t5-large,0.8290300790418726,0.4235436387853963
|
||||
monash_pedestrian_counts,amazon/chronos-t5-large,0.2764118100521592,0.18692234491663473
|
||||
monash_rideshare,amazon/chronos-t5-large,0.8758058784466208,0.140260325368757
|
||||
monash_temperature_rain,amazon/chronos-t5-large,0.9738403865035117,0.6604571928063249
|
||||
taxi_30min,amazon/chronos-t5-large,0.8245662397270109,0.2653520120326771
|
||||
uber_tlc_daily,amazon/chronos-t5-large,0.8044165990021739,0.09499035584302248
|
||||
uber_tlc_hourly,amazon/chronos-t5-large,0.6700665937164474,0.15190288476653066
|
||||
|
28
scripts/evaluation/results/chronos-t5-large-zero-shot.csv
Normal file
28
scripts/evaluation/results/chronos-t5-large-zero-shot.csv
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
dataset,model,MASE,WQL
|
||||
ETTh,amazon/chronos-t5-large,0.78160443631164,0.07884375667736107
|
||||
ETTm,amazon/chronos-t5-large,0.7325919639389967,0.06656858270921162
|
||||
dominick,amazon/chronos-t5-large,0.8200108271155829,0.3311575649734524
|
||||
ercot,amazon/chronos-t5-large,0.6050812633742764,0.01822996942395577
|
||||
exchange_rate,amazon/chronos-t5-large,2.3439287001928744,0.014841231672174684
|
||||
m4_quarterly,amazon/chronos-t5-large,1.2169666607868148,0.08235162400898562
|
||||
m4_yearly,amazon/chronos-t5-large,3.5524979814018947,0.1325675848907479
|
||||
m5,amazon/chronos-t5-large,0.9422990989146737,0.585615077637479
|
||||
monash_australian_electricity,amazon/chronos-t5-large,1.480849838497958,0.07973968848149568
|
||||
monash_car_parts,amazon/chronos-t5-large,0.901547374873302,1.0467398096496576
|
||||
monash_cif_2016,amazon/chronos-t5-large,0.9906388185665337,0.011966178555329998
|
||||
monash_covid_deaths,amazon/chronos-t5-large,44.07354193681227,0.06108999981222163
|
||||
monash_fred_md,amazon/chronos-t5-large,0.5184400880318044,0.01675533888399231
|
||||
monash_hospital,amazon/chronos-t5-large,0.7055308474630898,0.0552450850258613
|
||||
monash_m1_monthly,amazon/chronos-t5-large,1.0888995301234758,0.12729911122909737
|
||||
monash_m1_quarterly,amazon/chronos-t5-large,1.7477134564031453,0.10618253695380094
|
||||
monash_m1_yearly,amazon/chronos-t5-large,4.250667049416348,0.17128879333643188
|
||||
monash_m3_monthly,amazon/chronos-t5-large,0.8559326975903808,0.09572577431396007
|
||||
monash_m3_quarterly,amazon/chronos-t5-large,1.1867267751420676,0.07449254281607631
|
||||
monash_m3_yearly,amazon/chronos-t5-large,3.0239493021840635,0.14814710375646464
|
||||
monash_nn5_weekly,amazon/chronos-t5-large,0.9228721852437364,0.08948447200571868
|
||||
monash_tourism_monthly,amazon/chronos-t5-large,1.7304427846580348,0.09983169221760163
|
||||
monash_tourism_quarterly,amazon/chronos-t5-large,1.6437184365114073,0.0690906057781915
|
||||
monash_tourism_yearly,amazon/chronos-t5-large,3.6268503118928535,0.17732007043832695
|
||||
monash_traffic,amazon/chronos-t5-large,0.7985975530866148,0.25313515740581755
|
||||
monash_weather,amazon/chronos-t5-large,0.8187388457436171,0.1387756772600068
|
||||
nn5,amazon/chronos-t5-large,0.5755260854173723,0.15733693855465292
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
benchmark,metric,value
|
||||
in-domain,MASE,0.7249816823595568
|
||||
in-domain,WQL,0.5965372489622094
|
||||
zero-shot,MASE,0.8411995116926901
|
||||
zero-shot,WQL,0.6888397962259065
|
||||
|
16
scripts/evaluation/results/chronos-t5-mini-in-domain.csv
Normal file
16
scripts/evaluation/results/chronos-t5-mini-in-domain.csv
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
dataset,model,MASE,WQL
|
||||
electricity_15min,amazon/chronos-t5-mini,0.4446629660227641,0.08114657599239496
|
||||
m4_daily,amazon/chronos-t5-mini,3.1533349226194005,0.022000507013584743
|
||||
m4_hourly,amazon/chronos-t5-mini,0.7616830292996938,0.024630575107847653
|
||||
m4_monthly,amazon/chronos-t5-mini,0.9934074425853089,0.10402168689068064
|
||||
m4_weekly,amazon/chronos-t5-mini,2.1407189608104416,0.04138058102434373
|
||||
monash_electricity_hourly,amazon/chronos-t5-mini,1.3698378948313894,0.09189698159081384
|
||||
monash_electricity_weekly,amazon/chronos-t5-mini,1.9238345295706893,0.07015383787479901
|
||||
monash_kdd_cup_2018,amazon/chronos-t5-mini,0.6027861468526459,0.25493489598663444
|
||||
monash_london_smart_meters,amazon/chronos-t5-mini,0.8570035850603943,0.4356582737588471
|
||||
monash_pedestrian_counts,amazon/chronos-t5-mini,0.30374539593979855,0.2374083216051065
|
||||
monash_rideshare,amazon/chronos-t5-mini,0.8157349455509949,0.12963515638823117
|
||||
monash_temperature_rain,amazon/chronos-t5-mini,1.010161905102516,0.6919171702485583
|
||||
taxi_30min,amazon/chronos-t5-mini,0.9318379552979712,0.31229508015999674
|
||||
uber_tlc_daily,amazon/chronos-t5-mini,0.9213437323817685,0.10475291429149586
|
||||
uber_tlc_hourly,amazon/chronos-t5-mini,0.6812621470377416,0.15982192635434303
|
||||
|
28
scripts/evaluation/results/chronos-t5-mini-zero-shot.csv
Normal file
28
scripts/evaluation/results/chronos-t5-mini-zero-shot.csv
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
dataset,model,MASE,WQL
|
||||
ETTh,amazon/chronos-t5-mini,0.789678971785092,0.08068969536800001
|
||||
ETTm,amazon/chronos-t5-mini,0.7521219674190734,0.06791782942706617
|
||||
dominick,amazon/chronos-t5-mini,0.8207116999488602,0.34004499734299765
|
||||
ercot,amazon/chronos-t5-mini,0.5462749489237783,0.015035001020343136
|
||||
exchange_rate,amazon/chronos-t5-mini,2.1326718165798657,0.015073846769933199
|
||||
m4_quarterly,amazon/chronos-t5-mini,1.271761811062081,0.08575942238385105
|
||||
m4_yearly,amazon/chronos-t5-mini,3.7340853642679126,0.13938781939783162
|
||||
m5,amazon/chronos-t5-mini,0.9421556321929742,0.5961689098871504
|
||||
monash_australian_electricity,amazon/chronos-t5-mini,1.046297291920238,0.05424453772723559
|
||||
monash_car_parts,amazon/chronos-t5-mini,0.8913523483805221,1.0174797526818506
|
||||
monash_cif_2016,amazon/chronos-t5-mini,1.0674111822055679,0.016800831829085764
|
||||
monash_covid_deaths,amazon/chronos-t5-mini,43.69727825485175,0.08788117644141617
|
||||
monash_fred_md,amazon/chronos-t5-mini,0.46227452519609524,0.01871860604459728
|
||||
monash_hospital,amazon/chronos-t5-mini,0.7112593459108532,0.05831005112661489
|
||||
monash_m1_monthly,amazon/chronos-t5-mini,1.1756557848450433,0.14192178371159841
|
||||
monash_m1_quarterly,amazon/chronos-t5-mini,1.795009199698074,0.11760148522768847
|
||||
monash_m1_yearly,amazon/chronos-t5-mini,5.078889706085604,0.1882823108615221
|
||||
monash_m3_monthly,amazon/chronos-t5-mini,0.900404391663476,0.09935931092075681
|
||||
monash_m3_quarterly,amazon/chronos-t5-mini,1.2604342624229292,0.07807204797138119
|
||||
monash_m3_yearly,amazon/chronos-t5-mini,3.4395976709464255,0.16085249526114198
|
||||
monash_nn5_weekly,amazon/chronos-t5-mini,0.9459117943913629,0.09042762527674755
|
||||
monash_tourism_monthly,amazon/chronos-t5-mini,1.920865545569713,0.10791754513335952
|
||||
monash_tourism_quarterly,amazon/chronos-t5-mini,1.7957439111869486,0.07514539225156464
|
||||
monash_tourism_yearly,amazon/chronos-t5-mini,4.134958090482728,0.2202036957350168
|
||||
monash_traffic,amazon/chronos-t5-mini,0.8546792774237857,0.2668831661775284
|
||||
monash_weather,amazon/chronos-t5-mini,0.8607748244159247,0.15031866806333247
|
||||
nn5,amazon/chronos-t5-mini,0.6497211196906223,0.17352254058241523
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
benchmark,metric,value
|
||||
in-domain,MASE,0.7296140269944743
|
||||
in-domain,WQL,0.6086958548874499
|
||||
zero-shot,MASE,0.8303721909132112
|
||||
zero-shot,WQL,0.6649587072099045
|
||||
|
16
scripts/evaluation/results/chronos-t5-small-in-domain.csv
Normal file
16
scripts/evaluation/results/chronos-t5-small-in-domain.csv
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
dataset,model,MASE,WQL
|
||||
electricity_15min,amazon/chronos-t5-small,0.4115559557750193,0.08085148902238105
|
||||
m4_daily,amazon/chronos-t5-small,3.1384304946608896,0.02129901023419818
|
||||
m4_hourly,amazon/chronos-t5-small,0.7300874075370588,0.024686127211237932
|
||||
m4_monthly,amazon/chronos-t5-small,0.9797264456494642,0.10297069145186107
|
||||
m4_weekly,amazon/chronos-t5-small,2.0802214537692607,0.03959222330783002
|
||||
monash_electricity_hourly,amazon/chronos-t5-small,1.530308399040219,0.10765947926209926
|
||||
monash_electricity_weekly,amazon/chronos-t5-small,1.9249616494404531,0.07593976499899265
|
||||
monash_kdd_cup_2018,amazon/chronos-t5-small,0.6911172359201715,0.2863722811236367
|
||||
monash_london_smart_meters,amazon/chronos-t5-small,0.8405756252443325,0.4300875548402115
|
||||
monash_pedestrian_counts,amazon/chronos-t5-small,0.30836963006151696,0.2442543970311678
|
||||
monash_rideshare,amazon/chronos-t5-small,0.8436277753840817,0.1363421932158997
|
||||
monash_temperature_rain,amazon/chronos-t5-small,1.0176003932416664,0.6847726381172435
|
||||
taxi_30min,amazon/chronos-t5-small,0.976277213614167,0.32770172988517626
|
||||
uber_tlc_daily,amazon/chronos-t5-small,0.8694727058784919,0.0994889223610958
|
||||
uber_tlc_hourly,amazon/chronos-t5-small,0.6738672444888639,0.1573990617753753
|
||||
|
28
scripts/evaluation/results/chronos-t5-small-zero-shot.csv
Normal file
28
scripts/evaluation/results/chronos-t5-small-zero-shot.csv
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
dataset,model,MASE,WQL
|
||||
ETTh,amazon/chronos-t5-small,0.8516754221042285,0.08667817580712385
|
||||
ETTm,amazon/chronos-t5-small,0.6825432730635727,0.06076472147001207
|
||||
dominick,amazon/chronos-t5-small,0.8108766032127683,0.3368104617474581
|
||||
ercot,amazon/chronos-t5-small,0.564879593858422,0.015547628920969682
|
||||
exchange_rate,amazon/chronos-t5-small,1.8143459139100264,0.014492477372711763
|
||||
m4_quarterly,amazon/chronos-t5-small,1.2415331521819728,0.08383826063189778
|
||||
m4_yearly,amazon/chronos-t5-small,3.738749650935195,0.1384514201649314
|
||||
m5,amazon/chronos-t5-small,0.9368713240675598,0.5896066252181699
|
||||
monash_australian_electricity,amazon/chronos-t5-small,1.2241146217392032,0.06951399165882449
|
||||
monash_car_parts,amazon/chronos-t5-small,0.8917508090523597,1.0314986717260015
|
||||
monash_cif_2016,amazon/chronos-t5-small,1.0187937383419037,0.014633240218233142
|
||||
monash_covid_deaths,amazon/chronos-t5-small,42.298997211368935,0.06339512778191682
|
||||
monash_fred_md,amazon/chronos-t5-small,0.4742159923922472,0.01486734736993978
|
||||
monash_hospital,amazon/chronos-t5-small,0.709814741753487,0.05704674270057172
|
||||
monash_m1_monthly,amazon/chronos-t5-small,1.1723041163998773,0.13799049510465802
|
||||
monash_m1_quarterly,amazon/chronos-t5-small,1.8077827825737092,0.11323432989795904
|
||||
monash_m1_yearly,amazon/chronos-t5-small,4.739967673537301,0.1730738338876877
|
||||
monash_m3_monthly,amazon/chronos-t5-small,0.8856577322724943,0.09985251429658573
|
||||
monash_m3_quarterly,amazon/chronos-t5-small,1.278907982396775,0.08094041554590593
|
||||
monash_m3_yearly,amazon/chronos-t5-small,3.382470310192457,0.157363937435307
|
||||
monash_nn5_weekly,amazon/chronos-t5-small,0.9277396908126303,0.08963913763368506
|
||||
monash_tourism_monthly,amazon/chronos-t5-small,1.9251180766131313,0.10943962474253494
|
||||
monash_tourism_quarterly,amazon/chronos-t5-small,1.7623454951333655,0.06862432764377493
|
||||
monash_tourism_yearly,amazon/chronos-t5-small,3.987690476709746,0.19960492460202509
|
||||
monash_traffic,amazon/chronos-t5-small,0.8204223927835267,0.2571189517024486
|
||||
monash_weather,amazon/chronos-t5-small,0.8550633590487968,0.1479701971025123
|
||||
nn5,amazon/chronos-t5-small,0.6130789183153671,0.16771392719859998
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
benchmark,metric,value
|
||||
in-domain,MASE,0.7649019745781727
|
||||
in-domain,WQL,0.6288613368129368
|
||||
zero-shot,MASE,0.8704764463925718
|
||||
zero-shot,WQL,0.7108912052035352
|
||||
|
16
scripts/evaluation/results/chronos-t5-tiny-in-domain.csv
Normal file
16
scripts/evaluation/results/chronos-t5-tiny-in-domain.csv
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
dataset,model,MASE,WQL
|
||||
electricity_15min,amazon/chronos-t5-tiny,0.5091784254243783,0.08236334376190152
|
||||
m4_daily,amazon/chronos-t5-tiny,3.203164895930929,0.022152192084951595
|
||||
m4_hourly,amazon/chronos-t5-tiny,0.8171321441164723,0.027490760558343874
|
||||
m4_monthly,amazon/chronos-t5-tiny,1.005839207921131,0.10388015368939435
|
||||
m4_weekly,amazon/chronos-t5-tiny,2.2148332313370735,0.043429655561156084
|
||||
monash_electricity_hourly,amazon/chronos-t5-tiny,1.6190021089002615,0.10967453530956882
|
||||
monash_electricity_weekly,amazon/chronos-t5-tiny,2.0774597917676734,0.08159998975612164
|
||||
monash_kdd_cup_2018,amazon/chronos-t5-tiny,0.6730886827096076,0.2616610603634618
|
||||
monash_london_smart_meters,amazon/chronos-t5-tiny,0.8830447519225436,0.4499607073491794
|
||||
monash_pedestrian_counts,amazon/chronos-t5-tiny,0.3042105240185045,0.23387631681117071
|
||||
monash_rideshare,amazon/chronos-t5-tiny,0.8431350112476247,0.1378817076926394
|
||||
monash_temperature_rain,amazon/chronos-t5-tiny,0.9887398447367799,0.6957797286648015
|
||||
taxi_30min,amazon/chronos-t5-tiny,1.035544060665179,0.3450476958104713
|
||||
uber_tlc_daily,amazon/chronos-t5-tiny,0.93025919000775,0.1105323649942084
|
||||
uber_tlc_hourly,amazon/chronos-t5-tiny,0.697558054147913,0.16320255844336232
|
||||
|
28
scripts/evaluation/results/chronos-t5-tiny-zero-shot.csv
Normal file
28
scripts/evaluation/results/chronos-t5-tiny-zero-shot.csv
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
dataset,model,MASE,WQL
|
||||
ETTh,amazon/chronos-t5-tiny,0.8184074113571701,0.08578203438707048
|
||||
ETTm,amazon/chronos-t5-tiny,0.9103621000781905,0.07975361086322658
|
||||
dominick,amazon/chronos-t5-tiny,0.8538295532466194,0.3597090770361857
|
||||
ercot,amazon/chronos-t5-tiny,0.7273437589773705,0.020843170924006626
|
||||
exchange_rate,amazon/chronos-t5-tiny,1.6621128608546154,0.01085145980896454
|
||||
m4_quarterly,amazon/chronos-t5-tiny,1.2696259955861924,0.0861404188925996
|
||||
m4_yearly,amazon/chronos-t5-tiny,3.5293881164900527,0.13281575565500411
|
||||
m5,amazon/chronos-t5-tiny,0.9394059505709506,0.5981531758388589
|
||||
monash_australian_electricity,amazon/chronos-t5-tiny,1.4558820561269024,0.07673567331332948
|
||||
monash_car_parts,amazon/chronos-t5-tiny,0.9058206654011024,1.0236307963149358
|
||||
monash_cif_2016,amazon/chronos-t5-tiny,1.09349564130852,0.014066593076202984
|
||||
monash_covid_deaths,amazon/chronos-t5-tiny,46.53079664940016,0.09201919385053775
|
||||
monash_fred_md,amazon/chronos-t5-tiny,0.48008374212956456,0.03219550761153211
|
||||
monash_hospital,amazon/chronos-t5-tiny,0.7062562198194838,0.05790409320432609
|
||||
monash_m1_monthly,amazon/chronos-t5-tiny,1.214892145549996,0.14723095246308077
|
||||
monash_m1_quarterly,amazon/chronos-t5-tiny,1.8968576926613199,0.11026972972622998
|
||||
monash_m1_yearly,amazon/chronos-t5-tiny,4.829453202075546,0.17286063726000958
|
||||
monash_m3_monthly,amazon/chronos-t5-tiny,0.9095746605884618,0.10117875324490073
|
||||
monash_m3_quarterly,amazon/chronos-t5-tiny,1.3234957548639883,0.08209032993637215
|
||||
monash_m3_yearly,amazon/chronos-t5-tiny,3.1489371074890093,0.1492445630072877
|
||||
monash_nn5_weekly,amazon/chronos-t5-tiny,0.9637480731663901,0.09205994784693056
|
||||
monash_tourism_monthly,amazon/chronos-t5-tiny,2.151677532807024,0.11356761694754255
|
||||
monash_tourism_quarterly,amazon/chronos-t5-tiny,1.9116538900950555,0.07191734222366106
|
||||
monash_tourism_yearly,amazon/chronos-t5-tiny,3.820615532600914,0.19709256337364625
|
||||
monash_traffic,amazon/chronos-t5-tiny,0.878709088458116,0.2632101606272236
|
||||
monash_weather,amazon/chronos-t5-tiny,0.8504899606521996,0.14787595319625085
|
||||
nn5,amazon/chronos-t5-tiny,0.7021735456568664,0.19071330483289695
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .base import BaseChronosPipeline, ForecastType
|
||||
from .chronos import (
|
||||
ChronosConfig,
|
||||
ChronosModel,
|
||||
|
|
@ -8,11 +9,16 @@ from .chronos import (
|
|||
ChronosTokenizer,
|
||||
MeanScaleUniformBins,
|
||||
)
|
||||
from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline
|
||||
|
||||
__all__ = [
|
||||
"BaseChronosPipeline",
|
||||
"ForecastType",
|
||||
"ChronosConfig",
|
||||
"ChronosModel",
|
||||
"ChronosPipeline",
|
||||
"ChronosTokenizer",
|
||||
"MeanScaleUniformBins",
|
||||
"ChronosBoltConfig",
|
||||
"ChronosBoltPipeline",
|
||||
]
|
||||
|
|
|
|||
162
src/chronos/base.py
Normal file
162
src/chronos/base.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Authors: Caner Turkmen <atturkm@amazon.com>, Abdul Fatir Ansari <ansarnd@amazon.com>, Lorenzo Stella <stellalo@amazon.com>
|
||||
# Original source:
|
||||
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/src/autogluon/timeseries/models/chronos/pipeline/base.py
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from .utils import left_pad_and_stack_1D
|
||||
|
||||
|
||||
class ForecastType(Enum):
|
||||
SAMPLES = "samples"
|
||||
QUANTILES = "quantiles"
|
||||
|
||||
|
||||
class PipelineRegistry(type):
|
||||
REGISTRY: Dict[str, "PipelineRegistry"] = {}
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
"""See, https://github.com/faif/python-patterns."""
|
||||
new_cls = type.__new__(cls, name, bases, attrs)
|
||||
if name is not None:
|
||||
cls.REGISTRY[name] = new_cls
|
||||
|
||||
return new_cls
|
||||
|
||||
|
||||
class BaseChronosPipeline(metaclass=PipelineRegistry):
|
||||
forecast_type: ForecastType
|
||||
dtypes = {"bfloat16": torch.bfloat16, "float32": torch.float32}
|
||||
|
||||
def __init__(self, inner_model: "PreTrainedModel"):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
inner_model : PreTrainedModel
|
||||
A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration
|
||||
"""
|
||||
# for easy access to the inner HF-style model
|
||||
self.inner_model = inner_model
|
||||
|
||||
def _prepare_and_validate_context(
|
||||
self, context: Union[torch.Tensor, List[torch.Tensor]]
|
||||
):
|
||||
if isinstance(context, list):
|
||||
context = left_pad_and_stack_1D(context)
|
||||
assert isinstance(context, torch.Tensor)
|
||||
if context.ndim == 1:
|
||||
context = context.unsqueeze(0)
|
||||
assert context.ndim == 2
|
||||
|
||||
return context
|
||||
|
||||
def predict(
|
||||
self,
|
||||
context: Union[torch.Tensor, List[torch.Tensor]],
|
||||
prediction_length: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Get forecasts for the given time series.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
context
|
||||
Input series. This is either a 1D tensor, or a list
|
||||
of 1D tensors, or a 2D tensor whose first dimension
|
||||
is batch. In the latter case, use left-padding with
|
||||
``torch.nan`` to align series of different lengths.
|
||||
prediction_length
|
||||
Time steps to predict. Defaults to a model-dependent
|
||||
value if not given.
|
||||
|
||||
Returns
|
||||
-------
|
||||
forecasts
|
||||
Tensor containing forecasts. The layout and meaning
|
||||
of the forecasts values depends on ``self.forecast_type``.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def predict_quantiles(
|
||||
self,
|
||||
context: Union[torch.Tensor, List[torch.Tensor]],
|
||||
prediction_length: Optional[int] = None,
|
||||
quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get quantile and mean forecasts for given time series.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
context : Union[torch.Tensor, List[torch.Tensor]]
|
||||
Input series. This is either a 1D tensor, or a list
|
||||
of 1D tensors, or a 2D tensor whose first dimension
|
||||
is batch. In the latter case, use left-padding with
|
||||
``torch.nan`` to align series of different lengths.
|
||||
prediction_length : Optional[int], optional
|
||||
Time steps to predict. Defaults to a model-dependent
|
||||
value if not given.
|
||||
quantile_levels : List[float], optional
|
||||
Quantile levels to compute, by default [0.1, 0.2, ..., 0.9]
|
||||
|
||||
Returns
|
||||
-------
|
||||
quantiles
|
||||
Tensor containing quantile forecasts. Shape
|
||||
(batch_size, prediction_length, num_quantiles)
|
||||
mean
|
||||
Tensor containing mean (point) forecasts. Shape
|
||||
(batch_size, prediction_length)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, Path],
|
||||
*model_args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load the model, either from a local path or from the HuggingFace Hub.
|
||||
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
|
||||
from ``transformers``.
|
||||
"""
|
||||
from transformers import AutoConfig
|
||||
|
||||
torch_dtype = kwargs.get("torch_dtype", "auto")
|
||||
if torch_dtype != "auto" and isinstance(torch_dtype, str):
|
||||
kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
|
||||
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(
|
||||
config, "chronos_config"
|
||||
)
|
||||
|
||||
if not is_valid_config:
|
||||
raise ValueError("Not a Chronos config file")
|
||||
|
||||
pipeline_class_name = getattr(
|
||||
config, "chronos_pipeline_class", "ChronosPipeline"
|
||||
)
|
||||
class_ = PipelineRegistry.REGISTRY.get(pipeline_class_name)
|
||||
if class_ is None:
|
||||
raise ValueError(
|
||||
f"Trying to load unknown pipeline class: {pipeline_class_name}"
|
||||
)
|
||||
|
||||
return class_.from_pretrained( # type: ignore[attr-defined]
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import warnings
|
||||
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>, Lorenzo Stella <stellalo@amazon.com>, Caner Turkmen <atturkm@amazon.com>
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
|
|
@ -16,6 +18,10 @@ from transformers import (
|
|||
)
|
||||
|
||||
import chronos
|
||||
from chronos.base import BaseChronosPipeline, ForecastType
|
||||
from chronos.utils import left_pad_and_stack_1D
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -364,21 +370,7 @@ class ChronosModel(nn.Module):
|
|||
return preds.reshape(input_ids.size(0), num_samples, -1)
|
||||
|
||||
|
||||
def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
|
||||
max_len = max(len(c) for c in tensors)
|
||||
padded = []
|
||||
for c in tensors:
|
||||
assert isinstance(c, torch.Tensor)
|
||||
assert c.ndim == 1
|
||||
padding = torch.full(
|
||||
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
|
||||
)
|
||||
padded.append(torch.concat((padding, c), dim=-1))
|
||||
return torch.stack(padded).to(tensors[0])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChronosPipeline:
|
||||
class ChronosPipeline(BaseChronosPipeline):
|
||||
"""
|
||||
A ``ChronosPipeline`` uses the given tokenizer and model to forecast
|
||||
input time series.
|
||||
|
|
@ -396,6 +388,12 @@ class ChronosPipeline:
|
|||
|
||||
tokenizer: ChronosTokenizer
|
||||
model: ChronosModel
|
||||
forecast_type: ForecastType = ForecastType.SAMPLES
|
||||
|
||||
def __init__(self, tokenizer, model):
|
||||
super().__init__(inner_model=model.model)
|
||||
self.tokenizer = tokenizer
|
||||
self.model = model
|
||||
|
||||
def _prepare_and_validate_context(
|
||||
self, context: Union[torch.Tensor, List[torch.Tensor]]
|
||||
|
|
@ -445,7 +443,7 @@ class ChronosPipeline:
|
|||
).cpu()
|
||||
return embeddings, tokenizer_state
|
||||
|
||||
def predict(
|
||||
def predict( # type: ignore[override]
|
||||
self,
|
||||
context: Union[torch.Tensor, List[torch.Tensor]],
|
||||
prediction_length: Optional[int] = None,
|
||||
|
|
@ -453,21 +451,16 @@ class ChronosPipeline:
|
|||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
limit_prediction_length: bool = True,
|
||||
limit_prediction_length: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get forecasts for the given time series.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
context
|
||||
Input series. This is either a 1D tensor, or a list
|
||||
of 1D tensors, or a 2D tensor whose first dimension
|
||||
is batch. In the latter case, use left-padding with
|
||||
``torch.nan`` to align series of different lengths.
|
||||
prediction_length
|
||||
Time steps to predict. Defaults to what specified
|
||||
in ``self.model.config``.
|
||||
Refer to the base method (``BaseChronosPipeline.predict``)
|
||||
for details on shared parameters.
|
||||
|
||||
Additional parameters
|
||||
---------------------
|
||||
num_samples
|
||||
Number of sample paths to predict. Defaults to what
|
||||
specified in ``self.model.config``.
|
||||
|
|
@ -482,7 +475,7 @@ class ChronosPipeline:
|
|||
Defaults to what specified in ``self.model.config``.
|
||||
limit_prediction_length
|
||||
Force prediction length smaller or equal than the
|
||||
built-in prediction length from the model. True by
|
||||
built-in prediction length from the model. False by
|
||||
default. When true, fail loudly if longer predictions
|
||||
are requested, otherwise longer predictions are allowed.
|
||||
|
||||
|
|
@ -505,7 +498,7 @@ class ChronosPipeline:
|
|||
if limit_prediction_length:
|
||||
msg += "You can turn off this check by setting `limit_prediction_length=False`."
|
||||
raise ValueError(msg)
|
||||
warnings.warn(msg)
|
||||
logger.warning(msg)
|
||||
|
||||
input_dtype = context_tensor.dtype
|
||||
input_device = context_tensor.device
|
||||
|
|
@ -542,6 +535,30 @@ class ChronosPipeline:
|
|||
|
||||
return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device)
|
||||
|
||||
def predict_quantiles(
|
||||
self,
|
||||
context: Union[torch.Tensor, List[torch.Tensor]],
|
||||
prediction_length: Optional[int] = None,
|
||||
quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
**predict_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Refer to the base method (``BaseChronosPipeline.predict_quantiles``).
|
||||
"""
|
||||
prediction_samples = (
|
||||
self.predict(context, prediction_length=prediction_length, **predict_kwargs)
|
||||
.detach()
|
||||
.swapaxes(1, 2)
|
||||
)
|
||||
mean = prediction_samples.mean(dim=-1)
|
||||
quantiles = torch.quantile(
|
||||
prediction_samples,
|
||||
q=torch.tensor(quantile_levels, dtype=prediction_samples.dtype),
|
||||
dim=-1,
|
||||
).permute(1, 2, 0)
|
||||
|
||||
return quantiles, mean
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
587
src/chronos/chronos_bolt.py
Normal file
587
src/chronos/chronos_bolt.py
Normal file
|
|
@ -0,0 +1,587 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>, Caner Turkmen <atturkm@amazon.com>, Lorenzo Stella <stellalo@amazon.com>
|
||||
# Original source:
|
||||
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/src/autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
ACT2FN,
|
||||
T5Config,
|
||||
T5LayerNorm,
|
||||
T5PreTrainedModel,
|
||||
T5Stack,
|
||||
)
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from .base import BaseChronosPipeline, ForecastType
|
||||
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChronosBoltConfig:
|
||||
context_length: int
|
||||
prediction_length: int
|
||||
input_patch_size: int
|
||||
input_patch_stride: int
|
||||
quantiles: List[float]
|
||||
use_reg_token: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChronosBoltOutput(ModelOutput):
|
||||
loss: Optional[torch.Tensor] = None
|
||||
quantile_preds: Optional[torch.Tensor] = None
|
||||
attentions: Optional[torch.Tensor] = None
|
||||
cross_attentions: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class Patch(nn.Module):
|
||||
def __init__(self, patch_size: int, patch_stride: int) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
length = x.shape[-1]
|
||||
|
||||
if length % self.patch_size != 0:
|
||||
padding_size = (
|
||||
*x.shape[:-1],
|
||||
self.patch_size - (length % self.patch_size),
|
||||
)
|
||||
padding = torch.full(
|
||||
size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device
|
||||
)
|
||||
x = torch.concat((padding, x), dim=-1)
|
||||
|
||||
x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
|
||||
return x
|
||||
|
||||
|
||||
class InstanceNorm(nn.Module):
|
||||
"""
|
||||
See, also, RevIN. Apply standardization along the last dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, eps: float = 1e-5) -> None:
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
loc_scale: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if loc_scale is None:
|
||||
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0)
|
||||
scale = torch.nan_to_num(
|
||||
torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0
|
||||
)
|
||||
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
|
||||
else:
|
||||
loc, scale = loc_scale
|
||||
|
||||
return (x - loc) / scale, (loc, scale)
|
||||
|
||||
def inverse(
|
||||
self, x: torch.Tensor, loc_scale: Tuple[torch.Tensor, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
loc, scale = loc_scale
|
||||
return x * scale + loc
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
h_dim: int,
|
||||
out_dim: int,
|
||||
act_fn_name: str,
|
||||
dropout_p: float = 0.0,
|
||||
use_layer_norm: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
self.hidden_layer = nn.Linear(in_dim, h_dim)
|
||||
self.act = ACT2FN[act_fn_name]
|
||||
self.output_layer = nn.Linear(h_dim, out_dim)
|
||||
self.residual_layer = nn.Linear(in_dim, out_dim)
|
||||
|
||||
self.use_layer_norm = use_layer_norm
|
||||
if use_layer_norm:
|
||||
self.layer_norm = T5LayerNorm(out_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
hid = self.act(self.hidden_layer(x))
|
||||
out = self.dropout(self.output_layer(hid))
|
||||
res = self.residual_layer(x)
|
||||
|
||||
out = out + res
|
||||
|
||||
if self.use_layer_norm:
|
||||
return self.layer_norm(out)
|
||||
return out
|
||||
|
||||
|
||||
class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"input_patch_embedding\.",
|
||||
r"output_patch_embedding\.",
|
||||
]
|
||||
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"]
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: T5Config):
|
||||
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
|
||||
|
||||
super().__init__(config)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
self.chronos_config = ChronosBoltConfig(**config.chronos_config)
|
||||
|
||||
# Only decoder_start_id (and optionally REG token)
|
||||
if self.chronos_config.use_reg_token:
|
||||
config.reg_token_id = 1
|
||||
|
||||
config.vocab_size = 2 if self.chronos_config.use_reg_token else 1
|
||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||
|
||||
# Input patch embedding layer
|
||||
self.input_patch_embedding = ResidualBlock(
|
||||
in_dim=self.chronos_config.input_patch_size * 2,
|
||||
h_dim=config.d_ff,
|
||||
out_dim=config.d_model,
|
||||
act_fn_name=config.dense_act_fn,
|
||||
dropout_p=config.dropout_rate,
|
||||
)
|
||||
|
||||
# patching layer
|
||||
self.patch = Patch(
|
||||
patch_size=self.chronos_config.input_patch_size,
|
||||
patch_stride=self.chronos_config.input_patch_stride,
|
||||
)
|
||||
|
||||
# instance normalization, also referred to as "scaling" in Chronos and GluonTS
|
||||
self.instance_norm = InstanceNorm()
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.use_cache = False
|
||||
encoder_config.is_encoder_decoder = False
|
||||
self.encoder = T5Stack(encoder_config, self.shared)
|
||||
|
||||
self._init_decoder(config)
|
||||
|
||||
self.num_quantiles = len(self.chronos_config.quantiles)
|
||||
quantiles = torch.tensor(self.chronos_config.quantiles, dtype=self.dtype)
|
||||
self.register_buffer("quantiles", quantiles, persistent=False)
|
||||
|
||||
self.output_patch_embedding = ResidualBlock(
|
||||
in_dim=config.d_model,
|
||||
h_dim=config.d_ff,
|
||||
out_dim=self.num_quantiles * self.chronos_config.prediction_length,
|
||||
act_fn_name=config.dense_act_fn,
|
||||
dropout_p=config.dropout_rate,
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, (self.__class__)):
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
elif isinstance(module, ResidualBlock):
|
||||
module.hidden_layer.weight.data.normal_(
|
||||
mean=0.0,
|
||||
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
|
||||
)
|
||||
if (
|
||||
hasattr(module.hidden_layer, "bias")
|
||||
and module.hidden_layer.bias is not None
|
||||
):
|
||||
module.hidden_layer.bias.data.zero_()
|
||||
|
||||
module.residual_layer.weight.data.normal_(
|
||||
mean=0.0,
|
||||
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
|
||||
)
|
||||
if (
|
||||
hasattr(module.residual_layer, "bias")
|
||||
and module.residual_layer.bias is not None
|
||||
):
|
||||
module.residual_layer.bias.data.zero_()
|
||||
|
||||
module.output_layer.weight.data.normal_(
|
||||
mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
|
||||
)
|
||||
if (
|
||||
hasattr(module.output_layer, "bias")
|
||||
and module.output_layer.bias is not None
|
||||
):
|
||||
module.output_layer.bias.data.zero_()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
target: Optional[torch.Tensor] = None,
|
||||
target_mask: Optional[torch.Tensor] = None,
|
||||
) -> ChronosBoltOutput:
|
||||
mask = (
|
||||
mask.to(context.dtype)
|
||||
if mask is not None
|
||||
else torch.isnan(context).logical_not().to(context.dtype)
|
||||
)
|
||||
|
||||
batch_size, _ = context.shape
|
||||
if context.shape[-1] > self.chronos_config.context_length:
|
||||
context = context[..., -self.chronos_config.context_length :]
|
||||
mask = mask[..., -self.chronos_config.context_length :]
|
||||
|
||||
# scaling
|
||||
context, loc_scale = self.instance_norm(context)
|
||||
|
||||
# the scaling op above is done in 32-bit precision,
|
||||
# then the context is moved to model's dtype
|
||||
context = context.to(self.dtype)
|
||||
mask = mask.to(self.dtype)
|
||||
|
||||
# patching
|
||||
patched_context = self.patch(context)
|
||||
patched_mask = torch.nan_to_num(self.patch(mask), nan=0.0)
|
||||
patched_context = torch.where(patched_mask > 0.0, patched_context, 0.0)
|
||||
# concat context and mask along patch dim
|
||||
patched_context = torch.cat([patched_context, patched_mask], dim=-1)
|
||||
|
||||
# attention_mask = 1 if at least one item in the patch is observed
|
||||
attention_mask = (
|
||||
patched_mask.sum(dim=-1) > 0
|
||||
) # (batch_size, patched_seq_length)
|
||||
|
||||
input_embeds = self.input_patch_embedding(patched_context)
|
||||
|
||||
if self.chronos_config.use_reg_token:
|
||||
# Append [REG]
|
||||
reg_input_ids = torch.full(
|
||||
(batch_size, 1),
|
||||
self.config.reg_token_id,
|
||||
device=input_embeds.device,
|
||||
)
|
||||
reg_embeds = self.shared(reg_input_ids)
|
||||
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
attention_mask.to(self.dtype),
|
||||
torch.ones_like(reg_input_ids).to(self.dtype),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=input_embeds,
|
||||
)
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
sequence_output = self.decode(input_embeds, attention_mask, hidden_states)
|
||||
|
||||
quantile_preds_shape = (
|
||||
batch_size,
|
||||
self.num_quantiles,
|
||||
self.chronos_config.prediction_length,
|
||||
)
|
||||
quantile_preds = self.output_patch_embedding(sequence_output).view(
|
||||
*quantile_preds_shape
|
||||
)
|
||||
|
||||
loss = None
|
||||
if target is not None:
|
||||
# normalize target
|
||||
target, _ = self.instance_norm(target, loc_scale)
|
||||
target = target.unsqueeze(1) # type: ignore
|
||||
assert self.chronos_config.prediction_length >= target.shape[-1]
|
||||
|
||||
target = target.to(quantile_preds.device)
|
||||
target_mask = (
|
||||
target_mask.unsqueeze(1).to(quantile_preds.device)
|
||||
if target_mask is not None
|
||||
else ~torch.isnan(target)
|
||||
)
|
||||
target[~target_mask] = 0.0
|
||||
|
||||
# pad target and target_mask if they are shorter than model's prediction_length
|
||||
if self.chronos_config.prediction_length > target.shape[-1]:
|
||||
padding_shape = (
|
||||
*target.shape[:-1],
|
||||
self.chronos_config.prediction_length - target.shape[-1],
|
||||
)
|
||||
target = torch.cat(
|
||||
[target, torch.zeros(padding_shape).to(target)], dim=-1
|
||||
)
|
||||
target_mask = torch.cat(
|
||||
[target_mask, torch.zeros(padding_shape).to(target_mask)], dim=-1
|
||||
)
|
||||
|
||||
loss = (
|
||||
2
|
||||
* torch.abs(
|
||||
(target - quantile_preds)
|
||||
* (
|
||||
(target <= quantile_preds).float()
|
||||
- self.quantiles.view(1, self.num_quantiles, 1)
|
||||
)
|
||||
)
|
||||
* target_mask.float()
|
||||
)
|
||||
loss = loss.mean(dim=-2) # Mean over prediction horizon
|
||||
loss = loss.sum(dim=-1) # Sum over quantile levels
|
||||
loss = loss.mean() # Mean over batch
|
||||
|
||||
# Unscale predictions
|
||||
quantile_preds = self.instance_norm.inverse(
|
||||
quantile_preds.view(batch_size, -1),
|
||||
loc_scale,
|
||||
).view(*quantile_preds_shape)
|
||||
|
||||
return ChronosBoltOutput(
|
||||
loss=loss,
|
||||
quantile_preds=quantile_preds,
|
||||
)
|
||||
|
||||
def _init_decoder(self, config):
|
||||
decoder_config = copy.deepcopy(config)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.is_encoder_decoder = False
|
||||
decoder_config.num_layers = config.num_decoder_layers
|
||||
self.decoder = T5Stack(decoder_config, self.shared)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
input_embeds,
|
||||
attention_mask,
|
||||
hidden_states,
|
||||
output_attentions=False,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
input_embeds: torch.Tensor
|
||||
Patched and embedded inputs. Shape (batch_size, patched_context_length, d_model)
|
||||
attention_mask: torch.Tensor
|
||||
Attention mask for the patched context. Shape (batch_size, patched_context_length), type: torch.int64
|
||||
hidden_states: torch.Tensor
|
||||
Hidden states returned by the encoder. Shape (batch_size, patched_context_length, d_model)
|
||||
|
||||
Returns
|
||||
-------
|
||||
last_hidden_state
|
||||
Last hidden state returned by the decoder, of shape (batch_size, 1, d_model)
|
||||
"""
|
||||
batch_size = input_embeds.shape[0]
|
||||
decoder_input_ids = torch.full(
|
||||
(batch_size, 1),
|
||||
self.config.decoder_start_token_id,
|
||||
device=input_embeds.device,
|
||||
)
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return decoder_outputs.last_hidden_state # sequence_outputs, b x 1 x d_model
|
||||
|
||||
|
||||
class ChronosBoltPipeline(BaseChronosPipeline):
|
||||
forecast_type: ForecastType = ForecastType.QUANTILES
|
||||
default_context_length: int = 2048
|
||||
|
||||
def __init__(self, model: ChronosBoltModelForForecasting):
|
||||
super().__init__(inner_model=model)
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def quantiles(self) -> List[float]:
|
||||
return self.model.config.chronos_config["quantiles"]
|
||||
|
||||
def predict( # type: ignore[override]
|
||||
self,
|
||||
context: Union[torch.Tensor, List[torch.Tensor]],
|
||||
prediction_length: Optional[int] = None,
|
||||
limit_prediction_length: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get forecasts for the given time series.
|
||||
|
||||
Refer to the base method (``BaseChronosPipeline.predict``)
|
||||
for details on shared parameters.
|
||||
Additional parameters
|
||||
---------------------
|
||||
limit_prediction_length
|
||||
Force prediction length smaller or equal than the
|
||||
built-in prediction length from the model. False by
|
||||
default. When true, fail loudly if longer predictions
|
||||
are requested, otherwise longer predictions are allowed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Forecasts of shape (batch_size, num_quantiles, prediction_length)
|
||||
where num_quantiles is the number of quantiles the model has been
|
||||
trained to output. For official Chronos-Bolt models, the value of
|
||||
num_quantiles is 9 for [0.1, 0.2, ..., 0.9]-quantiles.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
When limit_prediction_length is True and the prediction_length is
|
||||
greater than model's trainig prediction_length.
|
||||
"""
|
||||
context_tensor = self._prepare_and_validate_context(context=context)
|
||||
|
||||
model_context_length = self.model.config.chronos_config["context_length"]
|
||||
model_prediction_length = self.model.config.chronos_config["prediction_length"]
|
||||
if prediction_length is None:
|
||||
prediction_length = model_prediction_length
|
||||
|
||||
if prediction_length > model_prediction_length:
|
||||
msg = (
|
||||
f"We recommend keeping prediction length <= {model_prediction_length}. "
|
||||
"The quality of longer predictions may degrade since the model is not optimized for it. "
|
||||
)
|
||||
if limit_prediction_length:
|
||||
msg += "You can turn off this check by setting `limit_prediction_length=False`."
|
||||
raise ValueError(msg)
|
||||
warnings.warn(msg)
|
||||
|
||||
predictions = []
|
||||
remaining = prediction_length
|
||||
|
||||
# We truncate the context here because otherwise batches with very long
|
||||
# context could take up large amounts of GPU memory unnecessarily.
|
||||
if context_tensor.shape[-1] > model_context_length:
|
||||
context_tensor = context_tensor[..., -model_context_length:]
|
||||
|
||||
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
|
||||
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
|
||||
# every 64 steps.
|
||||
while remaining > 0:
|
||||
with torch.no_grad():
|
||||
prediction = self.model(
|
||||
context=context_tensor.to(
|
||||
device=self.model.device,
|
||||
dtype=torch.float32, # scaling should be done in 32-bit precision
|
||||
),
|
||||
).quantile_preds.to(context_tensor)
|
||||
|
||||
predictions.append(prediction)
|
||||
remaining -= prediction.shape[-1]
|
||||
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()
|
||||
central_prediction = prediction[:, central_idx]
|
||||
|
||||
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
|
||||
|
||||
return torch.cat(predictions, dim=-1)[..., :prediction_length]
|
||||
|
||||
def predict_quantiles(
|
||||
self,
|
||||
context: Union[torch.Tensor, List[torch.Tensor]],
|
||||
prediction_length: Optional[int] = None,
|
||||
quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
**predict_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Refer to the base method (``BaseChronosPipeline.predict_quantiles``).
|
||||
"""
|
||||
# shape (batch_size, prediction_length, len(training_quantile_levels))
|
||||
predictions = (
|
||||
self.predict(context, prediction_length=prediction_length, **predict_kwargs)
|
||||
.detach()
|
||||
.swapaxes(1, 2)
|
||||
)
|
||||
|
||||
training_quantile_levels = self.quantiles
|
||||
|
||||
if set(quantile_levels).issubset(set(training_quantile_levels)):
|
||||
# no need to perform intra/extrapolation
|
||||
quantiles = predictions[
|
||||
..., [training_quantile_levels.index(q) for q in quantile_levels]
|
||||
]
|
||||
else:
|
||||
# we rely on torch for interpolating quantiles if quantiles that
|
||||
# Chronos Bolt was trained on were not provided
|
||||
if min(quantile_levels) < min(training_quantile_levels) or max(
|
||||
quantile_levels
|
||||
) > max(training_quantile_levels):
|
||||
logger.warning(
|
||||
f"\tQuantiles to be predicted ({quantile_levels}) are not within the range of "
|
||||
f"quantiles that Chronos-Bolt was trained on ({training_quantile_levels}). "
|
||||
"Quantile predictions will be set to the minimum/maximum levels at which Chronos-Bolt "
|
||||
"was trained on. This may significantly affect the quality of the predictions."
|
||||
)
|
||||
|
||||
# TODO: this is a hack that assumes the model's quantiles during training (training_quantile_levels)
|
||||
# made up an equidistant grid along the quantile dimension. i.e., they were (0.1, 0.2, ..., 0.9).
|
||||
# While this holds for official Chronos-Bolt models, this may not be true in the future, and this
|
||||
# function may have to be revised.
|
||||
augmented_predictions = torch.cat(
|
||||
[predictions[..., [0]], predictions, predictions[..., [-1]]],
|
||||
dim=-1,
|
||||
)
|
||||
quantiles = torch.quantile(
|
||||
augmented_predictions,
|
||||
q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
|
||||
dim=-1,
|
||||
).permute(1, 2, 0)
|
||||
# NOTE: the median is returned as the mean here
|
||||
mean = predictions[:, :, training_quantile_levels.index(0.5)]
|
||||
return quantiles, mean
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
"""
|
||||
Load the model, either from a local path or from the HuggingFace Hub.
|
||||
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
|
||||
from ``transformers``.
|
||||
"""
|
||||
|
||||
config = AutoConfig.from_pretrained(*args, **kwargs)
|
||||
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
|
||||
|
||||
architecture = config.architectures[0]
|
||||
class_ = globals().get(architecture)
|
||||
|
||||
if class_ is None:
|
||||
logger.warning(
|
||||
f"Unknown architecture: {architecture}, defaulting to ChronosBoltModelForForecasting"
|
||||
)
|
||||
class_ = ChronosBoltModelForForecasting
|
||||
|
||||
model = class_.from_pretrained(*args, **kwargs)
|
||||
return cls(model=model)
|
||||
20
src/chronos/utils.py
Normal file
20
src/chronos/utils.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
|
||||
max_len = max(len(c) for c in tensors)
|
||||
padded = []
|
||||
for c in tensors:
|
||||
assert isinstance(c, torch.Tensor)
|
||||
assert c.ndim == 1
|
||||
padding = torch.full(
|
||||
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
|
||||
)
|
||||
padded.append(torch.concat((padding, c), dim=-1))
|
||||
return torch.stack(padded).to(tensors[0])
|
||||
50
test/dummy-chronos-bolt-model/config.json
Normal file
50
test/dummy-chronos-bolt-model/config.json
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
{
|
||||
"architectures": [
|
||||
"ChronosBoltModelForForecasting"
|
||||
],
|
||||
"chronos_config": {
|
||||
"context_length": 512,
|
||||
"input_patch_size": 16,
|
||||
"input_patch_stride": 16,
|
||||
"prediction_length": 64,
|
||||
"quantiles": [
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9
|
||||
],
|
||||
"use_reg_token": true
|
||||
},
|
||||
"chronos_pipeline_class": "ChronosBoltPipeline",
|
||||
"classifier_dropout": 0.0,
|
||||
"d_ff": 8,
|
||||
"d_kv": 4,
|
||||
"d_model": 8,
|
||||
"decoder_start_token_id": 0,
|
||||
"dense_act_fn": "relu",
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"feed_forward_proj": "relu",
|
||||
"initializer_factor": 0.05,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": false,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"n_positions": 512,
|
||||
"num_decoder_layers": 4,
|
||||
"num_heads": 4,
|
||||
"num_layers": 4,
|
||||
"pad_token_id": 0,
|
||||
"reg_token_id": 1,
|
||||
"relative_attention_max_distance": 128,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.40.2",
|
||||
"use_cache": true,
|
||||
"vocab_size": 2
|
||||
}
|
||||
BIN
test/dummy-chronos-bolt-model/model.safetensors
Normal file
BIN
test/dummy-chronos-bolt-model/model.safetensors
Normal file
Binary file not shown.
|
|
@ -2,12 +2,21 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from chronos import ChronosConfig, ChronosPipeline, MeanScaleUniformBins
|
||||
from chronos import (
|
||||
BaseChronosPipeline,
|
||||
ChronosConfig,
|
||||
ChronosPipeline,
|
||||
MeanScaleUniformBins,
|
||||
)
|
||||
|
||||
|
||||
def test_base_chronos_pipeline_loads_from_huggingface():
|
||||
BaseChronosPipeline.from_pretrained("amazon/chronos-t5-tiny", device_map="cpu")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
|
||||
|
|
@ -157,10 +166,14 @@ def test_tokenizer_random_data(use_eos_token: bool):
|
|||
assert samples.shape == (2, 10, 4)
|
||||
|
||||
|
||||
def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype) -> None:
|
||||
def validate_tensor(
|
||||
a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None
|
||||
) -> None:
|
||||
assert isinstance(a, torch.Tensor)
|
||||
assert a.shape == shape
|
||||
assert a.dtype == dtype
|
||||
|
||||
if dtype is not None:
|
||||
assert a.dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
|
|
@ -179,7 +192,9 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
|||
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(context, num_samples=7, prediction_length=65)
|
||||
samples = pipeline.predict(
|
||||
context, num_samples=7, prediction_length=65, limit_prediction_length=True
|
||||
)
|
||||
|
||||
samples = pipeline.predict(
|
||||
context, num_samples=7, prediction_length=65, limit_prediction_length=False
|
||||
|
|
@ -192,7 +207,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
|||
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
|
||||
samples = pipeline.predict(
|
||||
list(context),
|
||||
num_samples=7,
|
||||
prediction_length=65,
|
||||
limit_prediction_length=True,
|
||||
)
|
||||
|
||||
samples = pipeline.predict(
|
||||
list(context),
|
||||
|
|
@ -208,17 +228,73 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
|||
validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
|
||||
samples = pipeline.predict(
|
||||
context[0, ...],
|
||||
num_samples=7,
|
||||
prediction_length=65,
|
||||
limit_prediction_length=True,
|
||||
)
|
||||
|
||||
samples = pipeline.predict(
|
||||
context[0, ...],
|
||||
num_samples=7,
|
||||
prediction_length=65,
|
||||
limit_prediction_length=False,
|
||||
)
|
||||
validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("prediction_length", [3, 65])
|
||||
@pytest.mark.parametrize(
|
||||
"quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
|
||||
)
|
||||
def test_pipeline_predict_quantiles(
|
||||
model_dtype: torch.dtype,
|
||||
prediction_length: int,
|
||||
quantile_levels: list[int],
|
||||
):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=model_dtype,
|
||||
)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
|
||||
num_expected_quantiles = len(quantile_levels)
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
quantiles, mean = pipeline.predict_quantiles(
|
||||
context,
|
||||
num_samples=12,
|
||||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (4, prediction_length))
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
quantiles, mean = pipeline.predict_quantiles(
|
||||
list(context),
|
||||
num_samples=12,
|
||||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (4, prediction_length))
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
|
||||
quantiles, mean = pipeline.predict_quantiles(
|
||||
context[0, ...],
|
||||
num_samples=12,
|
||||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (1, prediction_length))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
|
||||
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
|
|
|
|||
246
test/test_chronos_bolt.py
Normal file
246
test/test_chronos_bolt.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from chronos import BaseChronosPipeline, ChronosBoltPipeline
|
||||
from chronos.chronos_bolt import InstanceNorm, Patch
|
||||
|
||||
|
||||
def validate_tensor(input: torch.Tensor, shape: Tuple[int, ...]) -> None:
|
||||
assert isinstance(input, torch.Tensor)
|
||||
assert input.shape == shape
|
||||
|
||||
|
||||
def test_base_chronos_pipeline_loads_from_huggingface():
|
||||
BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-tiny", device_map="cpu")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
|
||||
def test_pipeline_predict(torch_dtype: str):
|
||||
pipeline = ChronosBoltPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-bolt-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
expected_num_quantiles = len(pipeline.quantiles)
|
||||
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
quantiles = pipeline.predict(context, prediction_length=3)
|
||||
validate_tensor(quantiles, (4, expected_num_quantiles, 3))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
quantiles = pipeline.predict(
|
||||
context, prediction_length=65, limit_prediction_length=True
|
||||
)
|
||||
|
||||
quantiles = pipeline.predict(context, prediction_length=65)
|
||||
validate_tensor(quantiles, (4, expected_num_quantiles, 65))
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
quantiles = pipeline.predict(list(context), prediction_length=3)
|
||||
validate_tensor(quantiles, (4, expected_num_quantiles, 3))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
quantiles = pipeline.predict(
|
||||
list(context),
|
||||
prediction_length=65,
|
||||
limit_prediction_length=True,
|
||||
)
|
||||
|
||||
quantiles = pipeline.predict(list(context), prediction_length=65)
|
||||
validate_tensor(quantiles, (4, expected_num_quantiles, 65))
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
|
||||
quantiles = pipeline.predict(context[0, ...], prediction_length=3)
|
||||
validate_tensor(quantiles, (1, expected_num_quantiles, 3))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
quantiles = pipeline.predict(
|
||||
context[0, ...],
|
||||
prediction_length=65,
|
||||
limit_prediction_length=True,
|
||||
)
|
||||
|
||||
quantiles = pipeline.predict(
|
||||
context[0, ...],
|
||||
prediction_length=65,
|
||||
)
|
||||
validate_tensor(quantiles, (1, expected_num_quantiles, 65))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("prediction_length", [3, 65])
|
||||
@pytest.mark.parametrize(
|
||||
"quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
|
||||
)
|
||||
def test_pipeline_predict_quantiles(
|
||||
torch_dtype: str, prediction_length: int, quantile_levels: list[int]
|
||||
):
|
||||
pipeline = ChronosBoltPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-bolt-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
|
||||
num_expected_quantiles = len(quantile_levels)
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
quantiles, mean = pipeline.predict_quantiles(
|
||||
context,
|
||||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (4, prediction_length))
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
quantiles, mean = pipeline.predict_quantiles(
|
||||
list(context),
|
||||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (4, prediction_length))
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
|
||||
quantiles, mean = pipeline.predict_quantiles(
|
||||
context[0, ...],
|
||||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles))
|
||||
validate_tensor(mean, (1, prediction_length))
|
||||
|
||||
|
||||
# The following tests have been taken from
|
||||
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py
|
||||
# Author: Caner Turkmen <atturkm@amazon.com>
|
||||
|
||||
|
||||
def test_given_even_data_patch_operator_output_is_correct():
|
||||
batch_size = 17
|
||||
patch_len = 16
|
||||
|
||||
patch = Patch(patch_len, patch_len)
|
||||
|
||||
batch = (
|
||||
torch.stack([torch.arange(512)] * batch_size)
|
||||
+ torch.arange(batch_size)[:, None]
|
||||
)
|
||||
output = patch(batch)
|
||||
|
||||
assert output.shape == (batch_size, 512 // patch_len, patch_len)
|
||||
|
||||
assert torch.allclose(
|
||||
output[:, 0],
|
||||
torch.stack([torch.arange(patch_len)] * batch_size)
|
||||
+ torch.arange(batch_size)[:, None],
|
||||
atol=1e-5,
|
||||
)
|
||||
assert torch.allclose(
|
||||
output[:, 1],
|
||||
torch.stack([torch.arange(patch_len, 2 * patch_len)] * batch_size)
|
||||
+ torch.arange(batch_size)[:, None],
|
||||
atol=1e-5,
|
||||
)
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
|
||||
def test_given_even_data_and_strides_patch_operator_output_is_correct():
|
||||
batch_size = 17
|
||||
patch_len, patch_stride = 16, 8
|
||||
|
||||
patch = Patch(patch_len, patch_stride)
|
||||
|
||||
offset = torch.arange(batch_size)[:, None]
|
||||
batch = torch.stack([torch.arange(512)] * batch_size) + offset
|
||||
output = patch(batch)
|
||||
|
||||
assert torch.allclose(
|
||||
output[:, 1],
|
||||
torch.stack([torch.arange(patch_stride, patch_stride + patch_len)] * batch_size)
|
||||
+ offset,
|
||||
atol=1e-5,
|
||||
)
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
|
||||
def test_given_uneven_data_patch_operator_pads_and_output_is_correct():
|
||||
batch_size = 17
|
||||
patch_len = 16
|
||||
|
||||
patch = Patch(patch_len, patch_len)
|
||||
|
||||
batch = (
|
||||
torch.stack([torch.arange(512 - patch_len + 1)] * batch_size)
|
||||
+ torch.arange(batch_size)[:, None]
|
||||
).float()
|
||||
output = patch(batch)
|
||||
|
||||
assert output.shape == (batch_size, 512 // patch_len, patch_len)
|
||||
|
||||
# check the first portion is padded
|
||||
assert torch.isnan(output[:, 0, :-1]).all()
|
||||
|
||||
# check nowhere else is nan
|
||||
assert not torch.isnan(output[:, 1:]).any()
|
||||
|
||||
|
||||
def test_when_instancenorm_applied_then_standardization_correct():
|
||||
inorm = InstanceNorm()
|
||||
|
||||
input_ = torch.tensor(
|
||||
[
|
||||
[1, 2, 3, 4, 5],
|
||||
[2, 3, 4, 5, 6],
|
||||
]
|
||||
).float()
|
||||
|
||||
normalized, (loc, scale) = inorm(input_)
|
||||
|
||||
assert normalized.shape == input_.shape
|
||||
assert torch.allclose(normalized[0], normalized[1])
|
||||
assert torch.allclose(loc.squeeze(), torch.tensor([3.0, 4.0]))
|
||||
assert torch.allclose(scale.squeeze(), torch.tensor(1.41421))
|
||||
|
||||
|
||||
def test_when_instancenorm_applied_and_reversed_then_nans_preserved():
|
||||
inorm = InstanceNorm()
|
||||
|
||||
input_ = torch.tensor(
|
||||
[
|
||||
[1, torch.nan, 3, 4, 5],
|
||||
[2, 3, 4, 5, torch.nan],
|
||||
]
|
||||
).float()
|
||||
|
||||
normalized, (loc, scale) = inorm(input_)
|
||||
assert torch.allclose(normalized.isnan(), input_.isnan())
|
||||
|
||||
output = inorm.inverse(normalized, (loc, scale))
|
||||
assert torch.allclose(output, input_, equal_nan=True)
|
||||
|
||||
|
||||
def test_when_instancenorm_applied_and_reversed_then_output_correct():
|
||||
inorm = InstanceNorm()
|
||||
|
||||
input_ = torch.tensor(
|
||||
[
|
||||
[1, 2, 3, 4, 5],
|
||||
[2, 3, 4, 5, 1000],
|
||||
]
|
||||
).float()
|
||||
|
||||
normalized, loc_scale = inorm(input_)
|
||||
output = inorm.inverse(normalized, loc_scale)
|
||||
|
||||
assert torch.allclose(output, input_)
|
||||
Loading…
Reference in a new issue