diff --git a/src/chronos/chronos2/trainer.py b/src/chronos/chronos2/trainer.py index 561da9e..bd4e464 100644 --- a/src/chronos/chronos2/trainer.py +++ b/src/chronos/chronos2/trainer.py @@ -3,6 +3,7 @@ # Authors: Abdul Fatir Ansari +import warnings from typing import TYPE_CHECKING, cast from torch.utils.data import DataLoader, Dataset @@ -48,11 +49,16 @@ class Chronos2Trainer(Trainer): train_dataset = cast("Chronos2Dataset", self.train_dataset) - assert train_dataset.batch_size == self.args.train_batch_size, ( - f"The batch_size of the train_dataset ({train_dataset.batch_size}) does not match the batch_size " - f"in TrainingArguments ({self.args.train_batch_size}). If you're using a machine with multiple GPUs, " - f"ensure that only a single GPU is visible by setting the CUDA_VISIBLE_DEVICES environment variable." - ) + if self.args.train_batch_size > train_dataset.batch_size: + warnings.warn( + f"The batch_size of the train_dataset ({train_dataset.batch_size}) does not match the batch_size " + f"in TrainingArguments ({self.args.train_batch_size}). On machines with multiple GPUs, this may indicate " + f"that multiple GPUs are visible and transformers is using DataParallel for training by default. " + f"This may lead to unnecessary slowdown and unexpected behavior. We strongly recommend setting the CUDA_VISIBLE_DEVICES " + f"environment variable to ensure that only a single GPU is visible.", + category=UserWarning, + stacklevel=3, + ) dataloader_params = { # Disable automatic batching as we handle batching ourselves @@ -74,11 +80,16 @@ class Chronos2Trainer(Trainer): eval_dataset = cast("Chronos2Dataset", self.eval_dataset) - assert eval_dataset.batch_size == self.args.eval_batch_size, ( - f"The batch_size of the eval_dataset ({eval_dataset.batch_size}) does not match the batch_size " - f"in TrainingArguments ({self.args.eval_batch_size}). If you're using a machine with multiple GPUs, " - f"ensure that only a single GPU is visible by setting the CUDA_VISIBLE_DEVICES environment variable." - ) + if self.args.eval_batch_size > eval_dataset.batch_size: + warnings.warn( + f"The batch_size of the eval_dataset ({eval_dataset.batch_size}) does not match the batch_size " + f"in TrainingArguments ({self.args.eval_batch_size}). On machines with multiple GPUs, this may indicate " + f"that multiple GPUs are visible and transformers is using DataParallel for training by default. " + f"This may lead to unnecessary slowdown and unexpected behavior. We strongly recommend setting the CUDA_VISIBLE_DEVICES " + f"environment variable to ensure that only a single GPU is visible.", + category=UserWarning, + stacklevel=3, + ) dataloader_params = { # Disable automatic batching as we handle batching ourselves