diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 30785ef..125d1d9 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -19,6 +19,7 @@ from transformers import AutoConfig from transformers.utils.import_utils import is_peft_available from transformers.utils.peft_utils import find_adapter_config_file + import chronos.chronos2 from chronos.base import BaseChronosPipeline, ForecastType from chronos.chronos2 import Chronos2Model @@ -31,6 +32,7 @@ if TYPE_CHECKING: import fev import pandas as pd from peft import LoraConfig + from transformers.trainer_callback import TrainerCallback logger = logging.getLogger(__name__) @@ -111,6 +113,7 @@ class Chronos2Pipeline(BaseChronosPipeline): output_dir: Path | str | None = None, min_past: int | None = None, finetuned_ckpt_name: str = "finetuned-ckpt", + callbacks: list["TrainerCallback"] | None = None, **extra_trainer_kwargs, ) -> "Chronos2Pipeline": """ @@ -151,6 +154,8 @@ class Chronos2Pipeline(BaseChronosPipeline): are filtered out, by default set equal to prediction_length finetuned_ckpt_name The name of the directory inside `output_dir` in which the final fine-tuned checkpoint will be saved, by default "finetuned-ckpt" + callbacks + A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer` **extra_trainer_kwargs Extra kwargs are directly forwarded to `TrainingArguments` @@ -276,7 +281,7 @@ class Chronos2Pipeline(BaseChronosPipeline): ) eval_dataset = None - callbacks = [] + callbacks = callbacks or [] if validation_inputs is not None: # construct validation dataset eval_dataset = Chronos2Dataset.convert_inputs(