diff --git a/test/test_chronos2.py b/test/test_chronos2.py index e0e658a..07d41a6 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -1154,13 +1154,7 @@ def test_pipeline_can_be_finetuned_with_preprocessed_hf_dataset(pipeline): # Preprocess and convert to HF Dataset (simulating Arrow-based lazy loading) prepared_tasks = prepare_tasks(raw_inputs, prediction_length=prediction_length, min_past=1, mode="train") - hf_dataset = datasets.Dataset.from_dict({ - "context": [task["context"].numpy() for task in prepared_tasks], - "future_covariates": [task["future_covariates"].numpy() for task in prepared_tasks], - "n_targets": [task["n_targets"] for task in prepared_tasks], - "n_covariates": [task["n_covariates"] for task in prepared_tasks], - "n_future_covariates": [task["n_future_covariates"] for task in prepared_tasks], - }).with_format("torch") + hf_dataset = datasets.Dataset.from_list(prepared_tasks).with_format("torch") # Fine-tune with preprocessed inputs ft_pipeline = pipeline.fit(