mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Chronos-2: Add option to disable DataParallel (#434)
*Issue #, if available:* *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
eb5b61234a
commit
71ff0d64ba
1 changed files with 8 additions and 0 deletions
|
|
@ -114,6 +114,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
finetuned_ckpt_name: str = "finetuned-ckpt",
|
||||
callbacks: list["TrainerCallback"] | None = None,
|
||||
remove_printer_callback: bool = False,
|
||||
disable_data_parallel: bool = True,
|
||||
**extra_trainer_kwargs,
|
||||
) -> "Chronos2Pipeline":
|
||||
"""
|
||||
|
|
@ -158,6 +159,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer`
|
||||
remove_printer_callback
|
||||
If True, all instances of `PrinterCallback` are removed from callbacks
|
||||
disable_data_parallel
|
||||
If True, ensures that DataParallel is disabled and training happens on a single GPU
|
||||
**extra_trainer_kwargs
|
||||
Extra kwargs are directly forwarded to `TrainingArguments`
|
||||
|
||||
|
|
@ -319,6 +322,11 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
|
||||
training_args = TrainingArguments(**training_kwargs)
|
||||
|
||||
if disable_data_parallel and not use_cpu:
|
||||
# This is a hack to disable the default `transformers` behavior of using DataParallel
|
||||
training_args._n_gpu = 1
|
||||
assert training_args.n_gpu == 1 # Ensure that the hack worked
|
||||
|
||||
trainer = Chronos2Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
|
|
|||
Loading…
Reference in a new issue