diff --git a/scripts/evaluation/evaluate.py b/scripts/evaluation/evaluate.py index 594f0c2..e928420 100644 --- a/scripts/evaluation/evaluate.py +++ b/scripts/evaluation/evaluate.py @@ -9,7 +9,7 @@ import torch import typer import yaml from gluonts.dataset.split import split -from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss +from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss, ND from gluonts.itertools import batcher from gluonts.model.evaluation import evaluate_forecasts from gluonts.model.forecast import QuantileForecast, SampleForecast @@ -19,6 +19,10 @@ from chronos import BaseChronosPipeline, Chronos2Pipeline, ChronosBoltPipeline, app = typer.Typer(pretty_exceptions_enable=False) +logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger("Chronos Evaluation") +logger.setLevel(logging.INFO) + QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] @@ -150,6 +154,7 @@ def eval_pipeline_and_save_results( metrics=[ MASE(), MeanWeightedSumQuantileLoss(QUANTILES), + ND(), ], batch_size=5000, ) @@ -162,7 +167,7 @@ def eval_pipeline_and_save_results( results_df = ( pd.DataFrame(result_rows) .rename( - {"MASE[0.5]": "MASE", "mean_weighted_sum_quantile_loss": "WQL"}, + {"MASE[0.5]": "MASE", "mean_weighted_sum_quantile_loss": "WQL", "ND": "WAPE"}, axis="columns", ) .sort_values(by="dataset") @@ -340,7 +345,4 @@ def chronos_2( if __name__ == "__main__": - logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - logger = logging.getLogger("Chronos Evaluation") - logger.setLevel(logging.INFO) app()