mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Save training job info (#80)
*Description of changes:* This PR updates the training script to also save the training details in the final checkpoint. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
This commit is contained in:
parent
55166d3227
commit
16f927ccfe
1 changed files with 69 additions and 2 deletions
|
|
@ -5,11 +5,14 @@ import ast
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
import itertools
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from typing import List, Iterator, Optional
|
||||
from typing import List, Iterator, Optional, Dict
|
||||
|
||||
import typer
|
||||
from typer_config import use_yaml_config
|
||||
|
|
@ -26,6 +29,8 @@ from transformers import (
|
|||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
import accelerate
|
||||
import gluonts
|
||||
from gluonts.dataset.common import FileDataset
|
||||
from gluonts.itertools import Cyclic, Map, Filter
|
||||
from gluonts.transform import (
|
||||
|
|
@ -59,6 +64,63 @@ def log_on_main(msg: str, logger: logging.Logger, log_level: int = logging.INFO)
|
|||
logger.log(log_level, msg)
|
||||
|
||||
|
||||
def get_training_job_info() -> Dict:
|
||||
"""
|
||||
Returns info about this training job.
|
||||
"""
|
||||
job_info = {}
|
||||
|
||||
# CUDA info
|
||||
job_info["cuda_available"] = torch.cuda.is_available()
|
||||
print(job_info)
|
||||
if torch.cuda.is_available():
|
||||
job_info["device_count"] = torch.cuda.device_count()
|
||||
print(job_info)
|
||||
|
||||
job_info["device_names"] = {
|
||||
idx: torch.cuda.get_device_name(idx)
|
||||
for idx in range(torch.cuda.device_count())
|
||||
}
|
||||
print(job_info)
|
||||
|
||||
print(torch.cuda.mem_get_info(device=0))
|
||||
print(torch.cuda.mem_get_info(device=1))
|
||||
job_info["mem_info"] = {
|
||||
idx: torch.cuda.mem_get_info(device=idx)
|
||||
for idx in range(torch.cuda.device_count())
|
||||
}
|
||||
print(job_info)
|
||||
|
||||
# DDP info
|
||||
job_info["torchelastic_launched"] = dist.is_torchelastic_launched()
|
||||
|
||||
if dist.is_torchelastic_launched():
|
||||
job_info["world_size"] = dist.get_world_size()
|
||||
|
||||
# Versions
|
||||
job_info["python_version"] = sys.version.replace("\n", " ")
|
||||
job_info["torch_version"] = torch.__version__
|
||||
job_info["numpy_version"] = np.__version__
|
||||
job_info["gluonts_version"] = gluonts.__version__
|
||||
job_info["transformers_version"] = transformers.__version__
|
||||
job_info["accelerate_version"] = accelerate.__version__
|
||||
|
||||
return job_info
|
||||
|
||||
|
||||
def save_training_info(ckpt_path: Path, training_config: Dict):
|
||||
"""
|
||||
Save info about this training job in a json file for documentation.
|
||||
"""
|
||||
assert ckpt_path.is_dir()
|
||||
with open(ckpt_path / "training_info.json", "w") as fp:
|
||||
json.dump(
|
||||
{"training_config": training_config, "job_info": get_training_job_info()},
|
||||
fp,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
|
||||
def get_next_path(
|
||||
base_fname: str,
|
||||
base_dir: Path,
|
||||
|
|
@ -407,7 +469,7 @@ def main(
|
|||
model_type: str = "seq2seq",
|
||||
random_init: bool = False,
|
||||
tie_embeddings: bool = False,
|
||||
output_dir: Path = Path("./output/"),
|
||||
output_dir: str = "./output/",
|
||||
tf32: bool = True,
|
||||
torch_compile: bool = True,
|
||||
tokenizer_class: str = "MeanScaleUniformBins",
|
||||
|
|
@ -427,6 +489,8 @@ def main(
|
|||
top_p: float = 1.0,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
raw_training_config = deepcopy(locals())
|
||||
output_dir = Path(output_dir)
|
||||
training_data_paths = ast.literal_eval(training_data_paths)
|
||||
assert isinstance(training_data_paths, list)
|
||||
|
||||
|
|
@ -554,6 +618,9 @@ def main(
|
|||
|
||||
if is_main_process():
|
||||
model.save_pretrained(output_dir / "checkpoint-final")
|
||||
save_training_info(
|
||||
output_dir / "checkpoint-final", training_config=raw_training_config
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue