mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
Fix YOLOE training test (#23448)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
8e589dfdf7
commit
8e6eab3643
2 changed files with 11 additions and 11 deletions
|
|
@ -701,7 +701,7 @@ def test_yoloe(tmp_path):
|
|||
"""Test YOLOE models with MobileClip support."""
|
||||
# Predict
|
||||
# text-prompts
|
||||
model = YOLO(WEIGHTS_DIR / "yoloe-26s-seg.pt")
|
||||
model = YOLO(WEIGHTS_DIR / "yoloe-11s-seg.pt")
|
||||
names = ["person", "bus"]
|
||||
model.set_classes(names, model.get_text_pe(names))
|
||||
model(SOURCE, conf=0.01)
|
||||
|
|
@ -721,7 +721,7 @@ def test_yoloe(tmp_path):
|
|||
)
|
||||
|
||||
# Val
|
||||
model = YOLOE(WEIGHTS_DIR / "yoloe-26s-seg.pt")
|
||||
model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg.pt")
|
||||
# text prompts
|
||||
model.val(data="coco128-seg.yaml", imgsz=32)
|
||||
# visual prompts
|
||||
|
|
@ -730,7 +730,7 @@ def test_yoloe(tmp_path):
|
|||
# Train, fine-tune
|
||||
from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer, YOLOESegTrainerFromScratch
|
||||
|
||||
model = YOLOE("yoloe-26s-seg.pt")
|
||||
model = YOLOE("yoloe-11s-seg.pt")
|
||||
model.train(
|
||||
data="coco128-seg.yaml",
|
||||
epochs=1,
|
||||
|
|
@ -739,12 +739,11 @@ def test_yoloe(tmp_path):
|
|||
imgsz=32,
|
||||
)
|
||||
# Train, from scratch
|
||||
model = YOLOE("yoloe-26s-seg.yaml")
|
||||
|
||||
data_dict = dict(train=dict(yolo_data=["coco128-seg.yaml"]), val=dict(yolo_data=["coco128-seg.yaml"]))
|
||||
data_yaml = tmp_path / "yoloe-data.yaml"
|
||||
YAML.save(data=data_dict, file=data_yaml)
|
||||
for data in {data_dict, data_yaml}:
|
||||
for data in [data_dict, data_yaml]:
|
||||
model = YOLOE("yoloe-11s-seg.yaml")
|
||||
model.train(
|
||||
data=data,
|
||||
epochs=1,
|
||||
|
|
@ -755,10 +754,10 @@ def test_yoloe(tmp_path):
|
|||
|
||||
# prompt-free
|
||||
# predict
|
||||
model = YOLOE(WEIGHTS_DIR / "yoloe-26s-seg-pf.pt")
|
||||
model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg-pf.pt")
|
||||
model.predict(SOURCE)
|
||||
# val
|
||||
model = YOLOE("yoloe-26s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes
|
||||
model = YOLOE("yoloe-11s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes
|
||||
model.val(data="coco128-seg.yaml", imgsz=32)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
|
@ -103,17 +104,17 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
||||
|
||||
@staticmethod
|
||||
def check_data_config(data: dict | str) -> dict:
|
||||
def check_data_config(data: dict | str | Path) -> dict:
|
||||
"""Check and load the data configuration from a YAML file or dictionary.
|
||||
|
||||
Args:
|
||||
data (dict | str): Data configuration as a dictionary or path to a YAML file.
|
||||
data (dict | str | Path): Data configuration as a dictionary or path to a YAML file.
|
||||
|
||||
Returns:
|
||||
(dict): Data configuration dictionary loaded from YAML file or passed directly.
|
||||
"""
|
||||
# If string, load from YAML file
|
||||
if isinstance(data, str):
|
||||
if not isinstance(data, dict):
|
||||
from ultralytics.utils import YAML
|
||||
|
||||
return YAML.load(check_file(data))
|
||||
|
|
|
|||
Loading…
Reference in a new issue