Save training job info (#80)

*Description of changes:* This PR updates the training script to also
save the training details in the final checkpoint.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
This commit is contained in:
Abdul Fatir 2024-05-27 09:57:18 +02:00 committed by GitHub
parent 55166d3227
commit 16f927ccfe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -5,11 +5,14 @@ import ast
import logging
import os
import re
import sys
import json
import itertools
import random
from copy import deepcopy
from pathlib import Path
from functools import partial
from typing import List, Iterator, Optional
from typing import List, Iterator, Optional, Dict
import typer
from typer_config import use_yaml_config
@ -26,6 +29,8 @@ from transformers import (
Trainer,
TrainingArguments,
)
import accelerate
import gluonts
from gluonts.dataset.common import FileDataset
from gluonts.itertools import Cyclic, Map, Filter
from gluonts.transform import (
@ -59,6 +64,63 @@ def log_on_main(msg: str, logger: logging.Logger, log_level: int = logging.INFO)
logger.log(log_level, msg)
def get_training_job_info() -> Dict:
"""
Returns info about this training job.
"""
job_info = {}
# CUDA info
job_info["cuda_available"] = torch.cuda.is_available()
print(job_info)
if torch.cuda.is_available():
job_info["device_count"] = torch.cuda.device_count()
print(job_info)
job_info["device_names"] = {
idx: torch.cuda.get_device_name(idx)
for idx in range(torch.cuda.device_count())
}
print(job_info)
print(torch.cuda.mem_get_info(device=0))
print(torch.cuda.mem_get_info(device=1))
job_info["mem_info"] = {
idx: torch.cuda.mem_get_info(device=idx)
for idx in range(torch.cuda.device_count())
}
print(job_info)
# DDP info
job_info["torchelastic_launched"] = dist.is_torchelastic_launched()
if dist.is_torchelastic_launched():
job_info["world_size"] = dist.get_world_size()
# Versions
job_info["python_version"] = sys.version.replace("\n", " ")
job_info["torch_version"] = torch.__version__
job_info["numpy_version"] = np.__version__
job_info["gluonts_version"] = gluonts.__version__
job_info["transformers_version"] = transformers.__version__
job_info["accelerate_version"] = accelerate.__version__
return job_info
def save_training_info(ckpt_path: Path, training_config: Dict):
"""
Save info about this training job in a json file for documentation.
"""
assert ckpt_path.is_dir()
with open(ckpt_path / "training_info.json", "w") as fp:
json.dump(
{"training_config": training_config, "job_info": get_training_job_info()},
fp,
indent=4,
)
def get_next_path(
base_fname: str,
base_dir: Path,
@ -407,7 +469,7 @@ def main(
model_type: str = "seq2seq",
random_init: bool = False,
tie_embeddings: bool = False,
output_dir: Path = Path("./output/"),
output_dir: str = "./output/",
tf32: bool = True,
torch_compile: bool = True,
tokenizer_class: str = "MeanScaleUniformBins",
@ -427,6 +489,8 @@ def main(
top_p: float = 1.0,
seed: Optional[int] = None,
):
raw_training_config = deepcopy(locals())
output_dir = Path(output_dir)
training_data_paths = ast.literal_eval(training_data_paths)
assert isinstance(training_data_paths, list)
@ -554,6 +618,9 @@ def main(
if is_main_process():
model.save_pretrained(output_dir / "checkpoint-final")
save_training_info(
output_dir / "checkpoint-final", training_config=raw_training_config
)
if __name__ == "__main__":