mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Chronos-2: Add option to specify callbacks in fit (#405)
*Issue #, if available:* *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
514019fb1c
commit
5cde42eb1f
1 changed files with 6 additions and 1 deletions
|
|
@ -19,6 +19,7 @@ from transformers import AutoConfig
|
|||
from transformers.utils.import_utils import is_peft_available
|
||||
from transformers.utils.peft_utils import find_adapter_config_file
|
||||
|
||||
|
||||
import chronos.chronos2
|
||||
from chronos.base import BaseChronosPipeline, ForecastType
|
||||
from chronos.chronos2 import Chronos2Model
|
||||
|
|
@ -31,6 +32,7 @@ if TYPE_CHECKING:
|
|||
import fev
|
||||
import pandas as pd
|
||||
from peft import LoraConfig
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -111,6 +113,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
output_dir: Path | str | None = None,
|
||||
min_past: int | None = None,
|
||||
finetuned_ckpt_name: str = "finetuned-ckpt",
|
||||
callbacks: list["TrainerCallback"] | None = None,
|
||||
**extra_trainer_kwargs,
|
||||
) -> "Chronos2Pipeline":
|
||||
"""
|
||||
|
|
@ -151,6 +154,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
are filtered out, by default set equal to prediction_length
|
||||
finetuned_ckpt_name
|
||||
The name of the directory inside `output_dir` in which the final fine-tuned checkpoint will be saved, by default "finetuned-ckpt"
|
||||
callbacks
|
||||
A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer`
|
||||
**extra_trainer_kwargs
|
||||
Extra kwargs are directly forwarded to `TrainingArguments`
|
||||
|
||||
|
|
@ -276,7 +281,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
)
|
||||
|
||||
eval_dataset = None
|
||||
callbacks = []
|
||||
callbacks = callbacks or []
|
||||
if validation_inputs is not None:
|
||||
# construct validation dataset
|
||||
eval_dataset = Chronos2Dataset.convert_inputs(
|
||||
|
|
|
|||
Loading…
Reference in a new issue