mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 01:29:48 +00:00
Enhance training script: auto tf32 detection and reorder default seed setting (#91)
*Description of changes:* Automatically set `tf32` to `False` if used on an older NVIDIA GPU. Reorder seed so that the seed is saved as part of the training config. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
This commit is contained in:
parent
b0bdbd9d1a
commit
6bcd4584a3
1 changed files with 19 additions and 6 deletions
|
|
@ -491,6 +491,25 @@ def main(
|
|||
top_p: float = 1.0,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
if tf32 and not (
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
|
||||
):
|
||||
# TF32 floating point format is available only on NVIDIA GPUs
|
||||
# with compute capability 8 and above. See link for details.
|
||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capability-8-x
|
||||
log_on_main(
|
||||
"TF32 format is only available on devices with compute capability >= 8. "
|
||||
"Setting tf32 to False.",
|
||||
logger,
|
||||
)
|
||||
tf32 = False
|
||||
|
||||
if seed is None:
|
||||
seed = random.randint(0, 2**32)
|
||||
|
||||
log_on_main(f"Using SEED: {seed}", logger)
|
||||
transformers.set_seed(seed=seed)
|
||||
|
||||
raw_training_config = deepcopy(locals())
|
||||
output_dir = Path(output_dir)
|
||||
training_data_paths = ast.literal_eval(training_data_paths)
|
||||
|
|
@ -511,12 +530,6 @@ def main(
|
|||
if not model_type == "seq2seq":
|
||||
raise NotImplementedError("Only seq2seq models are currently supported")
|
||||
|
||||
if seed is None:
|
||||
seed = random.randint(0, 2**32)
|
||||
|
||||
log_on_main(f"Using SEED: {seed}", logger)
|
||||
transformers.set_seed(seed=seed)
|
||||
|
||||
output_dir = get_next_path("run", base_dir=output_dir, file_type="")
|
||||
|
||||
log_on_main(f"Logging dir: {output_dir}", logger)
|
||||
|
|
|
|||
Loading…
Reference in a new issue