ultralytics 8.3.101 YOLOE visual prompt inference fix for video sources (#19959)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2025-04-03 19:49:04 +08:00 committed by GitHub
parent c2fbfc5c4a
commit 6ebf893880
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 21 additions and 2 deletions

View file

@ -160,6 +160,11 @@ Object detection is straightforward with the `predict` method, as illustrated be
=== "Visual Prompt"
!!! note
If `source` is a video/stream, the first frame of the video/stream will be automatically used as `refer_image`, or you could directly pass any frame from the video/stream to `refer_image` argument.
Prompts in source image:
```python

View file

@ -1,6 +1,6 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
__version__ = "8.3.100"
__version__ = "8.3.101"
import os

View file

@ -2,6 +2,7 @@
from pathlib import Path
from ultralytics.data.build import load_inference_source
from ultralytics.engine.model import Model
from ultralytics.models import yolo
from ultralytics.nn.tasks import (
@ -267,7 +268,14 @@ class YOLOE(Model):
f"{len(visual_prompts['cls'])} respectively"
)
self.predictor = (predictor or self._smart_load("predictor"))(
overrides={"task": "segment", "mode": "predict", "save": False, "verbose": False}, _callbacks=self.callbacks
overrides={
"task": self.model.task,
"mode": "predict",
"save": False,
"verbose": refer_image is None,
"batch": 1,
},
_callbacks=self.callbacks,
)
if len(visual_prompts):
@ -281,6 +289,12 @@ class YOLOE(Model):
self.predictor.set_prompts(visual_prompts.copy())
self.predictor.setup_model(model=self.model)
if refer_image is None:
dataset = load_inference_source(source)
if dataset.mode in {"video", "stream"}:
# NOTE: set the first frame as refer image for videos/streams inference
refer_image = next(iter(dataset))[1][0]
if refer_image is not None and len(visual_prompts):
vpe = self.predictor.get_vpe(refer_image)
self.model.set_classes(self.model.names, vpe)