mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
Chronos-2: Add option to remove PrinterCallback (#410)
*Issue #, if available:* *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
1da6965318
commit
c1237a5259
1 changed files with 8 additions and 1 deletions
|
|
@ -19,7 +19,6 @@ from transformers import AutoConfig
|
|||
from transformers.utils.import_utils import is_peft_available
|
||||
from transformers.utils.peft_utils import find_adapter_config_file
|
||||
|
||||
|
||||
import chronos.chronos2
|
||||
from chronos.base import BaseChronosPipeline, ForecastType
|
||||
from chronos.chronos2 import Chronos2Model
|
||||
|
|
@ -114,6 +113,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
min_past: int | None = None,
|
||||
finetuned_ckpt_name: str = "finetuned-ckpt",
|
||||
callbacks: list["TrainerCallback"] | None = None,
|
||||
remove_printer_callback: bool = False,
|
||||
**extra_trainer_kwargs,
|
||||
) -> "Chronos2Pipeline":
|
||||
"""
|
||||
|
|
@ -156,6 +156,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
The name of the directory inside `output_dir` in which the final fine-tuned checkpoint will be saved, by default "finetuned-ckpt"
|
||||
callbacks
|
||||
A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer`
|
||||
remove_printer_callback
|
||||
If True, all instances of `PrinterCallback` are removed from callbacks
|
||||
**extra_trainer_kwargs
|
||||
Extra kwargs are directly forwarded to `TrainingArguments`
|
||||
|
||||
|
|
@ -165,6 +167,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
"""
|
||||
|
||||
import torch.cuda
|
||||
from transformers.trainer_callback import PrinterCallback
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
if finetune_mode == "lora":
|
||||
|
|
@ -322,6 +325,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
eval_dataset=eval_dataset,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if remove_printer_callback:
|
||||
trainer.pop_callback(PrinterCallback)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# update max_output_patches, if the model was fine-tuned with longer prediction_length
|
||||
|
|
|
|||
Loading…
Reference in a new issue