diff --git a/scripts/training/train.py b/scripts/training/train.py index 3c6b21d..1973121 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -5,11 +5,14 @@ import ast import logging import os import re +import sys +import json import itertools import random +from copy import deepcopy from pathlib import Path from functools import partial -from typing import List, Iterator, Optional +from typing import List, Iterator, Optional, Dict import typer from typer_config import use_yaml_config @@ -26,6 +29,8 @@ from transformers import ( Trainer, TrainingArguments, ) +import accelerate +import gluonts from gluonts.dataset.common import FileDataset from gluonts.itertools import Cyclic, Map, Filter from gluonts.transform import ( @@ -59,6 +64,63 @@ def log_on_main(msg: str, logger: logging.Logger, log_level: int = logging.INFO) logger.log(log_level, msg) +def get_training_job_info() -> Dict: + """ + Returns info about this training job. + """ + job_info = {} + + # CUDA info + job_info["cuda_available"] = torch.cuda.is_available() + print(job_info) + if torch.cuda.is_available(): + job_info["device_count"] = torch.cuda.device_count() + print(job_info) + + job_info["device_names"] = { + idx: torch.cuda.get_device_name(idx) + for idx in range(torch.cuda.device_count()) + } + print(job_info) + + print(torch.cuda.mem_get_info(device=0)) + print(torch.cuda.mem_get_info(device=1)) + job_info["mem_info"] = { + idx: torch.cuda.mem_get_info(device=idx) + for idx in range(torch.cuda.device_count()) + } + print(job_info) + + # DDP info + job_info["torchelastic_launched"] = dist.is_torchelastic_launched() + + if dist.is_torchelastic_launched(): + job_info["world_size"] = dist.get_world_size() + + # Versions + job_info["python_version"] = sys.version.replace("\n", " ") + job_info["torch_version"] = torch.__version__ + job_info["numpy_version"] = np.__version__ + job_info["gluonts_version"] = gluonts.__version__ + job_info["transformers_version"] = transformers.__version__ + job_info["accelerate_version"] = accelerate.__version__ + + return job_info + + +def save_training_info(ckpt_path: Path, training_config: Dict): + """ + Save info about this training job in a json file for documentation. + """ + assert ckpt_path.is_dir() + with open(ckpt_path / "training_info.json", "w") as fp: + json.dump( + {"training_config": training_config, "job_info": get_training_job_info()}, + fp, + indent=4, + ) + + def get_next_path( base_fname: str, base_dir: Path, @@ -407,7 +469,7 @@ def main( model_type: str = "seq2seq", random_init: bool = False, tie_embeddings: bool = False, - output_dir: Path = Path("./output/"), + output_dir: str = "./output/", tf32: bool = True, torch_compile: bool = True, tokenizer_class: str = "MeanScaleUniformBins", @@ -427,6 +489,8 @@ def main( top_p: float = 1.0, seed: Optional[int] = None, ): + raw_training_config = deepcopy(locals()) + output_dir = Path(output_dir) training_data_paths = ast.literal_eval(training_data_paths) assert isinstance(training_data_paths, list) @@ -554,6 +618,9 @@ def main( if is_main_process(): model.save_pretrained(output_dir / "checkpoint-final") + save_training_info( + output_dir / "checkpoint-final", training_config=raw_training_config + ) if __name__ == "__main__":