Fix references

This commit is contained in:
Oleksandr Shchur 2026-02-19 09:28:20 +00:00
parent 5daf273ca6
commit 3a1a44e252
2 changed files with 3 additions and 3 deletions

View file

@ -164,7 +164,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
If True, ensures that DataParallel is disabled and training happens on a single GPU
convert_inputs
If True (default), preprocess raw inputs (convert tensors, encode categoricals, validate).
If False, inputs are expected to be already preprocessed using `chronos.chronos2.dataset.prepare_tasks`.
If False, inputs are expected to be already preprocessed using `chronos.chronos2.dataset.prepare_inputs`.
This allows for efficient training on large datasets that don't fit in memory.
**extra_trainer_kwargs
Extra kwargs are directly forwarded to `TrainingArguments`

View file

@ -1147,13 +1147,13 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
def test_pipeline_can_be_finetuned_with_preprocessed_hf_dataset(pipeline):
"""Test that fine-tuning works with preprocessed inputs from a HuggingFace Dataset."""
from chronos.chronos2.dataset import prepare_tasks
from chronos.chronos2.dataset import prepare_inputs
prediction_length = 8
raw_inputs = [{"target": torch.rand(20)}, {"target": torch.rand(25)}, {"target": torch.rand(30)}]
# Preprocess and convert to HF Dataset (simulating Arrow-based lazy loading)
prepared_tasks = prepare_tasks(raw_inputs, prediction_length=prediction_length, min_past=1, mode="train")
prepared_tasks = prepare_inputs(raw_inputs, prediction_length=prediction_length, min_past=1, mode="train")
hf_dataset = datasets.Dataset.from_list(prepared_tasks).with_format("torch")
# Fine-tune with preprocessed inputs