Add generation params to eval script (#138)

*Description of changes:* Adds generation params to command line options
for the evaluation script.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Abdul Fatir 2024-06-27 23:11:05 +02:00 committed by GitHub
parent df67c3eb44
commit 9d59057b72
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Iterable
from typing import Iterable, Optional
import datasets
import numpy as np
@ -232,8 +232,9 @@ def generate_sample_forecasts(
test_data_input: Iterable,
pipeline: ChronosPipeline,
prediction_length: int,
num_samples: int,
batch_size: int,
num_samples: int,
**predict_kwargs,
):
# Generate forecast samples
forecast_samples = []
@ -244,6 +245,7 @@ def generate_sample_forecasts(
context,
prediction_length=prediction_length,
num_samples=num_samples,
**predict_kwargs,
).numpy()
)
forecast_samples = np.concatenate(forecast_samples)
@ -268,6 +270,9 @@ def main(
torch_dtype: str = "bfloat16",
batch_size: int = 32,
num_samples: int = 20,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
):
if isinstance(torch_dtype, str):
torch_dtype = getattr(torch, torch_dtype)
@ -277,7 +282,7 @@ def main(
pipeline = ChronosPipeline.from_pretrained(
chronos_model_id,
device_map=device,
torch_dtype=torch.bfloat16,
torch_dtype=torch_dtype,
)
# Load backtest configs
@ -300,8 +305,11 @@ def main(
test_data.input,
pipeline=pipeline,
prediction_length=prediction_length,
num_samples=num_samples,
batch_size=batch_size,
num_samples=num_samples,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
logger.info(f"Evaluating forecasts for {dataset_name}")