diff --git a/scripts/evaluation/evaluate.py b/scripts/evaluation/evaluate.py index 6d18db1..756f544 100644 --- a/scripts/evaluation/evaluate.py +++ b/scripts/evaluation/evaluate.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Iterable +from typing import Iterable, Optional import datasets import numpy as np @@ -232,8 +232,9 @@ def generate_sample_forecasts( test_data_input: Iterable, pipeline: ChronosPipeline, prediction_length: int, - num_samples: int, batch_size: int, + num_samples: int, + **predict_kwargs, ): # Generate forecast samples forecast_samples = [] @@ -244,6 +245,7 @@ def generate_sample_forecasts( context, prediction_length=prediction_length, num_samples=num_samples, + **predict_kwargs, ).numpy() ) forecast_samples = np.concatenate(forecast_samples) @@ -268,6 +270,9 @@ def main( torch_dtype: str = "bfloat16", batch_size: int = 32, num_samples: int = 20, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, ): if isinstance(torch_dtype, str): torch_dtype = getattr(torch, torch_dtype) @@ -277,7 +282,7 @@ def main( pipeline = ChronosPipeline.from_pretrained( chronos_model_id, device_map=device, - torch_dtype=torch.bfloat16, + torch_dtype=torch_dtype, ) # Load backtest configs @@ -300,8 +305,11 @@ def main( test_data.input, pipeline=pipeline, prediction_length=prediction_length, - num_samples=num_samples, batch_size=batch_size, + num_samples=num_samples, + temperature=temperature, + top_k=top_k, + top_p=top_p, ) logger.info(f"Evaluating forecasts for {dataset_name}")