mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Fix references
This commit is contained in:
parent
5daf273ca6
commit
3a1a44e252
2 changed files with 3 additions and 3 deletions
|
|
@ -164,7 +164,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
If True, ensures that DataParallel is disabled and training happens on a single GPU
|
||||
convert_inputs
|
||||
If True (default), preprocess raw inputs (convert tensors, encode categoricals, validate).
|
||||
If False, inputs are expected to be already preprocessed using `chronos.chronos2.dataset.prepare_tasks`.
|
||||
If False, inputs are expected to be already preprocessed using `chronos.chronos2.dataset.prepare_inputs`.
|
||||
This allows for efficient training on large datasets that don't fit in memory.
|
||||
**extra_trainer_kwargs
|
||||
Extra kwargs are directly forwarded to `TrainingArguments`
|
||||
|
|
|
|||
|
|
@ -1147,13 +1147,13 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
|
|||
|
||||
def test_pipeline_can_be_finetuned_with_preprocessed_hf_dataset(pipeline):
|
||||
"""Test that fine-tuning works with preprocessed inputs from a HuggingFace Dataset."""
|
||||
from chronos.chronos2.dataset import prepare_tasks
|
||||
from chronos.chronos2.dataset import prepare_inputs
|
||||
|
||||
prediction_length = 8
|
||||
raw_inputs = [{"target": torch.rand(20)}, {"target": torch.rand(25)}, {"target": torch.rand(30)}]
|
||||
|
||||
# Preprocess and convert to HF Dataset (simulating Arrow-based lazy loading)
|
||||
prepared_tasks = prepare_tasks(raw_inputs, prediction_length=prediction_length, min_past=1, mode="train")
|
||||
prepared_tasks = prepare_inputs(raw_inputs, prediction_length=prediction_length, min_past=1, mode="train")
|
||||
hf_dataset = datasets.Dataset.from_list(prepared_tasks).with_format("torch")
|
||||
|
||||
# Fine-tune with preprocessed inputs
|
||||
|
|
|
|||
Loading…
Reference in a new issue