Chronos-2: Add option to remove PrinterCallback (#410)

*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Abdul Fatir 2025-12-01 17:56:37 +01:00 committed by GitHub
parent 1da6965318
commit c1237a5259
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -19,7 +19,6 @@ from transformers import AutoConfig
from transformers.utils.import_utils import is_peft_available
from transformers.utils.peft_utils import find_adapter_config_file
import chronos.chronos2
from chronos.base import BaseChronosPipeline, ForecastType
from chronos.chronos2 import Chronos2Model
@ -114,6 +113,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
min_past: int | None = None,
finetuned_ckpt_name: str = "finetuned-ckpt",
callbacks: list["TrainerCallback"] | None = None,
remove_printer_callback: bool = False,
**extra_trainer_kwargs,
) -> "Chronos2Pipeline":
"""
@ -156,6 +156,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
The name of the directory inside `output_dir` in which the final fine-tuned checkpoint will be saved, by default "finetuned-ckpt"
callbacks
A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer`
remove_printer_callback
If True, all instances of `PrinterCallback` are removed from callbacks
**extra_trainer_kwargs
Extra kwargs are directly forwarded to `TrainingArguments`
@ -165,6 +167,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
"""
import torch.cuda
from transformers.trainer_callback import PrinterCallback
from transformers.training_args import TrainingArguments
if finetune_mode == "lora":
@ -322,6 +325,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
eval_dataset=eval_dataset,
callbacks=callbacks,
)
if remove_printer_callback:
trainer.pop_callback(PrinterCallback)
trainer.train()
# update max_output_patches, if the model was fine-tuned with longer prediction_length