Chronos-2: Convert assert about batch size to warning (#392)

This commit is contained in:
Abdul Fatir 2025-11-24 09:22:43 +01:00 committed by GitHub
parent 7daaa7194c
commit 972a09b626
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3,6 +3,7 @@
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>
import warnings
from typing import TYPE_CHECKING, cast
from torch.utils.data import DataLoader, Dataset
@ -48,11 +49,16 @@ class Chronos2Trainer(Trainer):
train_dataset = cast("Chronos2Dataset", self.train_dataset)
assert train_dataset.batch_size == self.args.train_batch_size, (
f"The batch_size of the train_dataset ({train_dataset.batch_size}) does not match the batch_size "
f"in TrainingArguments ({self.args.train_batch_size}). If you're using a machine with multiple GPUs, "
f"ensure that only a single GPU is visible by setting the CUDA_VISIBLE_DEVICES environment variable."
)
if self.args.train_batch_size > train_dataset.batch_size:
warnings.warn(
f"The batch_size of the train_dataset ({train_dataset.batch_size}) does not match the batch_size "
f"in TrainingArguments ({self.args.train_batch_size}). On machines with multiple GPUs, this may indicate "
f"that multiple GPUs are visible and transformers is using DataParallel for training by default. "
f"This may lead to unnecessary slowdown and unexpected behavior. We strongly recommend setting the CUDA_VISIBLE_DEVICES "
f"environment variable to ensure that only a single GPU is visible.",
category=UserWarning,
stacklevel=3,
)
dataloader_params = {
# Disable automatic batching as we handle batching ourselves
@ -74,11 +80,16 @@ class Chronos2Trainer(Trainer):
eval_dataset = cast("Chronos2Dataset", self.eval_dataset)
assert eval_dataset.batch_size == self.args.eval_batch_size, (
f"The batch_size of the eval_dataset ({eval_dataset.batch_size}) does not match the batch_size "
f"in TrainingArguments ({self.args.eval_batch_size}). If you're using a machine with multiple GPUs, "
f"ensure that only a single GPU is visible by setting the CUDA_VISIBLE_DEVICES environment variable."
)
if self.args.eval_batch_size > eval_dataset.batch_size:
warnings.warn(
f"The batch_size of the eval_dataset ({eval_dataset.batch_size}) does not match the batch_size "
f"in TrainingArguments ({self.args.eval_batch_size}). On machines with multiple GPUs, this may indicate "
f"that multiple GPUs are visible and transformers is using DataParallel for training by default. "
f"This may lead to unnecessary slowdown and unexpected behavior. We strongly recommend setting the CUDA_VISIBLE_DEVICES "
f"environment variable to ensure that only a single GPU is visible.",
category=UserWarning,
stacklevel=3,
)
dataloader_params = {
# Disable automatic batching as we handle batching ourselves