mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
Refactor test_engine.py and add OBB/Pose test coverage (#24197)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
616c032b8b
commit
10beb3fa60
1 changed files with 52 additions and 76 deletions
|
|
@ -11,7 +11,7 @@ from tests import MODEL, SOURCE, TASK_MODEL_DATA
|
|||
from ultralytics import YOLO
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.engine.exporter import Exporter
|
||||
from ultralytics.models.yolo import classify, detect, segment
|
||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment
|
||||
from ultralytics.nn.tasks import load_checkpoint
|
||||
from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR
|
||||
|
||||
|
|
@ -30,107 +30,83 @@ def test_export():
|
|||
YOLO(f)(SOURCE) # exported model inference
|
||||
|
||||
|
||||
def test_detect():
|
||||
"""Test YOLO object detection training, validation, and prediction functionality."""
|
||||
overrides = {"data": "coco8.yaml", "model": "yolo26n.yaml", "imgsz": 32, "epochs": 1, "save": False}
|
||||
cfg = get_cfg(DEFAULT_CFG)
|
||||
cfg.data = "coco8.yaml"
|
||||
cfg.imgsz = 32
|
||||
|
||||
# Trainer
|
||||
trainer = detect.DetectionTrainer(overrides=overrides)
|
||||
trainer.add_callback("on_train_start", test_func)
|
||||
assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
|
||||
trainer.train()
|
||||
|
||||
# Validator
|
||||
val = detect.DetectionValidator(args=cfg)
|
||||
val.add_callback("on_val_start", test_func)
|
||||
assert test_func in val.callbacks["on_val_start"], "callback test failed"
|
||||
val(model=trainer.best) # validate best.pt
|
||||
|
||||
# Predictor
|
||||
pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})
|
||||
pred.add_callback("on_predict_start", test_func)
|
||||
assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
|
||||
# Confirm there is no issue with sys.argv being empty
|
||||
with mock.patch.object(sys, "argv", []):
|
||||
result = pred(source=ASSETS, model=MODEL)
|
||||
assert len(result), "predictor test failed"
|
||||
|
||||
# Test resume functionality
|
||||
with pytest.raises(AssertionError):
|
||||
detect.DetectionTrainer(overrides={**overrides, "resume": trainer.last}).train()
|
||||
|
||||
|
||||
def test_segment():
|
||||
"""Test image segmentation training, validation, and prediction pipelines using YOLO models."""
|
||||
@pytest.mark.parametrize(
|
||||
"trainer_cls,validator_cls,predictor_cls,data,model,weights",
|
||||
[
|
||||
(
|
||||
detect.DetectionTrainer,
|
||||
detect.DetectionValidator,
|
||||
detect.DetectionPredictor,
|
||||
"coco8.yaml",
|
||||
"yolo26n.yaml",
|
||||
MODEL,
|
||||
),
|
||||
(
|
||||
segment.SegmentationTrainer,
|
||||
segment.SegmentationValidator,
|
||||
segment.SegmentationPredictor,
|
||||
"coco8-seg.yaml",
|
||||
"yolo26n-seg.yaml",
|
||||
WEIGHTS_DIR / "yolo26n-seg.pt",
|
||||
),
|
||||
(
|
||||
classify.ClassificationTrainer,
|
||||
classify.ClassificationValidator,
|
||||
classify.ClassificationPredictor,
|
||||
"imagenet10",
|
||||
"yolo26n-cls.yaml",
|
||||
None,
|
||||
),
|
||||
(obb.OBBTrainer, obb.OBBValidator, obb.OBBPredictor, "dota8.yaml", "yolo26n-obb.yaml", None),
|
||||
(pose.PoseTrainer, pose.PoseValidator, pose.PosePredictor, "coco8-pose.yaml", "yolo26n-pose.yaml", None),
|
||||
],
|
||||
)
|
||||
def test_task(trainer_cls, validator_cls, predictor_cls, data, model, weights):
|
||||
"""Test YOLO training, validation, and prediction for various tasks."""
|
||||
overrides = {
|
||||
"data": "coco8-seg.yaml",
|
||||
"model": "yolo26n-seg.yaml",
|
||||
"data": data,
|
||||
"model": model,
|
||||
"imgsz": 32,
|
||||
"epochs": 1,
|
||||
"save": False,
|
||||
"mask_ratio": 1,
|
||||
"overlap_mask": False,
|
||||
}
|
||||
cfg = get_cfg(DEFAULT_CFG)
|
||||
cfg.data = "coco8-seg.yaml"
|
||||
cfg.imgsz = 32
|
||||
|
||||
# Trainer
|
||||
trainer = segment.SegmentationTrainer(overrides=overrides)
|
||||
trainer = trainer_cls(overrides=overrides)
|
||||
trainer.add_callback("on_train_start", test_func)
|
||||
assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
|
||||
trainer.train()
|
||||
|
||||
# Validator
|
||||
val = segment.SegmentationValidator(args=cfg)
|
||||
val.add_callback("on_val_start", test_func)
|
||||
assert test_func in val.callbacks["on_val_start"], "callback test failed"
|
||||
val(model=trainer.best) # validate best.pt
|
||||
|
||||
# Predictor
|
||||
pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]})
|
||||
pred.add_callback("on_predict_start", test_func)
|
||||
assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
|
||||
result = pred(source=ASSETS, model=WEIGHTS_DIR / "yolo26n-seg.pt")
|
||||
assert len(result), "predictor test failed"
|
||||
|
||||
# Test resume functionality
|
||||
with pytest.raises(AssertionError):
|
||||
segment.SegmentationTrainer(overrides={**overrides, "resume": trainer.last}).train()
|
||||
|
||||
|
||||
def test_classify():
|
||||
"""Test image classification including training, validation, and prediction phases."""
|
||||
overrides = {"data": "imagenet10", "model": "yolo26n-cls.yaml", "imgsz": 32, "epochs": 1, "save": False}
|
||||
cfg = get_cfg(DEFAULT_CFG)
|
||||
cfg.data = "imagenet10"
|
||||
cfg.data = data
|
||||
cfg.imgsz = 32
|
||||
|
||||
# Trainer
|
||||
trainer = classify.ClassificationTrainer(overrides=overrides)
|
||||
trainer.add_callback("on_train_start", test_func)
|
||||
assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
|
||||
trainer.train()
|
||||
|
||||
# Validator
|
||||
val = classify.ClassificationValidator(args=cfg)
|
||||
val = validator_cls(args=cfg)
|
||||
val.add_callback("on_val_start", test_func)
|
||||
assert test_func in val.callbacks["on_val_start"], "callback test failed"
|
||||
val(model=trainer.best)
|
||||
|
||||
# Predictor
|
||||
pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]})
|
||||
pred = predictor_cls(overrides={"imgsz": [64, 64]})
|
||||
pred.add_callback("on_predict_start", test_func)
|
||||
assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
|
||||
result = pred(source=ASSETS, model=trainer.best)
|
||||
assert len(result), "predictor test failed"
|
||||
|
||||
# Determine model path for prediction
|
||||
model_path = weights if weights else trainer.best
|
||||
if model == "yolo26n.yaml": # only for detection
|
||||
# Confirm there is no issue with sys.argv being empty
|
||||
with mock.patch.object(sys, "argv", []):
|
||||
result = pred(source=ASSETS, model=model_path)
|
||||
assert len(result), "predictor test failed"
|
||||
else:
|
||||
result = pred(source=ASSETS, model=model_path)
|
||||
assert len(result), "predictor test failed"
|
||||
|
||||
# Test resume functionality
|
||||
with pytest.raises(AssertionError):
|
||||
classify.ClassificationTrainer(overrides={**overrides, "resume": trainer.last}).train()
|
||||
trainer_cls(overrides={**overrides, "resume": trainer.last}).train()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task,weight,data", TASK_MODEL_DATA)
|
||||
|
|
|
|||
Loading…
Reference in a new issue