Chronos-2: Add option to specify callbacks in fit (#405)

*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Abdul Fatir 2025-11-28 16:55:35 +01:00 committed by GitHub
parent 514019fb1c
commit 5cde42eb1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -19,6 +19,7 @@ from transformers import AutoConfig
from transformers.utils.import_utils import is_peft_available
from transformers.utils.peft_utils import find_adapter_config_file
import chronos.chronos2
from chronos.base import BaseChronosPipeline, ForecastType
from chronos.chronos2 import Chronos2Model
@ -31,6 +32,7 @@ if TYPE_CHECKING:
import fev
import pandas as pd
from peft import LoraConfig
from transformers.trainer_callback import TrainerCallback
logger = logging.getLogger(__name__)
@ -111,6 +113,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
output_dir: Path | str | None = None,
min_past: int | None = None,
finetuned_ckpt_name: str = "finetuned-ckpt",
callbacks: list["TrainerCallback"] | None = None,
**extra_trainer_kwargs,
) -> "Chronos2Pipeline":
"""
@ -151,6 +154,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
are filtered out, by default set equal to prediction_length
finetuned_ckpt_name
The name of the directory inside `output_dir` in which the final fine-tuned checkpoint will be saved, by default "finetuned-ckpt"
callbacks
A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer`
**extra_trainer_kwargs
Extra kwargs are directly forwarded to `TrainingArguments`
@ -276,7 +281,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
)
eval_dataset = None
callbacks = []
callbacks = callbacks or []
if validation_inputs is not None:
# construct validation dataset
eval_dataset = Chronos2Dataset.convert_inputs(