Feat: support multiple data config via yaml for YOLOE training (#23427)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Jing Qiu <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
Louis 2026-01-27 16:11:46 +08:00 committed by GitHub
parent 25ded59194
commit 8e589dfdf7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 75 additions and 17 deletions

View file

@ -321,6 +321,7 @@ This approach provides a powerful means of customizing state-of-the-art [object
from ultralytics import YOLOWorld
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
# Option 1: Use Python dictionary
data = dict(
train=dict(
yolo_data=["Objects365.yaml"],
@ -337,8 +338,27 @@ This approach provides a powerful means of customizing state-of-the-art [object
),
val=dict(yolo_data=["lvis.yaml"]),
)
# Option 2: Use YAML file (yolo_world_data.yaml)
# train:
# yolo_data:
# - Objects365.yaml
# grounding_data:
# - img_path: flickr/full_images/
# json_file: flickr/annotations/final_flickr_separateGT_train_segm.json
# - img_path: mixed_grounding/gqa/images
# json_file: mixed_grounding/annotations/final_mixed_train_no_coco_segm.json
# val:
# yolo_data:
# - lvis.yaml
model = YOLOWorld("yolov8s-worldv2.yaml")
model.train(data=data, batch=128, epochs=100, trainer=WorldTrainerFromScratch)
model.train(
data=data, # or data="yolo_world_data.yaml" if using YAML file
batch=128,
epochs=100,
trainer=WorldTrainerFromScratch,
)
```
## Citations and Acknowledgments

View file

@ -541,6 +541,7 @@ The export process is similar to other YOLO models, with the added flexibility o
from ultralytics import YOLOE
from ultralytics.models.yolo.yoloe import YOLOESegTrainerFromScratch
# Option 1: Use Python dictionary
data = dict(
train=dict(
yolo_data=["Objects365.yaml"],
@ -558,9 +559,22 @@ The export process is similar to other YOLO models, with the added flexibility o
val=dict(yolo_data=["lvis.yaml"]),
)
# Option 2: Use YAML file (yoloe_data.yaml)
# train:
# yolo_data:
# - Objects365.yaml
# grounding_data:
# - img_path: flickr/full_images/
# json_file: flickr/annotations/final_flickr_separateGT_train_segm.json
# - img_path: mixed_grounding/gqa/images
# json_file: mixed_grounding/annotations/final_mixed_train_no_coco_segm.json
# val:
# yolo_data:
# - lvis.yaml
model = YOLOE("yoloe-26l-seg.yaml")
model.train(
data=data,
data=data, # or data="yoloe_data.yaml" if using YAML file
batch=128,
epochs=30,
close_mosaic=2,

View file

@ -697,11 +697,11 @@ def test_yolo_world():
checks.IS_PYTHON_3_8 and LINUX and ARM64,
reason="YOLOE with CLIP is not supported in Python 3.8 and aarch64 Linux",
)
def test_yoloe():
def test_yoloe(tmp_path):
"""Test YOLOE models with MobileClip support."""
# Predict
# text-prompts
model = YOLO(WEIGHTS_DIR / "yoloe-11s-seg.pt")
model = YOLO(WEIGHTS_DIR / "yoloe-26s-seg.pt")
names = ["person", "bus"]
model.set_classes(names, model.get_text_pe(names))
model(SOURCE, conf=0.01)
@ -721,7 +721,7 @@ def test_yoloe():
)
# Val
model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg.pt")
model = YOLOE(WEIGHTS_DIR / "yoloe-26s-seg.pt")
# text prompts
model.val(data="coco128-seg.yaml", imgsz=32)
# visual prompts
@ -730,7 +730,7 @@ def test_yoloe():
# Train, fine-tune
from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer, YOLOESegTrainerFromScratch
model = YOLOE("yoloe-11s-seg.pt")
model = YOLOE("yoloe-26s-seg.pt")
model.train(
data="coco128-seg.yaml",
epochs=1,
@ -739,21 +739,26 @@ def test_yoloe():
imgsz=32,
)
# Train, from scratch
model = YOLOE("yoloe-11s-seg.yaml")
model.train(
data=dict(train=dict(yolo_data=["coco128-seg.yaml"]), val=dict(yolo_data=["coco128-seg.yaml"])),
epochs=1,
close_mosaic=1,
trainer=YOLOESegTrainerFromScratch,
imgsz=32,
)
model = YOLOE("yoloe-26s-seg.yaml")
data_dict = dict(train=dict(yolo_data=["coco128-seg.yaml"]), val=dict(yolo_data=["coco128-seg.yaml"]))
data_yaml = tmp_path / "yoloe-data.yaml"
YAML.save(data=data_dict, file=data_yaml)
for data in {data_dict, data_yaml}:
model.train(
data=data,
epochs=1,
close_mosaic=1,
trainer=YOLOESegTrainerFromScratch,
imgsz=32,
)
# prompt-free
# predict
model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg-pf.pt")
model = YOLOE(WEIGHTS_DIR / "yoloe-26s-seg-pf.pt")
model.predict(SOURCE)
# val
model = YOLOE("yoloe-11s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes
model = YOLOE("yoloe-26s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes
model.val(data="coco128-seg.yaml", imgsz=32)

View file

@ -1,4 +1,5 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
@ -6,6 +7,7 @@ from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_data
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.world import WorldTrainer
from ultralytics.utils import DATASETS_DIR, DEFAULT_CFG, LOGGER
from ultralytics.utils.checks import check_file
from ultralytics.utils.torch_utils import unwrap_model
@ -100,6 +102,23 @@ class WorldTrainerFromScratch(WorldTrainer):
self.set_text_embeddings(datasets, batch) # cache text embeddings to accelerate training
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
@staticmethod
def check_data_config(data: dict | str) -> dict:
"""Check and load the data configuration from a YAML file or dictionary.
Args:
data (dict | str): Data configuration as a dictionary or path to a YAML file.
Returns:
(dict): Data configuration dictionary loaded from YAML file or passed directly.
"""
# If string, load from YAML file
if isinstance(data, str):
from ultralytics.utils import YAML
return YAML.load(check_file(data))
return data
def get_dataset(self):
"""Get train and validation paths from data dictionary.
@ -114,7 +133,7 @@ class WorldTrainerFromScratch(WorldTrainer):
AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.
"""
final_data = {}
data_yaml = self.args.data
self.args.data = data_yaml = self.check_data_config(self.args.data)
assert data_yaml.get("train", False), "train dataset not found" # object365.yaml
assert data_yaml.get("val", False), "validation dataset not found" # lvis.yaml
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}