Fix train resume for non-end2end models (#24173)

Co-authored-by: Jing Qiu <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
fatih akyon 2026-04-10 18:29:30 +03:00 committed by GitHub
parent fec04ba66c
commit f06a9bbadd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 43 additions and 2 deletions

View file

@ -7,11 +7,12 @@ from unittest import mock
import pytest
import torch
from tests import MODEL, SOURCE
from tests import MODEL, SOURCE, TASK_MODEL_DATA
from ultralytics import YOLO
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.models.yolo import classify, detect, segment
from ultralytics.nn.tasks import load_checkpoint
from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR
@ -127,6 +128,46 @@ def test_classify():
result = pred(source=ASSETS, model=trainer.best)
assert len(result), "predictor test failed"
# Test resume functionality
with pytest.raises(AssertionError):
classify.ClassificationTrainer(overrides={**overrides, "resume": trainer.last}).train()
@pytest.mark.parametrize("task,weight,data", TASK_MODEL_DATA)
def test_resume_incomplete(task, weight, data, tmp_path):
"""Test training resumes from an incomplete checkpoint."""
train_args = {
"data": data,
"epochs": 2,
"save": True,
"plots": False,
"workers": 0,
"project": tmp_path,
"name": task,
"imgsz": 32,
"exist_ok": True,
}
def stop_after_first_epoch(trainer):
if trainer.epoch == 0:
trainer.stop = True
def disable_final_eval(trainer):
trainer.final_eval = lambda: None
model = YOLO(weight)
model.add_callback("on_train_start", disable_final_eval)
model.add_callback("on_train_epoch_end", stop_after_first_epoch)
model.train(**train_args)
last_path = model.trainer.last
_, ckpt = load_checkpoint(last_path)
assert ckpt["epoch"] == 0, "checkpoint should be resumable"
# Resume training using the checkpoint
resume_model = YOLO(last_path)
resume_model.train(resume=True, **train_args)
assert resume_model.trainer.start_epoch == resume_model.trainer.epoch == 1, "resume test failed"
def test_nan_recovery():
"""Test NaN loss detection and recovery during training."""

View file

@ -961,7 +961,7 @@ class BaseTrainer:
)
self.epochs += ckpt["epoch"] # finetune additional epochs
self._load_checkpoint_state(ckpt)
if unwrap_model(self.model).end2end:
if getattr(unwrap_model(self.model), "end2end", False):
# initialize loss and resume o2o and o2m args
unwrap_model(self.model).criterion = unwrap_model(self.model).init_criterion()
unwrap_model(self.model).criterion.updates = start_epoch - 1