From 71ff0d64baad4f28f2c4640413fc6b6b0d1b4b41 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Mon, 15 Dec 2025 19:27:05 +0100 Subject: [PATCH] Chronos-2: Add option to disable `DataParallel` (#434) *Issue #, if available:* *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- src/chronos/chronos2/pipeline.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 3eddcd7..e99d8d9 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -114,6 +114,7 @@ class Chronos2Pipeline(BaseChronosPipeline): finetuned_ckpt_name: str = "finetuned-ckpt", callbacks: list["TrainerCallback"] | None = None, remove_printer_callback: bool = False, + disable_data_parallel: bool = True, **extra_trainer_kwargs, ) -> "Chronos2Pipeline": """ @@ -158,6 +159,8 @@ class Chronos2Pipeline(BaseChronosPipeline): A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer` remove_printer_callback If True, all instances of `PrinterCallback` are removed from callbacks + disable_data_parallel + If True, ensures that DataParallel is disabled and training happens on a single GPU **extra_trainer_kwargs Extra kwargs are directly forwarded to `TrainingArguments` @@ -319,6 +322,11 @@ class Chronos2Pipeline(BaseChronosPipeline): training_args = TrainingArguments(**training_kwargs) + if disable_data_parallel and not use_cpu: + # This is a hack to disable the default `transformers` behavior of using DataParallel + training_args._n_gpu = 1 + assert training_args.n_gpu == 1 # Ensure that the hack worked + trainer = Chronos2Trainer( model=model, args=training_args,