mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
Fix train resume for non-end2end models (#24173)
Co-authored-by: Jing Qiu <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
parent
fec04ba66c
commit
f06a9bbadd
2 changed files with 43 additions and 2 deletions
|
|
@ -7,11 +7,12 @@ from unittest import mock
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from tests import MODEL, SOURCE
|
||||
from tests import MODEL, SOURCE, TASK_MODEL_DATA
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.engine.exporter import Exporter
|
||||
from ultralytics.models.yolo import classify, detect, segment
|
||||
from ultralytics.nn.tasks import load_checkpoint
|
||||
from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR
|
||||
|
||||
|
||||
|
|
@ -127,6 +128,46 @@ def test_classify():
|
|||
result = pred(source=ASSETS, model=trainer.best)
|
||||
assert len(result), "predictor test failed"
|
||||
|
||||
# Test resume functionality
|
||||
with pytest.raises(AssertionError):
|
||||
classify.ClassificationTrainer(overrides={**overrides, "resume": trainer.last}).train()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task,weight,data", TASK_MODEL_DATA)
|
||||
def test_resume_incomplete(task, weight, data, tmp_path):
|
||||
"""Test training resumes from an incomplete checkpoint."""
|
||||
train_args = {
|
||||
"data": data,
|
||||
"epochs": 2,
|
||||
"save": True,
|
||||
"plots": False,
|
||||
"workers": 0,
|
||||
"project": tmp_path,
|
||||
"name": task,
|
||||
"imgsz": 32,
|
||||
"exist_ok": True,
|
||||
}
|
||||
|
||||
def stop_after_first_epoch(trainer):
|
||||
if trainer.epoch == 0:
|
||||
trainer.stop = True
|
||||
|
||||
def disable_final_eval(trainer):
|
||||
trainer.final_eval = lambda: None
|
||||
|
||||
model = YOLO(weight)
|
||||
model.add_callback("on_train_start", disable_final_eval)
|
||||
model.add_callback("on_train_epoch_end", stop_after_first_epoch)
|
||||
model.train(**train_args)
|
||||
last_path = model.trainer.last
|
||||
_, ckpt = load_checkpoint(last_path)
|
||||
assert ckpt["epoch"] == 0, "checkpoint should be resumable"
|
||||
|
||||
# Resume training using the checkpoint
|
||||
resume_model = YOLO(last_path)
|
||||
resume_model.train(resume=True, **train_args)
|
||||
assert resume_model.trainer.start_epoch == resume_model.trainer.epoch == 1, "resume test failed"
|
||||
|
||||
|
||||
def test_nan_recovery():
|
||||
"""Test NaN loss detection and recovery during training."""
|
||||
|
|
|
|||
|
|
@ -961,7 +961,7 @@ class BaseTrainer:
|
|||
)
|
||||
self.epochs += ckpt["epoch"] # finetune additional epochs
|
||||
self._load_checkpoint_state(ckpt)
|
||||
if unwrap_model(self.model).end2end:
|
||||
if getattr(unwrap_model(self.model), "end2end", False):
|
||||
# initialize loss and resume o2o and o2m args
|
||||
unwrap_model(self.model).criterion = unwrap_model(self.model).init_criterion()
|
||||
unwrap_model(self.model).criterion.updates = start_epoch - 1
|
||||
|
|
|
|||
Loading…
Reference in a new issue