From c1237a52593550c1d580d9105187997af449df79 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Mon, 1 Dec 2025 17:56:37 +0100 Subject: [PATCH] Chronos-2: Add option to remove PrinterCallback (#410) *Issue #, if available:* *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- src/chronos/chronos2/pipeline.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 94b2b3e..9f8a0dc 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -19,7 +19,6 @@ from transformers import AutoConfig from transformers.utils.import_utils import is_peft_available from transformers.utils.peft_utils import find_adapter_config_file - import chronos.chronos2 from chronos.base import BaseChronosPipeline, ForecastType from chronos.chronos2 import Chronos2Model @@ -114,6 +113,7 @@ class Chronos2Pipeline(BaseChronosPipeline): min_past: int | None = None, finetuned_ckpt_name: str = "finetuned-ckpt", callbacks: list["TrainerCallback"] | None = None, + remove_printer_callback: bool = False, **extra_trainer_kwargs, ) -> "Chronos2Pipeline": """ @@ -156,6 +156,8 @@ class Chronos2Pipeline(BaseChronosPipeline): The name of the directory inside `output_dir` in which the final fine-tuned checkpoint will be saved, by default "finetuned-ckpt" callbacks A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer` + remove_printer_callback + If True, all instances of `PrinterCallback` are removed from callbacks **extra_trainer_kwargs Extra kwargs are directly forwarded to `TrainingArguments` @@ -165,6 +167,7 @@ class Chronos2Pipeline(BaseChronosPipeline): """ import torch.cuda + from transformers.trainer_callback import PrinterCallback from transformers.training_args import TrainingArguments if finetune_mode == "lora": @@ -322,6 +325,10 @@ class Chronos2Pipeline(BaseChronosPipeline): eval_dataset=eval_dataset, callbacks=callbacks, ) + + if remove_printer_callback: + trainer.pop_callback(PrinterCallback) + trainer.train() # update max_output_patches, if the model was fine-tuned with longer prediction_length