diff --git a/scripts/training/train.py b/scripts/training/train.py index bbdf6cd..ada4a99 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -491,6 +491,25 @@ def main( top_p: float = 1.0, seed: Optional[int] = None, ): + if tf32 and not ( + torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 + ): + # TF32 floating point format is available only on NVIDIA GPUs + # with compute capability 8 and above. See link for details. + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capability-8-x + log_on_main( + "TF32 format is only available on devices with compute capability >= 8. " + "Setting tf32 to False.", + logger, + ) + tf32 = False + + if seed is None: + seed = random.randint(0, 2**32) + + log_on_main(f"Using SEED: {seed}", logger) + transformers.set_seed(seed=seed) + raw_training_config = deepcopy(locals()) output_dir = Path(output_dir) training_data_paths = ast.literal_eval(training_data_paths) @@ -511,12 +530,6 @@ def main( if not model_type == "seq2seq": raise NotImplementedError("Only seq2seq models are currently supported") - if seed is None: - seed = random.randint(0, 2**32) - - log_on_main(f"Using SEED: {seed}", logger) - transformers.set_seed(seed=seed) - output_dir = get_next_path("run", base_dir=output_dir, file_type="") log_on_main(f"Logging dir: {output_dir}", logger)