Chronos-2: Add option to disable DataParallel (#434)

*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Abdul Fatir 2025-12-15 19:27:05 +01:00 committed by GitHub
parent eb5b61234a
commit 71ff0d64ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -114,6 +114,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
finetuned_ckpt_name: str = "finetuned-ckpt",
callbacks: list["TrainerCallback"] | None = None,
remove_printer_callback: bool = False,
disable_data_parallel: bool = True,
**extra_trainer_kwargs,
) -> "Chronos2Pipeline":
"""
@ -158,6 +159,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer`
remove_printer_callback
If True, all instances of `PrinterCallback` are removed from callbacks
disable_data_parallel
If True, ensures that DataParallel is disabled and training happens on a single GPU
**extra_trainer_kwargs
Extra kwargs are directly forwarded to `TrainingArguments`
@ -319,6 +322,11 @@ class Chronos2Pipeline(BaseChronosPipeline):
training_args = TrainingArguments(**training_kwargs)
if disable_data_parallel and not use_cpu:
# This is a hack to disable the default `transformers` behavior of using DataParallel
training_args._n_gpu = 1
assert training_args.n_gpu == 1 # Ensure that the hack worked
trainer = Chronos2Trainer(
model=model,
args=training_args,