diff --git a/test/test_chronos2.py b/test/test_chronos2.py index 3d8884e..e0e658a 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -1143,3 +1143,33 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline): for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped): # Should match exactly or very close (numerical precision) assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) + + +def test_pipeline_can_be_finetuned_with_preprocessed_hf_dataset(pipeline): + """Test that fine-tuning works with preprocessed inputs from a HuggingFace Dataset.""" + from chronos.chronos2.dataset import prepare_tasks + + prediction_length = 8 + raw_inputs = [{"target": torch.rand(20)}, {"target": torch.rand(25)}, {"target": torch.rand(30)}] + + # Preprocess and convert to HF Dataset (simulating Arrow-based lazy loading) + prepared_tasks = prepare_tasks(raw_inputs, prediction_length=prediction_length, min_past=1, mode="train") + hf_dataset = datasets.Dataset.from_dict({ + "context": [task["context"].numpy() for task in prepared_tasks], + "future_covariates": [task["future_covariates"].numpy() for task in prepared_tasks], + "n_targets": [task["n_targets"] for task in prepared_tasks], + "n_covariates": [task["n_covariates"] for task in prepared_tasks], + "n_future_covariates": [task["n_future_covariates"] for task in prepared_tasks], + }).with_format("torch") + + # Fine-tune with preprocessed inputs + ft_pipeline = pipeline.fit( + hf_dataset, prediction_length=prediction_length, num_steps=5, min_past=1, batch_size=32, convert_inputs=False + ) + + # Verify fine-tuned model can predict + ft_outputs = ft_pipeline.predict(raw_inputs, prediction_length=prediction_length) + assert len(ft_outputs) == len(raw_inputs) + for ft_out in ft_outputs: + assert ft_out.shape == (1, DEFAULT_MODEL_NUM_QUANTILES, prediction_length) + assert not torch.isnan(ft_out).any()