Fix YOLOE training test (#23448)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Jing Qiu 2026-01-27 22:19:34 +08:00 committed by GitHub
parent 8e589dfdf7
commit 8e6eab3643
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 11 additions and 11 deletions

View file

@ -701,7 +701,7 @@ def test_yoloe(tmp_path):
"""Test YOLOE models with MobileClip support."""
# Predict
# text-prompts
model = YOLO(WEIGHTS_DIR / "yoloe-26s-seg.pt")
model = YOLO(WEIGHTS_DIR / "yoloe-11s-seg.pt")
names = ["person", "bus"]
model.set_classes(names, model.get_text_pe(names))
model(SOURCE, conf=0.01)
@ -721,7 +721,7 @@ def test_yoloe(tmp_path):
)
# Val
model = YOLOE(WEIGHTS_DIR / "yoloe-26s-seg.pt")
model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg.pt")
# text prompts
model.val(data="coco128-seg.yaml", imgsz=32)
# visual prompts
@ -730,7 +730,7 @@ def test_yoloe(tmp_path):
# Train, fine-tune
from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer, YOLOESegTrainerFromScratch
model = YOLOE("yoloe-26s-seg.pt")
model = YOLOE("yoloe-11s-seg.pt")
model.train(
data="coco128-seg.yaml",
epochs=1,
@ -739,12 +739,11 @@ def test_yoloe(tmp_path):
imgsz=32,
)
# Train, from scratch
model = YOLOE("yoloe-26s-seg.yaml")
data_dict = dict(train=dict(yolo_data=["coco128-seg.yaml"]), val=dict(yolo_data=["coco128-seg.yaml"]))
data_yaml = tmp_path / "yoloe-data.yaml"
YAML.save(data=data_dict, file=data_yaml)
for data in {data_dict, data_yaml}:
for data in [data_dict, data_yaml]:
model = YOLOE("yoloe-11s-seg.yaml")
model.train(
data=data,
epochs=1,
@ -755,10 +754,10 @@ def test_yoloe(tmp_path):
# prompt-free
# predict
model = YOLOE(WEIGHTS_DIR / "yoloe-26s-seg-pf.pt")
model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg-pf.pt")
model.predict(SOURCE)
# val
model = YOLOE("yoloe-26s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes
model = YOLOE("yoloe-11s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes
model.val(data="coco128-seg.yaml", imgsz=32)

View file

@ -1,4 +1,5 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
@ -103,17 +104,17 @@ class WorldTrainerFromScratch(WorldTrainer):
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
@staticmethod
def check_data_config(data: dict | str) -> dict:
def check_data_config(data: dict | str | Path) -> dict:
"""Check and load the data configuration from a YAML file or dictionary.
Args:
data (dict | str): Data configuration as a dictionary or path to a YAML file.
data (dict | str | Path): Data configuration as a dictionary or path to a YAML file.
Returns:
(dict): Data configuration dictionary loaded from YAML file or passed directly.
"""
# If string, load from YAML file
if isinstance(data, str):
if not isinstance(data, dict):
from ultralytics.utils import YAML
return YAML.load(check_file(data))