diff --git a/docs/en/models/yolo-world.md b/docs/en/models/yolo-world.md index 2945f4aed1..ec1a25e812 100644 --- a/docs/en/models/yolo-world.md +++ b/docs/en/models/yolo-world.md @@ -321,6 +321,7 @@ This approach provides a powerful means of customizing state-of-the-art [object from ultralytics import YOLOWorld from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch + # Option 1: Use Python dictionary data = dict( train=dict( yolo_data=["Objects365.yaml"], @@ -337,8 +338,27 @@ This approach provides a powerful means of customizing state-of-the-art [object ), val=dict(yolo_data=["lvis.yaml"]), ) + + # Option 2: Use YAML file (yolo_world_data.yaml) + # train: + # yolo_data: + # - Objects365.yaml + # grounding_data: + # - img_path: flickr/full_images/ + # json_file: flickr/annotations/final_flickr_separateGT_train_segm.json + # - img_path: mixed_grounding/gqa/images + # json_file: mixed_grounding/annotations/final_mixed_train_no_coco_segm.json + # val: + # yolo_data: + # - lvis.yaml + model = YOLOWorld("yolov8s-worldv2.yaml") - model.train(data=data, batch=128, epochs=100, trainer=WorldTrainerFromScratch) + model.train( + data=data, # or data="yolo_world_data.yaml" if using YAML file + batch=128, + epochs=100, + trainer=WorldTrainerFromScratch, + ) ``` ## Citations and Acknowledgments diff --git a/docs/en/models/yoloe.md b/docs/en/models/yoloe.md index 3f355aa70c..d0da2b1ed0 100644 --- a/docs/en/models/yoloe.md +++ b/docs/en/models/yoloe.md @@ -541,6 +541,7 @@ The export process is similar to other YOLO models, with the added flexibility o from ultralytics import YOLOE from ultralytics.models.yolo.yoloe import YOLOESegTrainerFromScratch + # Option 1: Use Python dictionary data = dict( train=dict( yolo_data=["Objects365.yaml"], @@ -558,9 +559,22 @@ The export process is similar to other YOLO models, with the added flexibility o val=dict(yolo_data=["lvis.yaml"]), ) + # Option 2: Use YAML file (yoloe_data.yaml) + # train: + # yolo_data: + # - Objects365.yaml + # grounding_data: + # - img_path: flickr/full_images/ + # json_file: flickr/annotations/final_flickr_separateGT_train_segm.json + # - img_path: mixed_grounding/gqa/images + # json_file: mixed_grounding/annotations/final_mixed_train_no_coco_segm.json + # val: + # yolo_data: + # - lvis.yaml + model = YOLOE("yoloe-26l-seg.yaml") model.train( - data=data, + data=data, # or data="yoloe_data.yaml" if using YAML file batch=128, epochs=30, close_mosaic=2, diff --git a/tests/test_python.py b/tests/test_python.py index 8acdb2daf4..d42a751b2d 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -697,11 +697,11 @@ def test_yolo_world(): checks.IS_PYTHON_3_8 and LINUX and ARM64, reason="YOLOE with CLIP is not supported in Python 3.8 and aarch64 Linux", ) -def test_yoloe(): +def test_yoloe(tmp_path): """Test YOLOE models with MobileClip support.""" # Predict # text-prompts - model = YOLO(WEIGHTS_DIR / "yoloe-11s-seg.pt") + model = YOLO(WEIGHTS_DIR / "yoloe-26s-seg.pt") names = ["person", "bus"] model.set_classes(names, model.get_text_pe(names)) model(SOURCE, conf=0.01) @@ -721,7 +721,7 @@ def test_yoloe(): ) # Val - model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg.pt") + model = YOLOE(WEIGHTS_DIR / "yoloe-26s-seg.pt") # text prompts model.val(data="coco128-seg.yaml", imgsz=32) # visual prompts @@ -730,7 +730,7 @@ def test_yoloe(): # Train, fine-tune from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer, YOLOESegTrainerFromScratch - model = YOLOE("yoloe-11s-seg.pt") + model = YOLOE("yoloe-26s-seg.pt") model.train( data="coco128-seg.yaml", epochs=1, @@ -739,21 +739,26 @@ def test_yoloe(): imgsz=32, ) # Train, from scratch - model = YOLOE("yoloe-11s-seg.yaml") - model.train( - data=dict(train=dict(yolo_data=["coco128-seg.yaml"]), val=dict(yolo_data=["coco128-seg.yaml"])), - epochs=1, - close_mosaic=1, - trainer=YOLOESegTrainerFromScratch, - imgsz=32, - ) + model = YOLOE("yoloe-26s-seg.yaml") + + data_dict = dict(train=dict(yolo_data=["coco128-seg.yaml"]), val=dict(yolo_data=["coco128-seg.yaml"])) + data_yaml = tmp_path / "yoloe-data.yaml" + YAML.save(data=data_dict, file=data_yaml) + for data in {data_dict, data_yaml}: + model.train( + data=data, + epochs=1, + close_mosaic=1, + trainer=YOLOESegTrainerFromScratch, + imgsz=32, + ) # prompt-free # predict - model = YOLOE(WEIGHTS_DIR / "yoloe-11s-seg-pf.pt") + model = YOLOE(WEIGHTS_DIR / "yoloe-26s-seg-pf.pt") model.predict(SOURCE) # val - model = YOLOE("yoloe-11s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes + model = YOLOE("yoloe-26s-seg.pt") # or select yoloe-m/l-seg.pt for different sizes model.val(data="coco128-seg.yaml", imgsz=32) diff --git a/ultralytics/models/yolo/world/train_world.py b/ultralytics/models/yolo/world/train_world.py index e02c0fec50..44c701e830 100644 --- a/ultralytics/models/yolo/world/train_world.py +++ b/ultralytics/models/yolo/world/train_world.py @@ -1,4 +1,5 @@ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +from __future__ import annotations from pathlib import Path @@ -6,6 +7,7 @@ from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_data from ultralytics.data.utils import check_det_dataset from ultralytics.models.yolo.world import WorldTrainer from ultralytics.utils import DATASETS_DIR, DEFAULT_CFG, LOGGER +from ultralytics.utils.checks import check_file from ultralytics.utils.torch_utils import unwrap_model @@ -100,6 +102,23 @@ class WorldTrainerFromScratch(WorldTrainer): self.set_text_embeddings(datasets, batch) # cache text embeddings to accelerate training return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0] + @staticmethod + def check_data_config(data: dict | str) -> dict: + """Check and load the data configuration from a YAML file or dictionary. + + Args: + data (dict | str): Data configuration as a dictionary or path to a YAML file. + + Returns: + (dict): Data configuration dictionary loaded from YAML file or passed directly. + """ + # If string, load from YAML file + if isinstance(data, str): + from ultralytics.utils import YAML + + return YAML.load(check_file(data)) + return data + def get_dataset(self): """Get train and validation paths from data dictionary. @@ -114,7 +133,7 @@ class WorldTrainerFromScratch(WorldTrainer): AssertionError: If train or validation datasets are not found, or if validation has multiple datasets. """ final_data = {} - data_yaml = self.args.data + self.args.data = data_yaml = self.check_data_config(self.args.data) assert data_yaml.get("train", False), "train dataset not found" # object365.yaml assert data_yaml.get("val", False), "validation dataset not found" # lvis.yaml data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}