From 5cde42eb1f41d3041750ca70e7f6c3db07998073 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Fri, 28 Nov 2025 16:55:35 +0100 Subject: [PATCH] Chronos-2: Add option to specify callbacks in `fit` (#405) *Issue #, if available:* *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- src/chronos/chronos2/pipeline.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 30785ef..125d1d9 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -19,6 +19,7 @@ from transformers import AutoConfig from transformers.utils.import_utils import is_peft_available from transformers.utils.peft_utils import find_adapter_config_file + import chronos.chronos2 from chronos.base import BaseChronosPipeline, ForecastType from chronos.chronos2 import Chronos2Model @@ -31,6 +32,7 @@ if TYPE_CHECKING: import fev import pandas as pd from peft import LoraConfig + from transformers.trainer_callback import TrainerCallback logger = logging.getLogger(__name__) @@ -111,6 +113,7 @@ class Chronos2Pipeline(BaseChronosPipeline): output_dir: Path | str | None = None, min_past: int | None = None, finetuned_ckpt_name: str = "finetuned-ckpt", + callbacks: list["TrainerCallback"] | None = None, **extra_trainer_kwargs, ) -> "Chronos2Pipeline": """ @@ -151,6 +154,8 @@ class Chronos2Pipeline(BaseChronosPipeline): are filtered out, by default set equal to prediction_length finetuned_ckpt_name The name of the directory inside `output_dir` in which the final fine-tuned checkpoint will be saved, by default "finetuned-ckpt" + callbacks + A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer` **extra_trainer_kwargs Extra kwargs are directly forwarded to `TrainingArguments` @@ -276,7 +281,7 @@ class Chronos2Pipeline(BaseChronosPipeline): ) eval_dataset = None - callbacks = [] + callbacks = callbacks or [] if validation_inputs is not None: # construct validation dataset eval_dataset = Chronos2Dataset.convert_inputs(