Bound number of workers by number of datasets (#157)

*Issue #, if available:* Fixes #154

*Description of changes:* Prior to the fix, some workers have no dataset
to consume if `dataloader_num_workers > len(training_data_paths)`.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Lorenzo Stella 2024-07-23 10:37:14 +02:00 committed by GitHub
parent 9d59057b72
commit 050d600f64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -569,6 +569,16 @@ def main(
probability = [1.0 / len(training_data_paths)] * len(training_data_paths)
assert isinstance(probability, list)
assert len(training_data_paths) == len(probability)
if dataloader_num_workers > len(training_data_paths):
log_on_main(
f"Setting the number of data loader workers to {len(training_data_paths)}, "
f"instead of {dataloader_num_workers}.",
logger,
)
dataloader_num_workers = len(training_data_paths)
if isinstance(tokenizer_kwargs, str):
tokenizer_kwargs = ast.literal_eval(tokenizer_kwargs)
assert isinstance(tokenizer_kwargs, dict)