Enhance training script: auto tf32 detection and reorder default seed setting (#91)

*Description of changes:* Automatically set `tf32` to `False` if used on
an older NVIDIA GPU. Reorder seed so that the seed is saved as part of
the training config.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
This commit is contained in:
Abdul Fatir 2024-05-31 15:13:49 +02:00 committed by GitHub
parent b0bdbd9d1a
commit 6bcd4584a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -491,6 +491,25 @@ def main(
top_p: float = 1.0,
seed: Optional[int] = None,
):
if tf32 and not (
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
):
# TF32 floating point format is available only on NVIDIA GPUs
# with compute capability 8 and above. See link for details.
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capability-8-x
log_on_main(
"TF32 format is only available on devices with compute capability >= 8. "
"Setting tf32 to False.",
logger,
)
tf32 = False
if seed is None:
seed = random.randint(0, 2**32)
log_on_main(f"Using SEED: {seed}", logger)
transformers.set_seed(seed=seed)
raw_training_config = deepcopy(locals())
output_dir = Path(output_dir)
training_data_paths = ast.literal_eval(training_data_paths)
@ -511,12 +530,6 @@ def main(
if not model_type == "seq2seq":
raise NotImplementedError("Only seq2seq models are currently supported")
if seed is None:
seed = random.randint(0, 2**32)
log_on_main(f"Using SEED: {seed}", logger)
transformers.set_seed(seed=seed)
output_dir = get_next_path("run", base_dir=output_dir, file_type="")
log_on_main(f"Logging dir: {output_dir}", logger)