From 6bcd4584a31453aa2db50b4a2e58ef80c98c4ece Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Fri, 31 May 2024 15:13:49 +0200 Subject: [PATCH] Enhance training script: auto tf32 detection and reorder default seed setting (#91) *Description of changes:* Automatically set `tf32` to `False` if used on an older NVIDIA GPU. Reorder seed so that the seed is saved as part of the training config. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Abdul Fatir Ansari --- scripts/training/train.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/scripts/training/train.py b/scripts/training/train.py index bbdf6cd..ada4a99 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -491,6 +491,25 @@ def main( top_p: float = 1.0, seed: Optional[int] = None, ): + if tf32 and not ( + torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 + ): + # TF32 floating point format is available only on NVIDIA GPUs + # with compute capability 8 and above. See link for details. + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capability-8-x + log_on_main( + "TF32 format is only available on devices with compute capability >= 8. " + "Setting tf32 to False.", + logger, + ) + tf32 = False + + if seed is None: + seed = random.randint(0, 2**32) + + log_on_main(f"Using SEED: {seed}", logger) + transformers.set_seed(seed=seed) + raw_training_config = deepcopy(locals()) output_dir = Path(output_dir) training_data_paths = ast.literal_eval(training_data_paths) @@ -511,12 +530,6 @@ def main( if not model_type == "seq2seq": raise NotImplementedError("Only seq2seq models are currently supported") - if seed is None: - seed = random.randint(0, 2**32) - - log_on_main(f"Using SEED: {seed}", logger) - transformers.set_seed(seed=seed) - output_dir = get_next_path("run", base_dir=output_dir, file_type="") log_on_main(f"Logging dir: {output_dir}", logger)