mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
ultralytics 8.3.190 Autobackend torch-native NMS (#21862)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
parent
b1e3530657
commit
a8e3450e62
15 changed files with 386 additions and 207 deletions
20
docs/en/reference/utils/nms.md
Normal file
20
docs/en/reference/utils/nms.md
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
---
|
||||
description: Custom NMS implementation for Ultralytics YOLO with TorchNMS class for torchvision-free inference and fast-nms for oriented bounding boxes. Optimized for speed and accuracy.
|
||||
keywords: NMS, non-maximum suppression, TorchNMS, YOLO, torchvision-free, rotated NMS, object detection, bounding boxes, IoU threshold, custom implementation
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/utils/nms.py`
|
||||
|
||||
!!! note
|
||||
|
||||
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/nms.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/nms.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/utils/nms.py) 🛠️. Thank you 🙏!
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.utils.nms.TorchNMS
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.utils.nms.non_max_suppression
|
||||
|
||||
<br><br>
|
||||
|
|
@ -27,14 +27,6 @@ keywords: Ultralytics, utility operations, non-max suppression, bounding box tra
|
|||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.utils.ops.nms_rotated
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.utils.ops.non_max_suppression
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.utils.ops.clip_boxes
|
||||
|
||||
<br><br><hr><br>
|
||||
|
|
|
|||
|
|
@ -8,9 +8,8 @@ import numpy as np
|
|||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
import ultralytics.utils.ops as ops
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import ASSETS, YAML
|
||||
from ultralytics.utils import ASSETS, YAML, nms, ops
|
||||
from ultralytics.utils.checks import check_yaml
|
||||
|
||||
|
||||
|
|
@ -139,7 +138,7 @@ class YOLOv8Seg:
|
|||
(List[Results]): Processed detection results containing bounding boxes and segmentation masks.
|
||||
"""
|
||||
preds, protos = [torch.from_numpy(p) for p in outs]
|
||||
preds = ops.non_max_suppression(preds, self.conf, self.iou, nc=len(self.classes))
|
||||
preds = nms.non_max_suppression(preds, self.conf, self.iou, nc=len(self.classes))
|
||||
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
|
|
|
|||
|
|
@ -686,6 +686,7 @@ nav:
|
|||
- logger: reference/utils/logger.md
|
||||
- loss: reference/utils/loss.md
|
||||
- metrics: reference/utils/metrics.md
|
||||
- nms: reference/utils/nms.md
|
||||
- ops: reference/utils/ops.md
|
||||
- patches: reference/utils/patches.md
|
||||
- plotting: reference/utils/plotting.md
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
__version__ = "8.3.189"
|
||||
__version__ = "8.3.190"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -107,7 +107,9 @@ from ultralytics.utils.checks import (
|
|||
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
|
||||
from ultralytics.utils.export import export_engine, export_onnx
|
||||
from ultralytics.utils.files import file_size, spaces_in_path
|
||||
from ultralytics.utils.ops import Profile, nms_rotated
|
||||
from ultralytics.utils.metrics import batch_probiou
|
||||
from ultralytics.utils.nms import TorchNMS
|
||||
from ultralytics.utils.ops import Profile
|
||||
from ultralytics.utils.patches import arange_patch
|
||||
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
|
||||
|
||||
|
|
@ -1562,12 +1564,13 @@ class NMSModel(torch.nn.Module):
|
|||
nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
|
||||
nms_fn = (
|
||||
partial(
|
||||
nms_rotated,
|
||||
TorchNMS.fast_nms,
|
||||
use_triu=not (
|
||||
self.is_tf
|
||||
or (self.args.opset or 14) < 14
|
||||
or (self.args.format == "openvino" and self.args.int8) # OpenVINO int8 error with triu
|
||||
),
|
||||
iou_func=batch_probiou,
|
||||
)
|
||||
if self.obb
|
||||
else nms
|
||||
|
|
|
|||
|
|
@ -101,6 +101,8 @@ class BaseValidator:
|
|||
args (SimpleNamespace, optional): Configuration for the validator.
|
||||
_callbacks (dict, optional): Dictionary to store various callback functions.
|
||||
"""
|
||||
import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
|
||||
|
||||
self.args = get_cfg(overrides=args)
|
||||
self.dataloader = dataloader
|
||||
self.stride = None
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import ops
|
||||
from ultralytics.utils import nms, ops
|
||||
|
||||
|
||||
class DetectionPredictor(BasePredictor):
|
||||
|
|
@ -53,7 +53,7 @@ class DetectionPredictor(BasePredictor):
|
|||
>>> processed_results = predictor.postprocess(preds, img, orig_imgs)
|
||||
"""
|
||||
save_feats = getattr(self, "_feats", None) is not None
|
||||
preds = ops.non_max_suppression(
|
||||
preds = nms.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import torch
|
|||
|
||||
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
|
||||
from ultralytics.engine.validator import BaseValidator
|
||||
from ultralytics.utils import LOGGER, ops
|
||||
from ultralytics.utils import LOGGER, nms, ops
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
|
|
@ -115,7 +115,7 @@ class DetectionValidator(BaseValidator):
|
|||
(List[Dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
|
||||
'bboxes', 'conf', 'cls', and 'extra' tensors.
|
||||
"""
|
||||
outputs = ops.non_max_suppression(
|
||||
outputs = nms.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import torch
|
|||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import LOGGER, ops
|
||||
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
|
||||
from ultralytics.utils.nms import TorchNMS
|
||||
|
||||
|
||||
class OBBValidator(DetectionValidator):
|
||||
|
|
@ -281,7 +282,7 @@ class OBBValidator(DetectionValidator):
|
|||
b = bbox[:, :5].clone()
|
||||
b[:, :2] += c
|
||||
# 0.3 could get results close to the ones from official merging script, even slightly better.
|
||||
i = ops.nms_rotated(b, scores, 0.3)
|
||||
i = TorchNMS.fast_nms(b, scores, 0.3, iou_func=batch_probiou)
|
||||
bbox = bbox[i]
|
||||
|
||||
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
|
||||
|
|
|
|||
|
|
@ -850,8 +850,6 @@ class AutoBackend(nn.Module):
|
|||
Args:
|
||||
imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
|
||||
"""
|
||||
import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
|
||||
|
||||
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
|
||||
if any(warmup_types) and (self.device.type != "cpu" or self.triton):
|
||||
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
||||
|
|
|
|||
346
ultralytics/utils/nms.py
Normal file
346
ultralytics/utils/nms.py
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import batch_probiou, box_iou
|
||||
from ultralytics.utils.ops import xywh2xyxy
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
prediction,
|
||||
conf_thres: float = 0.25,
|
||||
iou_thres: float = 0.45,
|
||||
classes=None,
|
||||
agnostic: bool = False,
|
||||
multi_label: bool = False,
|
||||
labels=(),
|
||||
max_det: int = 300,
|
||||
nc: int = 0, # number of classes (optional)
|
||||
max_time_img: float = 0.05,
|
||||
max_nms: int = 30000,
|
||||
max_wh: int = 7680,
|
||||
rotated: bool = False,
|
||||
end2end: bool = False,
|
||||
return_idxs: bool = False,
|
||||
):
|
||||
"""
|
||||
Perform non-maximum suppression (NMS) on prediction results.
|
||||
|
||||
Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple
|
||||
detection formats including standard boxes, rotated boxes, and masks.
|
||||
|
||||
Args:
|
||||
prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)
|
||||
containing boxes, classes, and optional masks.
|
||||
conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.
|
||||
iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.
|
||||
classes (List[int], optional): List of class indices to consider. If None, all classes are considered.
|
||||
agnostic (bool): Whether to perform class-agnostic NMS.
|
||||
multi_label (bool): Whether each box can have multiple labels.
|
||||
labels (List[List[Union[int, float, torch.Tensor]]]): A priori labels for each image.
|
||||
max_det (int): Maximum number of detections to keep per image.
|
||||
nc (int): Number of classes. Indices after this are considered masks.
|
||||
max_time_img (float): Maximum time in seconds for processing one image.
|
||||
max_nms (int): Maximum number of boxes for NMS.
|
||||
max_wh (int): Maximum box width and height in pixels.
|
||||
rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).
|
||||
end2end (bool): Whether the model is end-to-end and doesn't require NMS.
|
||||
return_idxs (bool): Whether to return the indices of kept detections.
|
||||
|
||||
Returns:
|
||||
output (List[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)
|
||||
containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
||||
keepi (List[torch.Tensor]): Indices of kept detections if return_idxs=True.
|
||||
"""
|
||||
# Checks
|
||||
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
||||
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
|
||||
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
|
||||
prediction = prediction[0] # select only inference output
|
||||
if classes is not None:
|
||||
classes = torch.tensor(classes, device=prediction.device)
|
||||
|
||||
if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
|
||||
output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
|
||||
if classes is not None:
|
||||
output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
|
||||
return output
|
||||
|
||||
bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)
|
||||
nc = nc or (prediction.shape[1] - 4) # number of classes
|
||||
extra = prediction.shape[1] - nc - 4 # number of extra info
|
||||
mi = 4 + nc # mask start index
|
||||
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
|
||||
xinds = torch.arange(prediction.shape[-1], device=prediction.device).expand(bs, -1)[..., None] # to track idxs
|
||||
|
||||
# Settings
|
||||
# min_wh = 2 # (pixels) minimum box width and height
|
||||
time_limit = 2.0 + max_time_img * bs # seconds to quit after
|
||||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
|
||||
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||
if not rotated:
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
|
||||
t = time.time()
|
||||
output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs
|
||||
keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs
|
||||
for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices)
|
||||
# Apply constraints
|
||||
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
||||
filt = xc[xi] # confidence
|
||||
x = x[filt]
|
||||
if return_idxs:
|
||||
xk = xk[filt]
|
||||
|
||||
# Cat apriori labels if autolabelling
|
||||
if labels and len(labels[xi]) and not rotated:
|
||||
lb = labels[xi]
|
||||
v = torch.zeros((len(lb), nc + extra + 4), device=x.device)
|
||||
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
|
||||
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
|
||||
x = torch.cat((x, v), 0)
|
||||
|
||||
# If none remain process next image
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Detections matrix nx6 (xyxy, conf, cls)
|
||||
box, cls, mask = x.split((4, nc, extra), 1)
|
||||
|
||||
if multi_label:
|
||||
i, j = torch.where(cls > conf_thres)
|
||||
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
|
||||
if return_idxs:
|
||||
xk = xk[i]
|
||||
else: # best class only
|
||||
conf, j = cls.max(1, keepdim=True)
|
||||
filt = conf.view(-1) > conf_thres
|
||||
x = torch.cat((box, conf, j.float(), mask), 1)[filt]
|
||||
if return_idxs:
|
||||
xk = xk[filt]
|
||||
|
||||
# Filter by class
|
||||
if classes is not None:
|
||||
filt = (x[:, 5:6] == classes).any(1)
|
||||
x = x[filt]
|
||||
if return_idxs:
|
||||
xk = xk[filt]
|
||||
|
||||
# Check shape
|
||||
n = x.shape[0] # number of boxes
|
||||
if not n: # no boxes
|
||||
continue
|
||||
if n > max_nms: # excess boxes
|
||||
filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes
|
||||
x = x[filt]
|
||||
if return_idxs:
|
||||
xk = xk[filt]
|
||||
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||
scores = x[:, 4] # scores
|
||||
if rotated:
|
||||
boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
|
||||
i = TorchNMS.fast_nms(boxes, scores, iou_thres, iou_func=batch_probiou)
|
||||
else:
|
||||
boxes = x[:, :4] + c # boxes (offset by class)
|
||||
# Speed strategy: torchvision for val or already loaded (faster), TorchNMS for predict (lower latency)
|
||||
if "torchvision" in sys.modules:
|
||||
import torchvision # scope as slow import
|
||||
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres)
|
||||
else:
|
||||
i = TorchNMS.nms(boxes, scores, iou_thres)
|
||||
i = i[:max_det] # limit detections
|
||||
|
||||
output[xi] = x[i]
|
||||
if return_idxs:
|
||||
keepi[xi] = xk[i].view(-1)
|
||||
if (time.time() - t) > time_limit:
|
||||
LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded")
|
||||
break # time limit exceeded
|
||||
|
||||
return (output, keepi) if return_idxs else output
|
||||
|
||||
|
||||
class TorchNMS:
|
||||
"""
|
||||
Ultralytics custom NMS implementation optimized for YOLO.
|
||||
|
||||
This class provides static methods for performing non-maximum suppression (NMS) operations on bounding boxes,
|
||||
including both standard NMS and batched NMS for multi-class scenarios.
|
||||
|
||||
Methods:
|
||||
nms: Optimized NMS with early termination that matches torchvision behavior exactly.
|
||||
batched_nms: Batched NMS for class-aware suppression.
|
||||
|
||||
Examples:
|
||||
Perform standard NMS on boxes and scores
|
||||
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
|
||||
>>> scores = torch.tensor([0.9, 0.8])
|
||||
>>> keep = TorchNMS.nms(boxes, scores, 0.5)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def fast_nms(
|
||||
boxes: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
iou_threshold: float,
|
||||
use_triu: bool = True,
|
||||
iou_func=box_iou,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fast-NMS implementation from https://arxiv.org/pdf/1904.02689 using upper triangular matrix operations.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
|
||||
scores (torch.Tensor): Confidence scores with shape (N,).
|
||||
iou_threshold (float): IoU threshold for suppression.
|
||||
use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.
|
||||
iou_func (callable): Function to compute IoU between boxes.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Indices of boxes to keep after NMS.
|
||||
|
||||
Examples:
|
||||
Apply NMS to a set of boxes
|
||||
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
|
||||
>>> scores = torch.tensor([0.9, 0.8])
|
||||
>>> keep = TorchNMS.nms(boxes, scores, 0.5)
|
||||
"""
|
||||
if boxes.numel() == 0:
|
||||
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
|
||||
|
||||
sorted_idx = torch.argsort(scores, descending=True)
|
||||
boxes = boxes[sorted_idx]
|
||||
ious = iou_func(boxes, boxes)
|
||||
if use_triu:
|
||||
ious = ious.triu_(diagonal=1)
|
||||
# NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
|
||||
pick = torch.nonzero((ious >= iou_threshold).sum(0) <= 0).squeeze_(-1)
|
||||
else:
|
||||
n = boxes.shape[0]
|
||||
row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
|
||||
col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
|
||||
upper_mask = row_idx < col_idx
|
||||
ious = ious * upper_mask
|
||||
# Zeroing these scores ensures the additional indices would not affect the final results
|
||||
scores[~((ious >= iou_threshold).sum(0) <= 0)] = 0
|
||||
# NOTE: return indices with fixed length to avoid TFLite reshape error
|
||||
pick = torch.topk(scores, scores.shape[0]).indices
|
||||
return sorted_idx[pick]
|
||||
|
||||
@staticmethod
|
||||
def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
||||
"""
|
||||
Optimized NMS with early termination that matches torchvision behavior exactly.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
|
||||
scores (torch.Tensor): Confidence scores with shape (N,).
|
||||
iou_threshold (float): IoU threshold for suppression.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Indices of boxes to keep after NMS.
|
||||
|
||||
Examples:
|
||||
Apply NMS to a set of boxes
|
||||
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
|
||||
>>> scores = torch.tensor([0.9, 0.8])
|
||||
>>> keep = TorchNMS.nms(boxes, scores, 0.5)
|
||||
"""
|
||||
if boxes.numel() == 0:
|
||||
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
|
||||
|
||||
# Pre-allocate and extract coordinates once
|
||||
x1, y1, x2, y2 = boxes.unbind(1)
|
||||
areas = (x2 - x1) * (y2 - y1)
|
||||
|
||||
# Sort by scores descending
|
||||
_, order = scores.sort(0, descending=True)
|
||||
|
||||
# Pre-allocate keep list with maximum possible size
|
||||
keep = torch.zeros(order.numel(), dtype=torch.int64, device=boxes.device)
|
||||
keep_idx = 0
|
||||
|
||||
while order.numel() > 0:
|
||||
i = order[0]
|
||||
keep[keep_idx] = i
|
||||
keep_idx += 1
|
||||
|
||||
if order.numel() == 1:
|
||||
break
|
||||
|
||||
# Vectorized IoU calculation for remaining boxes
|
||||
rest = order[1:]
|
||||
xx1 = torch.maximum(x1[i], x1[rest])
|
||||
yy1 = torch.maximum(y1[i], y1[rest])
|
||||
xx2 = torch.minimum(x2[i], x2[rest])
|
||||
yy2 = torch.minimum(y2[i], y2[rest])
|
||||
|
||||
# Fast intersection and IoU
|
||||
w = (xx2 - xx1).clamp_(min=0)
|
||||
h = (yy2 - yy1).clamp_(min=0)
|
||||
inter = w * h
|
||||
|
||||
# Early termination: skip IoU calculation if no intersection
|
||||
if inter.sum() == 0:
|
||||
# No overlaps with current box, keep all remaining boxes
|
||||
remaining_count = rest.numel()
|
||||
keep[keep_idx : keep_idx + remaining_count] = rest
|
||||
keep_idx += remaining_count
|
||||
break
|
||||
|
||||
iou = inter / (areas[i] + areas[rest] - inter)
|
||||
|
||||
# Keep boxes with IoU <= threshold
|
||||
mask = iou <= iou_threshold
|
||||
order = rest[mask]
|
||||
|
||||
return keep[:keep_idx]
|
||||
|
||||
@staticmethod
|
||||
def batched_nms(
|
||||
boxes: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
idxs: torch.Tensor,
|
||||
iou_threshold: float,
|
||||
use_fast_nms: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Batched NMS for class-aware suppression.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
|
||||
scores (torch.Tensor): Confidence scores with shape (N,).
|
||||
idxs (torch.Tensor): Class indices with shape (N,).
|
||||
iou_threshold (float): IoU threshold for suppression.
|
||||
use_fast_nms (bool): Whether to use the Fast-NMS implementation.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Indices of boxes to keep after NMS.
|
||||
|
||||
Examples:
|
||||
Apply batched NMS across multiple classes
|
||||
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
|
||||
>>> scores = torch.tensor([0.9, 0.8])
|
||||
>>> idxs = torch.tensor([0, 1])
|
||||
>>> keep = TorchNMS.batched_nms(boxes, scores, idxs, 0.5)
|
||||
"""
|
||||
if boxes.numel() == 0:
|
||||
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
|
||||
|
||||
# Strategy: offset boxes by class index to prevent cross-class suppression
|
||||
max_coordinate = boxes.max()
|
||||
offsets = idxs.to(boxes) * (max_coordinate + 1)
|
||||
boxes_for_nms = boxes + offsets[:, None]
|
||||
|
||||
return (
|
||||
TorchNMS.fast_nms(boxes_for_nms, scores, iou_threshold)
|
||||
if use_fast_nms
|
||||
else TorchNMS.nms(boxes_for_nms, scores, iou_threshold)
|
||||
)
|
||||
|
|
@ -11,8 +11,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.utils import LOGGER, NOT_MACOS14
|
||||
from ultralytics.utils.metrics import batch_probiou
|
||||
from ultralytics.utils import NOT_MACOS14
|
||||
|
||||
|
||||
class Profile(contextlib.ContextDecorator):
|
||||
|
|
@ -154,188 +153,6 @@ def make_divisible(x: int, divisor):
|
|||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def nms_rotated(boxes, scores, threshold: float = 0.45, use_triu: bool = True):
|
||||
"""
|
||||
Perform NMS on oriented bounding boxes using probiou and fast-nms.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Rotated bounding boxes with shape (N, 5) in xywhr format.
|
||||
scores (torch.Tensor): Confidence scores with shape (N,).
|
||||
threshold (float): IoU threshold for NMS.
|
||||
use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Indices of boxes to keep after NMS.
|
||||
"""
|
||||
sorted_idx = torch.argsort(scores, descending=True)
|
||||
boxes = boxes[sorted_idx]
|
||||
ious = batch_probiou(boxes, boxes)
|
||||
if use_triu:
|
||||
ious = ious.triu_(diagonal=1)
|
||||
# NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
|
||||
pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
|
||||
else:
|
||||
n = boxes.shape[0]
|
||||
row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
|
||||
col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
|
||||
upper_mask = row_idx < col_idx
|
||||
ious = ious * upper_mask
|
||||
# Zeroing these scores ensures the additional indices would not affect the final results
|
||||
scores[~((ious >= threshold).sum(0) <= 0)] = 0
|
||||
# NOTE: return indices with fixed length to avoid TFLite reshape error
|
||||
pick = torch.topk(scores, scores.shape[0]).indices
|
||||
return sorted_idx[pick]
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
prediction,
|
||||
conf_thres: float = 0.25,
|
||||
iou_thres: float = 0.45,
|
||||
classes=None,
|
||||
agnostic: bool = False,
|
||||
multi_label: bool = False,
|
||||
labels=(),
|
||||
max_det: int = 300,
|
||||
nc: int = 0, # number of classes (optional)
|
||||
max_time_img: float = 0.05,
|
||||
max_nms: int = 30000,
|
||||
max_wh: int = 7680,
|
||||
in_place: bool = True,
|
||||
rotated: bool = False,
|
||||
end2end: bool = False,
|
||||
return_idxs: bool = False,
|
||||
):
|
||||
"""
|
||||
Perform non-maximum suppression (NMS) on prediction results.
|
||||
|
||||
Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple
|
||||
detection formats including standard boxes, rotated boxes, and masks.
|
||||
|
||||
Args:
|
||||
prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)
|
||||
containing boxes, classes, and optional masks.
|
||||
conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.
|
||||
iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.
|
||||
classes (List[int], optional): List of class indices to consider. If None, all classes are considered.
|
||||
agnostic (bool): Whether to perform class-agnostic NMS.
|
||||
multi_label (bool): Whether each box can have multiple labels.
|
||||
labels (List[List[Union[int, float, torch.Tensor]]]): A priori labels for each image.
|
||||
max_det (int): Maximum number of detections to keep per image.
|
||||
nc (int): Number of classes. Indices after this are considered masks.
|
||||
max_time_img (float): Maximum time in seconds for processing one image.
|
||||
max_nms (int): Maximum number of boxes for torchvision.ops.nms().
|
||||
max_wh (int): Maximum box width and height in pixels.
|
||||
in_place (bool): Whether to modify the input prediction tensor in place.
|
||||
rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).
|
||||
end2end (bool): Whether the model is end-to-end and doesn't require NMS.
|
||||
return_idxs (bool): Whether to return the indices of kept detections.
|
||||
|
||||
Returns:
|
||||
output (List[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)
|
||||
containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
||||
keepi (List[torch.Tensor]): Indices of kept detections if return_idxs=True.
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
# Checks
|
||||
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
||||
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
|
||||
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
|
||||
prediction = prediction[0] # select only inference output
|
||||
if classes is not None:
|
||||
classes = torch.tensor(classes, device=prediction.device)
|
||||
|
||||
if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
|
||||
output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
|
||||
if classes is not None:
|
||||
output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
|
||||
return output
|
||||
|
||||
bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)
|
||||
nc = nc or (prediction.shape[1] - 4) # number of classes
|
||||
extra = prediction.shape[1] - nc - 4 # number of extra info
|
||||
mi = 4 + nc # mask start index
|
||||
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
|
||||
xinds = torch.stack([torch.arange(len(i), device=prediction.device) for i in xc])[..., None] # to track idxs
|
||||
|
||||
# Settings
|
||||
# min_wh = 2 # (pixels) minimum box width and height
|
||||
time_limit = 2.0 + max_time_img * bs # seconds to quit after
|
||||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
|
||||
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||
if not rotated:
|
||||
if in_place:
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
else:
|
||||
prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy
|
||||
|
||||
t = time.time()
|
||||
output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs
|
||||
keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs
|
||||
for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices)
|
||||
# Apply constraints
|
||||
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
||||
filt = xc[xi] # confidence
|
||||
x, xk = x[filt], xk[filt]
|
||||
|
||||
# Cat apriori labels if autolabelling
|
||||
if labels and len(labels[xi]) and not rotated:
|
||||
lb = labels[xi]
|
||||
v = torch.zeros((len(lb), nc + extra + 4), device=x.device)
|
||||
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
|
||||
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
|
||||
x = torch.cat((x, v), 0)
|
||||
|
||||
# If none remain process next image
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Detections matrix nx6 (xyxy, conf, cls)
|
||||
box, cls, mask = x.split((4, nc, extra), 1)
|
||||
|
||||
if multi_label:
|
||||
i, j = torch.where(cls > conf_thres)
|
||||
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
|
||||
xk = xk[i]
|
||||
else: # best class only
|
||||
conf, j = cls.max(1, keepdim=True)
|
||||
filt = conf.view(-1) > conf_thres
|
||||
x = torch.cat((box, conf, j.float(), mask), 1)[filt]
|
||||
xk = xk[filt]
|
||||
|
||||
# Filter by class
|
||||
if classes is not None:
|
||||
filt = (x[:, 5:6] == classes).any(1)
|
||||
x, xk = x[filt], xk[filt]
|
||||
|
||||
# Check shape
|
||||
n = x.shape[0] # number of boxes
|
||||
if not n: # no boxes
|
||||
continue
|
||||
if n > max_nms: # excess boxes
|
||||
filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes
|
||||
x, xk = x[filt], xk[filt]
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||
scores = x[:, 4] # scores
|
||||
if rotated:
|
||||
boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
|
||||
i = nms_rotated(boxes, scores, iou_thres)
|
||||
else:
|
||||
boxes = x[:, :4] + c # boxes (offset by class)
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
i = i[:max_det] # limit detections
|
||||
|
||||
output[xi], keepi[xi] = x[i], xk[i].reshape(-1)
|
||||
if (time.time() - t) > time_limit:
|
||||
LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded")
|
||||
break # time limit exceeded
|
||||
|
||||
return (output, keepi) if return_idxs else output
|
||||
|
||||
|
||||
def clip_boxes(boxes, shape):
|
||||
"""
|
||||
Clip bounding boxes to image boundaries.
|
||||
|
|
|
|||
|
|
@ -387,7 +387,7 @@ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
|||
if xywh:
|
||||
c_xy = (x1y1 + x2y2) / 2
|
||||
wh = x2y2 - x1y1
|
||||
return torch.cat((c_xy, wh), dim) # xywh bbox
|
||||
return torch.cat([c_xy, wh], dim) # xywh bbox
|
||||
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ class TQDM:
|
|||
|
||||
# For bytes with scaling, use binary units
|
||||
if self.unit in ("B", "bytes") and self.unit_scale:
|
||||
for threshold, unit in [(1024**3, "GB/s"), (1024**2, "MB/s"), (1024, "KB/s")]:
|
||||
for threshold, unit in [(1073741824, "GB/s"), (1048576, "MB/s"), (1024, "KB/s")]: # 1 << 30, << 20, << 10
|
||||
if rate >= threshold:
|
||||
return f"{rate / threshold:.1f}{unit}"
|
||||
return f"{rate:.1f}B/s"
|
||||
|
|
|
|||
Loading…
Reference in a new issue