feat(eval): add WAPE metric (ND) to evaluation pipeline and fix logger scope

This commit is contained in:
Jorge Emiliano 2026-01-28 10:43:23 -03:00
parent 1f099eb265
commit 0aca1079bc

View file

@ -9,7 +9,7 @@ import torch
import typer
import yaml
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss, ND
from gluonts.itertools import batcher
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import QuantileForecast, SampleForecast
@ -19,6 +19,10 @@ from chronos import BaseChronosPipeline, Chronos2Pipeline, ChronosBoltPipeline,
app = typer.Typer(pretty_exceptions_enable=False)
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("Chronos Evaluation")
logger.setLevel(logging.INFO)
QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
@ -150,6 +154,7 @@ def eval_pipeline_and_save_results(
metrics=[
MASE(),
MeanWeightedSumQuantileLoss(QUANTILES),
ND(),
],
batch_size=5000,
)
@ -162,7 +167,7 @@ def eval_pipeline_and_save_results(
results_df = (
pd.DataFrame(result_rows)
.rename(
{"MASE[0.5]": "MASE", "mean_weighted_sum_quantile_loss": "WQL"},
{"MASE[0.5]": "MASE", "mean_weighted_sum_quantile_loss": "WQL", "ND": "WAPE"},
axis="columns",
)
.sort_values(by="dataset")
@ -340,7 +345,4 @@ def chronos_2(
if __name__ == "__main__":
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("Chronos Evaluation")
logger.setLevel(logging.INFO)
app()