From 050d600f6430c0608904ae6e3a90c1bf6b93cf84 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Tue, 23 Jul 2024 10:37:14 +0200 Subject: [PATCH] Bound number of workers by number of datasets (#157) *Issue #, if available:* Fixes #154 *Description of changes:* Prior to the fix, some workers have no dataset to consume if `dataloader_num_workers > len(training_data_paths)`. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- scripts/training/train.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/scripts/training/train.py b/scripts/training/train.py index 01e1678..c16092e 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -569,6 +569,16 @@ def main( probability = [1.0 / len(training_data_paths)] * len(training_data_paths) assert isinstance(probability, list) + assert len(training_data_paths) == len(probability) + + if dataloader_num_workers > len(training_data_paths): + log_on_main( + f"Setting the number of data loader workers to {len(training_data_paths)}, " + f"instead of {dataloader_num_workers}.", + logger, + ) + dataloader_num_workers = len(training_data_paths) + if isinstance(tokenizer_kwargs, str): tokenizer_kwargs = ast.literal_eval(tokenizer_kwargs) assert isinstance(tokenizer_kwargs, dict)