mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Add test for lazy dataset
This commit is contained in:
parent
4dc9f486d6
commit
19c1b72a94
1 changed files with 30 additions and 0 deletions
|
|
@ -1143,3 +1143,33 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
|
|||
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
|
||||
# Should match exactly or very close (numerical precision)
|
||||
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
|
||||
|
||||
|
||||
def test_pipeline_can_be_finetuned_with_preprocessed_hf_dataset(pipeline):
|
||||
"""Test that fine-tuning works with preprocessed inputs from a HuggingFace Dataset."""
|
||||
from chronos.chronos2.dataset import prepare_tasks
|
||||
|
||||
prediction_length = 8
|
||||
raw_inputs = [{"target": torch.rand(20)}, {"target": torch.rand(25)}, {"target": torch.rand(30)}]
|
||||
|
||||
# Preprocess and convert to HF Dataset (simulating Arrow-based lazy loading)
|
||||
prepared_tasks = prepare_tasks(raw_inputs, prediction_length=prediction_length, min_past=1, mode="train")
|
||||
hf_dataset = datasets.Dataset.from_dict({
|
||||
"context": [task["context"].numpy() for task in prepared_tasks],
|
||||
"future_covariates": [task["future_covariates"].numpy() for task in prepared_tasks],
|
||||
"n_targets": [task["n_targets"] for task in prepared_tasks],
|
||||
"n_covariates": [task["n_covariates"] for task in prepared_tasks],
|
||||
"n_future_covariates": [task["n_future_covariates"] for task in prepared_tasks],
|
||||
}).with_format("torch")
|
||||
|
||||
# Fine-tune with preprocessed inputs
|
||||
ft_pipeline = pipeline.fit(
|
||||
hf_dataset, prediction_length=prediction_length, num_steps=5, min_past=1, batch_size=32, convert_inputs=False
|
||||
)
|
||||
|
||||
# Verify fine-tuned model can predict
|
||||
ft_outputs = ft_pipeline.predict(raw_inputs, prediction_length=prediction_length)
|
||||
assert len(ft_outputs) == len(raw_inputs)
|
||||
for ft_out in ft_outputs:
|
||||
assert ft_out.shape == (1, DEFAULT_MODEL_NUM_QUANTILES, prediction_length)
|
||||
assert not torch.isnan(ft_out).any()
|
||||
|
|
|
|||
Loading…
Reference in a new issue