From f06a9bbadd28c4567071408f9a27a103ba7cf265 Mon Sep 17 00:00:00 2001 From: fatih akyon <34196005+fcakyon@users.noreply.github.com> Date: Fri, 10 Apr 2026 18:29:30 +0300 Subject: [PATCH] Fix train resume for non-end2end models (#24173) Co-authored-by: Jing Qiu <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> --- tests/test_engine.py | 43 ++++++++++++++++++++++++++++++++++- ultralytics/engine/trainer.py | 2 +- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/tests/test_engine.py b/tests/test_engine.py index 1598ef51de..6440e49614 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -7,11 +7,12 @@ from unittest import mock import pytest import torch -from tests import MODEL, SOURCE +from tests import MODEL, SOURCE, TASK_MODEL_DATA from ultralytics import YOLO from ultralytics.cfg import get_cfg from ultralytics.engine.exporter import Exporter from ultralytics.models.yolo import classify, detect, segment +from ultralytics.nn.tasks import load_checkpoint from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR @@ -127,6 +128,46 @@ def test_classify(): result = pred(source=ASSETS, model=trainer.best) assert len(result), "predictor test failed" + # Test resume functionality + with pytest.raises(AssertionError): + classify.ClassificationTrainer(overrides={**overrides, "resume": trainer.last}).train() + + +@pytest.mark.parametrize("task,weight,data", TASK_MODEL_DATA) +def test_resume_incomplete(task, weight, data, tmp_path): + """Test training resumes from an incomplete checkpoint.""" + train_args = { + "data": data, + "epochs": 2, + "save": True, + "plots": False, + "workers": 0, + "project": tmp_path, + "name": task, + "imgsz": 32, + "exist_ok": True, + } + + def stop_after_first_epoch(trainer): + if trainer.epoch == 0: + trainer.stop = True + + def disable_final_eval(trainer): + trainer.final_eval = lambda: None + + model = YOLO(weight) + model.add_callback("on_train_start", disable_final_eval) + model.add_callback("on_train_epoch_end", stop_after_first_epoch) + model.train(**train_args) + last_path = model.trainer.last + _, ckpt = load_checkpoint(last_path) + assert ckpt["epoch"] == 0, "checkpoint should be resumable" + + # Resume training using the checkpoint + resume_model = YOLO(last_path) + resume_model.train(resume=True, **train_args) + assert resume_model.trainer.start_epoch == resume_model.trainer.epoch == 1, "resume test failed" + def test_nan_recovery(): """Test NaN loss detection and recovery during training.""" diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 4484f2e251..49b9d937ac 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -961,7 +961,7 @@ class BaseTrainer: ) self.epochs += ckpt["epoch"] # finetune additional epochs self._load_checkpoint_state(ckpt) - if unwrap_model(self.model).end2end: + if getattr(unwrap_model(self.model), "end2end", False): # initialize loss and resume o2o and o2m args unwrap_model(self.model).criterion = unwrap_model(self.model).init_criterion() unwrap_model(self.model).criterion.updates = start_epoch - 1