mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
Feat: support multiple data config via yaml for YOLOE training (#23427)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Jing Qiu <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
parent
25ded59194
commit
8e589dfdf7
4 changed files with 75 additions and 17 deletions
|
|
@ -321,6 +321,7 @@ This approach provides a powerful means of customizing state-of-the-art [object
|
|||
from ultralytics import YOLOWorld
|
||||
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
||||
|
||||
# Option 1: Use Python dictionary
|
||||
data = dict(
|
||||
train=dict(
|
||||
yolo_data=["Objects365.yaml"],
|
||||
|
|
@ -337,8 +338,27 @@ This approach provides a powerful means of customizing state-of-the-art [object
|
|||
),
|
||||
val=dict(yolo_data=["lvis.yaml"]),
|
||||
)
|
||||
|
||||
# Option 2: Use YAML file (yolo_world_data.yaml)
|
||||
# train:
|
||||
# yolo_data:
|
||||
# - Objects365.yaml
|
||||
# grounding_data:
|
||||
# - img_path: flickr/full_images/
|
||||
# json_file: flickr/annotations/final_flickr_separateGT_train_segm.json
|
||||
# - img_path: mixed_grounding/gqa/images
|
||||
# json_file: mixed_grounding/annotations/final_mixed_train_no_coco_segm.json
|
||||
# val:
|
||||
# yolo_data:
|
||||
# - lvis.yaml
|
||||
|
||||
model = YOLOWorld("yolov8s-worldv2.yaml")
|
||||
model.train(data=data, batch=128, epochs=100, trainer=WorldTrainerFromScratch)
|
||||
model.train(
|
||||
data=data, # or data="yolo_world_data.yaml" if using YAML file
|
||||
batch=128,
|
||||
epochs=100,
|
||||
trainer=WorldTrainerFromScratch,
|
||||
)
|
||||
```
|
||||
|
||||
## Citations and Acknowledgments
|
||||
|
|
|
|||
|
|
@ -541,6 +541,7 @@ The export process is similar to other YOLO models, with the added flexibility o
|
|||
from ultralytics import YOLOE
|
||||
from ultralytics.models.yolo.yoloe import YOLOESegTrainerFromScratch
|
||||
|
||||
# Option 1: Use Python dictionary
|
||||
data = dict(
|
||||
train=dict(
|
||||
yolo_data=["Objects365.yaml"],
|
||||
|
|
@ -558,9 +559,22 @@ The export process is similar to other YOLO models, with the added flexibility o
|
|||
val=dict(yolo_data=["lvis.yaml"]),
|
||||
)
|
||||
|
||||
# Option 2: Use YAML file (yoloe_data.yaml)
|
||||
# train:
|
||||
# yolo_data:
|
||||
# - Objects365.yaml
|
||||
# grounding_data:
|
||||
# - img_path: flickr/full_images/
|
||||
# json_file: flickr/annotations/final_flickr_separateGT_train_segm.json
|
||||
# - img_path: mixed_grounding/gqa/images
|
||||
# json_file: mixed_grounding/annotations/final_mixed_train_no_coco_segm.json
|
||||
# val:
|
||||
# yolo_data:
|
||||
# - lvis.yaml
|
||||
|
||||
model = YOLOE("yoloe-26l-seg.yaml")
|
||||
model.train(
|
||||
data=data,
|
||||
data=data, # or data="yoloe_data.yaml" if using YAML file
|
||||
batch=128,
|
||||
epochs=30,
|
||||
close_mosaic=2,
|
||||
|
|
|
|||
|
|
@ -697,11 +697,11 @@ def test_yolo_world():
|
|||
checks.IS_PYTHON_3_8 and LINUX and ARM64,
|
||||
reason="YOLOE with CLIP is not supported in Python 3.8 and aarch64 Linux",
|
||||
)
|
||||
def test_yoloe():
|
||||
def test_yoloe(tmp_path):
|
||||
"""Test YOLOE models with MobileClip support."""
|
||||
# Predict
|
||||
# text-prompts
|
||||
model = YOLO(WEIGHTS_DIR / "yoloe-11s-seg.pt")
|
||||
model = YOLO(WEIGHTS_DIR / "yoloe-26s-seg.pt")
|
||||
names = ["person", "bus"]
|
||||
model.set_classes(names, model.get_text_pe(names))
|
||||
model(SOURCE, conf=0.01)
|
||||
|
|
@ -721,7 +721,7 @@ def test_yoloe():
|
|||
)
|
||||
|
||||
# Val
|
||||
model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg.pt")
|
||||
model = YOLOE(WEIGHTS_DIR / "yoloe-26s-seg.pt")
|
||||
# text prompts
|
||||
model.val(data="coco128-seg.yaml", imgsz=32)
|
||||
# visual prompts
|
||||
|
|
@ -730,7 +730,7 @@ def test_yoloe():
|
|||
# Train, fine-tune
|
||||
from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer, YOLOESegTrainerFromScratch
|
||||
|
||||
model = YOLOE("yoloe-11s-seg.pt")
|
||||
model = YOLOE("yoloe-26s-seg.pt")
|
||||
model.train(
|
||||
data="coco128-seg.yaml",
|
||||
epochs=1,
|
||||
|
|
@ -739,21 +739,26 @@ def test_yoloe():
|
|||
imgsz=32,
|
||||
)
|
||||
# Train, from scratch
|
||||
model = YOLOE("yoloe-11s-seg.yaml")
|
||||
model.train(
|
||||
data=dict(train=dict(yolo_data=["coco128-seg.yaml"]), val=dict(yolo_data=["coco128-seg.yaml"])),
|
||||
epochs=1,
|
||||
close_mosaic=1,
|
||||
trainer=YOLOESegTrainerFromScratch,
|
||||
imgsz=32,
|
||||
)
|
||||
model = YOLOE("yoloe-26s-seg.yaml")
|
||||
|
||||
data_dict = dict(train=dict(yolo_data=["coco128-seg.yaml"]), val=dict(yolo_data=["coco128-seg.yaml"]))
|
||||
data_yaml = tmp_path / "yoloe-data.yaml"
|
||||
YAML.save(data=data_dict, file=data_yaml)
|
||||
for data in {data_dict, data_yaml}:
|
||||
model.train(
|
||||
data=data,
|
||||
epochs=1,
|
||||
close_mosaic=1,
|
||||
trainer=YOLOESegTrainerFromScratch,
|
||||
imgsz=32,
|
||||
)
|
||||
|
||||
# prompt-free
|
||||
# predict
|
||||
model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg-pf.pt")
|
||||
model = YOLOE(WEIGHTS_DIR / "yoloe-26s-seg-pf.pt")
|
||||
model.predict(SOURCE)
|
||||
# val
|
||||
model = YOLOE("yoloe-11s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes
|
||||
model = YOLOE("yoloe-26s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes
|
||||
model.val(data="coco128-seg.yaml", imgsz=32)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -6,6 +7,7 @@ from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_data
|
|||
from ultralytics.data.utils import check_det_dataset
|
||||
from ultralytics.models.yolo.world import WorldTrainer
|
||||
from ultralytics.utils import DATASETS_DIR, DEFAULT_CFG, LOGGER
|
||||
from ultralytics.utils.checks import check_file
|
||||
from ultralytics.utils.torch_utils import unwrap_model
|
||||
|
||||
|
||||
|
|
@ -100,6 +102,23 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
self.set_text_embeddings(datasets, batch) # cache text embeddings to accelerate training
|
||||
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
||||
|
||||
@staticmethod
|
||||
def check_data_config(data: dict | str) -> dict:
|
||||
"""Check and load the data configuration from a YAML file or dictionary.
|
||||
|
||||
Args:
|
||||
data (dict | str): Data configuration as a dictionary or path to a YAML file.
|
||||
|
||||
Returns:
|
||||
(dict): Data configuration dictionary loaded from YAML file or passed directly.
|
||||
"""
|
||||
# If string, load from YAML file
|
||||
if isinstance(data, str):
|
||||
from ultralytics.utils import YAML
|
||||
|
||||
return YAML.load(check_file(data))
|
||||
return data
|
||||
|
||||
def get_dataset(self):
|
||||
"""Get train and validation paths from data dictionary.
|
||||
|
||||
|
|
@ -114,7 +133,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.
|
||||
"""
|
||||
final_data = {}
|
||||
data_yaml = self.args.data
|
||||
self.args.data = data_yaml = self.check_data_config(self.args.data)
|
||||
assert data_yaml.get("train", False), "train dataset not found" # object365.yaml
|
||||
assert data_yaml.get("val", False), "validation dataset not found" # lvis.yaml
|
||||
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
|
||||
|
|
|
|||
Loading…
Reference in a new issue