mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Add generation params to eval script (#138)
*Description of changes:* Adds generation params to command line options for the evaluation script. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
df67c3eb44
commit
9d59057b72
1 changed files with 12 additions and 4 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
|
@ -232,8 +232,9 @@ def generate_sample_forecasts(
|
|||
test_data_input: Iterable,
|
||||
pipeline: ChronosPipeline,
|
||||
prediction_length: int,
|
||||
num_samples: int,
|
||||
batch_size: int,
|
||||
num_samples: int,
|
||||
**predict_kwargs,
|
||||
):
|
||||
# Generate forecast samples
|
||||
forecast_samples = []
|
||||
|
|
@ -244,6 +245,7 @@ def generate_sample_forecasts(
|
|||
context,
|
||||
prediction_length=prediction_length,
|
||||
num_samples=num_samples,
|
||||
**predict_kwargs,
|
||||
).numpy()
|
||||
)
|
||||
forecast_samples = np.concatenate(forecast_samples)
|
||||
|
|
@ -268,6 +270,9 @@ def main(
|
|||
torch_dtype: str = "bfloat16",
|
||||
batch_size: int = 32,
|
||||
num_samples: int = 20,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
):
|
||||
if isinstance(torch_dtype, str):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
|
|
@ -277,7 +282,7 @@ def main(
|
|||
pipeline = ChronosPipeline.from_pretrained(
|
||||
chronos_model_id,
|
||||
device_map=device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
# Load backtest configs
|
||||
|
|
@ -300,8 +305,11 @@ def main(
|
|||
test_data.input,
|
||||
pipeline=pipeline,
|
||||
prediction_length=prediction_length,
|
||||
num_samples=num_samples,
|
||||
batch_size=batch_size,
|
||||
num_samples=num_samples,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
logger.info(f"Evaluating forecasts for {dataset_name}")
|
||||
|
|
|
|||
Loading…
Reference in a new issue