diff --git a/scripts/training/train.py b/scripts/training/train.py index 01e1678..c16092e 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -569,6 +569,16 @@ def main( probability = [1.0 / len(training_data_paths)] * len(training_data_paths) assert isinstance(probability, list) + assert len(training_data_paths) == len(probability) + + if dataloader_num_workers > len(training_data_paths): + log_on_main( + f"Setting the number of data loader workers to {len(training_data_paths)}, " + f"instead of {dataloader_num_workers}.", + logger, + ) + dataloader_num_workers = len(training_data_paths) + if isinstance(tokenizer_kwargs, str): tokenizer_kwargs = ast.literal_eval(tokenizer_kwargs) assert isinstance(tokenizer_kwargs, dict)