mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
ultralytics 8.4.40 Per-image Precision and Recall (#24089)
Some checks are pending
CI / Benchmarks (yolo26n, macos-26, 3.12) (push) Waiting to run
CI / Benchmarks (yolo26n, ubuntu-24.04-arm, 3.12) (push) Waiting to run
CI / Benchmarks (yolo26n, ubuntu-latest, 3.12) (push) Waiting to run
CI / Tests (macos-26, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-24.04-arm, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-latest, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-latest, 3.8, 1.8.0, 0.9.0) (push) Waiting to run
CI / Tests (windows-latest, 3.12, latest) (push) Waiting to run
CI / SlowTests (macos-26, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-24.04-arm, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.11.0, 0.12.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.12.0, 0.13.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.13.0, 0.14.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.11, 2.0.0, 0.15.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.11, 2.1.0, 0.16.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.10.0, 0.25.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.11.0, 0.26.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.2.0, 0.17.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.3.0, 0.18.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.4.0, 0.19.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.5.0, 0.20.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.6.0, 0.21.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.7.0, 0.22.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.8.0, 0.23.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.9.0, 0.24.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.8, 1.8.0, 0.9.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.9, 1.10.0, 0.11.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.9, 1.9.0, 0.10.0) (push) Waiting to run
CI / SlowTests (windows-latest, 3.12, latest) (push) Waiting to run
CI / GPU (push) Waiting to run
CI / RaspberryPi (push) Waiting to run
CI / NVIDIA_Jetson (JetPack5.1.2, 1.23.5, https://github.com/ultralytics/assets/releases/download/v0.0.0/onnxruntime_gpu-1.16.3-cp38-cp38-linux_aarch64.whl, 3.8, jetson-jp512, https://github.com/ultralytics/assets/releases/download/v0.0.0/torch-2.2.0-cp38-c… (push) Waiting to run
CI / NVIDIA_Jetson (JetPack6.2, 1.26.4, https://github.com/ultralytics/assets/releases/download/v0.0.0/onnxruntime_gpu-1.20.0-cp310-cp310-linux_aarch64.whl, 3.10, jetson-jp62, https://github.com/ultralytics/assets/releases/download/v0.0.0/torch-2.5.0a0+872d… (push) Waiting to run
CI / Conda (ubuntu-latest, 3.12) (push) Waiting to run
CI / Summary (push) Blocked by required conditions
Publish Docker Images / Build (push) Waiting to run
Publish Docker Images / trigger-actions (push) Blocked by required conditions
Publish Docker Images / notify (push) Blocked by required conditions
Publish Docs / Docs (push) Waiting to run
Publish to PyPI / check (push) Waiting to run
Publish to PyPI / build (push) Blocked by required conditions
Publish to PyPI / publish (push) Blocked by required conditions
Publish to PyPI / sbom (push) Blocked by required conditions
Publish to PyPI / notify (push) Blocked by required conditions
Some checks are pending
CI / Benchmarks (yolo26n, macos-26, 3.12) (push) Waiting to run
CI / Benchmarks (yolo26n, ubuntu-24.04-arm, 3.12) (push) Waiting to run
CI / Benchmarks (yolo26n, ubuntu-latest, 3.12) (push) Waiting to run
CI / Tests (macos-26, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-24.04-arm, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-latest, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-latest, 3.8, 1.8.0, 0.9.0) (push) Waiting to run
CI / Tests (windows-latest, 3.12, latest) (push) Waiting to run
CI / SlowTests (macos-26, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-24.04-arm, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.11.0, 0.12.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.12.0, 0.13.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.13.0, 0.14.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.11, 2.0.0, 0.15.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.11, 2.1.0, 0.16.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.10.0, 0.25.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.11.0, 0.26.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.2.0, 0.17.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.3.0, 0.18.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.4.0, 0.19.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.5.0, 0.20.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.6.0, 0.21.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.7.0, 0.22.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.8.0, 0.23.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.9.0, 0.24.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.8, 1.8.0, 0.9.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.9, 1.10.0, 0.11.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.9, 1.9.0, 0.10.0) (push) Waiting to run
CI / SlowTests (windows-latest, 3.12, latest) (push) Waiting to run
CI / GPU (push) Waiting to run
CI / RaspberryPi (push) Waiting to run
CI / NVIDIA_Jetson (JetPack5.1.2, 1.23.5, https://github.com/ultralytics/assets/releases/download/v0.0.0/onnxruntime_gpu-1.16.3-cp38-cp38-linux_aarch64.whl, 3.8, jetson-jp512, https://github.com/ultralytics/assets/releases/download/v0.0.0/torch-2.2.0-cp38-c… (push) Waiting to run
CI / NVIDIA_Jetson (JetPack6.2, 1.26.4, https://github.com/ultralytics/assets/releases/download/v0.0.0/onnxruntime_gpu-1.20.0-cp310-cp310-linux_aarch64.whl, 3.10, jetson-jp62, https://github.com/ultralytics/assets/releases/download/v0.0.0/torch-2.5.0a0+872d… (push) Waiting to run
CI / Conda (ubuntu-latest, 3.12) (push) Waiting to run
CI / Summary (push) Blocked by required conditions
Publish Docker Images / Build (push) Waiting to run
Publish Docker Images / trigger-actions (push) Blocked by required conditions
Publish Docker Images / notify (push) Blocked by required conditions
Publish Docs / Docs (push) Waiting to run
Publish to PyPI / check (push) Waiting to run
Publish to PyPI / build (push) Blocked by required conditions
Publish to PyPI / publish (push) Blocked by required conditions
Publish to PyPI / sbom (push) Blocked by required conditions
Publish to PyPI / notify (push) Blocked by required conditions
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
d86cb96689
commit
901dcb0e47
12 changed files with 146 additions and 2 deletions
|
|
@ -93,6 +93,7 @@ This logs the mean F1 score across all classes and a per-class breakdown after e
|
||||||
| Attribute | Description |
|
| Attribute | Description |
|
||||||
|---|---|
|
|---|---|
|
||||||
| `f1` | F1 score per class |
|
| `f1` | F1 score per class |
|
||||||
|
| `image_metrics` | Per-image metrics dictionary with precision, recall, F1, TP, FP, and FN |
|
||||||
| `p` | Precision per class |
|
| `p` | Precision per class |
|
||||||
| `r` | Recall per class |
|
| `r` | Recall per class |
|
||||||
| `ap50` | AP at IoU 0.5 per class |
|
| `ap50` | AP at IoU 0.5 per class |
|
||||||
|
|
|
||||||
|
|
@ -104,6 +104,7 @@ If you want to get a deeper understanding of your YOLO26 model's performance, yo
|
||||||
print("Mean results for different metrics:", results.box.mean_results)
|
print("Mean results for different metrics:", results.box.mean_results)
|
||||||
print("Mean precision:", results.box.mp)
|
print("Mean precision:", results.box.mp)
|
||||||
print("Mean recall:", results.box.mr)
|
print("Mean recall:", results.box.mr)
|
||||||
|
print("Per-image metrics:", results.box.image_metrics)
|
||||||
print("Precision:", results.box.p)
|
print("Precision:", results.box.p)
|
||||||
print("Precision curve:", results.box.p_curve)
|
print("Precision curve:", results.box.p_curve)
|
||||||
print("Precision values:", results.box.prec_values)
|
print("Precision values:", results.box.prec_values)
|
||||||
|
|
@ -112,7 +113,10 @@ If you want to get a deeper understanding of your YOLO26 model's performance, yo
|
||||||
print("Recall curve:", results.box.r_curve)
|
print("Recall curve:", results.box.r_curve)
|
||||||
```
|
```
|
||||||
|
|
||||||
The results object also includes speed metrics like preprocess time, inference time, loss, and postprocess time. By analyzing these metrics, you can fine-tune and optimize your YOLO26 model for better performance, making it more effective for your specific use case.
|
The results object also includes `image_metrics`, a per-image dictionary keyed by image filename with `precision`,
|
||||||
|
`recall`, `f1`, `tp`, `fp`, and `fn`, as well as speed metrics like preprocess time, inference time, loss, and
|
||||||
|
postprocess time. By analyzing these metrics, you can fine-tune and optimize your YOLO26 model for better performance,
|
||||||
|
making it more effective for your specific use case.
|
||||||
|
|
||||||
## How Does Fine-Tuning Work?
|
## How Does Fine-Tuning Work?
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,7 @@ Validate a trained YOLO26n model [accuracy](https://www.ultralytics.com/glossary
|
||||||
metrics.box.map50 # map50
|
metrics.box.map50 # map50
|
||||||
metrics.box.map75 # map75
|
metrics.box.map75 # map75
|
||||||
metrics.box.maps # a list containing mAP50-95 for each category
|
metrics.box.maps # a list containing mAP50-95 for each category
|
||||||
|
metrics.box.image_metrics # per-image metrics dictionary with precision, recall, F1, TP, FP, and FN
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "CLI"
|
=== "CLI"
|
||||||
|
|
@ -137,6 +138,42 @@ The below examples showcase YOLO model validation with custom arguments in Pytho
|
||||||
print(results.confusion_matrix.to_df())
|
print(results.confusion_matrix.to_df())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
!!! tip "Per-Image Precision, Recall, and F1"
|
||||||
|
|
||||||
|
Validation stores per-image precision, recall, F1, TP, FP, and FN metrics (at IoU threshold 0.5) for all tasks
|
||||||
|
except classification. Access them through `results.box.image_metrics` for detection and OBB, `results.seg.image_metrics`
|
||||||
|
for segmentation, and `results.pose.image_metrics` for pose after validation completes.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# Load a model
|
||||||
|
model = YOLO("yolo26n.pt")
|
||||||
|
|
||||||
|
# Validate and access per-image metrics
|
||||||
|
results = model.val(data="coco8.yaml")
|
||||||
|
|
||||||
|
# image_metrics is a dictionary with image filenames as keys
|
||||||
|
print(results.box.image_metrics)
|
||||||
|
# Output: {'image1.jpg': {'precision': 0.85, 'recall': 0.92, 'f1': 0.88, 'tp': 17, 'fp': 3, 'fn': 1}, ...}
|
||||||
|
|
||||||
|
# Access metrics for a specific image
|
||||||
|
results.box.image_metrics["image1.jpg"] # {'precision': 0.85, 'recall': 0.92, 'f1': 0.88, 'tp': 17, 'fp': 3, 'fn': 1}
|
||||||
|
```
|
||||||
|
|
||||||
|
Each entry in `image_metrics` contains the following keys:
|
||||||
|
|
||||||
|
| Key | Description |
|
||||||
|
|-------------|---------------------------------------------------|
|
||||||
|
| `precision` | Precision score for the image (`tp / (tp + fp)`). |
|
||||||
|
| `recall` | Recall score for the image (`tp / (tp + fn)`). |
|
||||||
|
| `f1` | Harmonic mean of precision and recall. |
|
||||||
|
| `tp` | Number of true positives for the image. |
|
||||||
|
| `fp` | Number of false positives for the image. |
|
||||||
|
| `fn` | Number of false negatives for the image. |
|
||||||
|
|
||||||
|
This feature is available for detection, segmentation, pose, and OBB tasks.
|
||||||
|
|
||||||
| Method | Return Type | Description |
|
| Method | Return Type | Description |
|
||||||
| ----------- | ---------------------- | -------------------------------------------------------------------------- |
|
| ----------- | ---------------------- | -------------------------------------------------------------------------- |
|
||||||
| `summary()` | `List[Dict[str, Any]]` | Converts validation results to a summarized dictionary. |
|
| `summary()` | `List[Dict[str, Any]]` | Converts validation results to a summarized dictionary. |
|
||||||
|
|
@ -187,6 +224,7 @@ print(metrics.box.map) # mAP50-95
|
||||||
print(metrics.box.map50) # mAP50
|
print(metrics.box.map50) # mAP50
|
||||||
print(metrics.box.map75) # mAP75
|
print(metrics.box.map75) # mAP75
|
||||||
print(metrics.box.maps) # list of mAP50-95 for each category
|
print(metrics.box.maps) # list of mAP50-95 for each category
|
||||||
|
print(metrics.box.image_metrics) # per-image metrics dictionary with precision, recall, F1, TP, FP, and FN
|
||||||
```
|
```
|
||||||
|
|
||||||
For a complete performance evaluation, it's crucial to review all these metrics. For more details, refer to the [Key Features of Val Mode](#key-features-of-val-mode).
|
For a complete performance evaluation, it's crucial to review all these metrics. For more details, refer to the [Key Features of Val Mode](#key-features-of-val-mode).
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,7 @@ Validate trained YOLO26n model [accuracy](https://www.ultralytics.com/glossary/a
|
||||||
metrics.box.map50 # map50
|
metrics.box.map50 # map50
|
||||||
metrics.box.map75 # map75
|
metrics.box.map75 # map75
|
||||||
metrics.box.maps # a list containing mAP50-95 for each category
|
metrics.box.maps # a list containing mAP50-95 for each category
|
||||||
|
metrics.box.image_metrics # per-image metrics dictionary with precision, recall, F1, TP, FP, and FN
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "CLI"
|
=== "CLI"
|
||||||
|
|
|
||||||
|
|
@ -127,6 +127,7 @@ Validate trained YOLO26n-obb model [accuracy](https://www.ultralytics.com/glossa
|
||||||
metrics.box.map50 # map50(B)
|
metrics.box.map50 # map50(B)
|
||||||
metrics.box.map75 # map75(B)
|
metrics.box.map75 # map75(B)
|
||||||
metrics.box.maps # a list containing mAP50-95(B) for each category
|
metrics.box.maps # a list containing mAP50-95(B) for each category
|
||||||
|
metrics.box.image_metrics # per-image metrics dictionary with precision, recall, F1, TP, FP, and FN
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "CLI"
|
=== "CLI"
|
||||||
|
|
|
||||||
|
|
@ -120,10 +120,12 @@ Validate trained YOLO26n-pose model [accuracy](https://www.ultralytics.com/gloss
|
||||||
metrics.box.map50 # map50
|
metrics.box.map50 # map50
|
||||||
metrics.box.map75 # map75
|
metrics.box.map75 # map75
|
||||||
metrics.box.maps # a list containing mAP50-95 for each category
|
metrics.box.maps # a list containing mAP50-95 for each category
|
||||||
|
metrics.box.image_metrics # per-image metrics dictionary for box with precision, recall, F1, TP, FP, and FN
|
||||||
metrics.pose.map # map50-95(P)
|
metrics.pose.map # map50-95(P)
|
||||||
metrics.pose.map50 # map50(P)
|
metrics.pose.map50 # map50(P)
|
||||||
metrics.pose.map75 # map75(P)
|
metrics.pose.map75 # map75(P)
|
||||||
metrics.pose.maps # a list containing mAP50-95(P) for each category
|
metrics.pose.maps # a list containing mAP50-95(P) for each category
|
||||||
|
metrics.pose.image_metrics # per-image metrics dictionary for pose with precision, recall, F1, TP, FP, and FN
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "CLI"
|
=== "CLI"
|
||||||
|
|
|
||||||
|
|
@ -98,10 +98,12 @@ Validate trained YOLO26n-seg model [accuracy](https://www.ultralytics.com/glossa
|
||||||
metrics.box.map50 # map50(B)
|
metrics.box.map50 # map50(B)
|
||||||
metrics.box.map75 # map75(B)
|
metrics.box.map75 # map75(B)
|
||||||
metrics.box.maps # a list containing mAP50-95(B) for each category
|
metrics.box.maps # a list containing mAP50-95(B) for each category
|
||||||
|
metrics.box.image_metrics # per-image metrics dictionary for det with precision, recall, F1, TP, FP, and FN
|
||||||
metrics.seg.map # map50-95(M)
|
metrics.seg.map # map50-95(M)
|
||||||
metrics.seg.map50 # map50(M)
|
metrics.seg.map50 # map50(M)
|
||||||
metrics.seg.map75 # map75(M)
|
metrics.seg.map75 # map75(M)
|
||||||
metrics.seg.maps # a list containing mAP50-95(M) for each category
|
metrics.seg.maps # a list containing mAP50-95(M) for each category
|
||||||
|
metrics.seg.image_metrics # per-image metrics dictionary for seg with precision, recall, F1, TP, FP, and FN
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "CLI"
|
=== "CLI"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||||
|
|
||||||
__version__ = "8.4.39"
|
__version__ = "8.4.40"
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,8 @@ class DetectionValidator(BaseValidator):
|
||||||
self.seen = 0
|
self.seen = 0
|
||||||
self.jdict = []
|
self.jdict = []
|
||||||
self.metrics.names = model.names
|
self.metrics.names = model.names
|
||||||
|
self.metrics.clear_stats()
|
||||||
|
self.metrics.clear_image_metrics()
|
||||||
self.confusion_matrix = ConfusionMatrix(names=model.names, save_matches=self.args.plots and self.args.visualize)
|
self.confusion_matrix = ConfusionMatrix(names=model.names, save_matches=self.args.plots and self.args.visualize)
|
||||||
|
|
||||||
def get_desc(self) -> str:
|
def get_desc(self) -> str:
|
||||||
|
|
@ -186,6 +188,7 @@ class DetectionValidator(BaseValidator):
|
||||||
"target_img": np.unique(cls),
|
"target_img": np.unique(cls),
|
||||||
"conf": np.zeros(0) if no_pred else predn["conf"].cpu().numpy(),
|
"conf": np.zeros(0) if no_pred else predn["conf"].cpu().numpy(),
|
||||||
"pred_cls": np.zeros(0) if no_pred else predn["cls"].cpu().numpy(),
|
"pred_cls": np.zeros(0) if no_pred else predn["cls"].cpu().numpy(),
|
||||||
|
"im_name": Path(pbatch["im_file"]).name,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# Evaluate
|
# Evaluate
|
||||||
|
|
@ -219,6 +222,19 @@ class DetectionValidator(BaseValidator):
|
||||||
self.metrics.confusion_matrix = self.confusion_matrix
|
self.metrics.confusion_matrix = self.confusion_matrix
|
||||||
self.metrics.save_dir = self.save_dir
|
self.metrics.save_dir = self.save_dir
|
||||||
|
|
||||||
|
def _gather_image_metrics(self, metric) -> None:
|
||||||
|
"""Gather per-image metrics from all GPUs for a single metric object."""
|
||||||
|
if RANK == 0:
|
||||||
|
gathered_image_metrics = [None] * dist.get_world_size()
|
||||||
|
dist.gather_object(metric.image_metrics, gathered_image_metrics, dst=0)
|
||||||
|
metric.clear_image_metrics()
|
||||||
|
for image_metrics in gathered_image_metrics:
|
||||||
|
if image_metrics:
|
||||||
|
metric.image_metrics.update(image_metrics)
|
||||||
|
elif RANK > 0:
|
||||||
|
dist.gather_object(metric.image_metrics, None, dst=0)
|
||||||
|
metric.clear_image_metrics()
|
||||||
|
|
||||||
def gather_stats(self) -> None:
|
def gather_stats(self) -> None:
|
||||||
"""Gather stats from all GPUs."""
|
"""Gather stats from all GPUs."""
|
||||||
if RANK == 0:
|
if RANK == 0:
|
||||||
|
|
@ -234,10 +250,12 @@ class DetectionValidator(BaseValidator):
|
||||||
for jdict in gathered_jdict:
|
for jdict in gathered_jdict:
|
||||||
self.jdict.extend(jdict)
|
self.jdict.extend(jdict)
|
||||||
self.metrics.stats = merged_stats
|
self.metrics.stats = merged_stats
|
||||||
|
self._gather_image_metrics(self.metrics.box)
|
||||||
self.seen = len(self.dataloader.dataset) # total image count from dataset
|
self.seen = len(self.dataloader.dataset) # total image count from dataset
|
||||||
elif RANK > 0:
|
elif RANK > 0:
|
||||||
dist.gather_object(self.metrics.stats, None, dst=0)
|
dist.gather_object(self.metrics.stats, None, dst=0)
|
||||||
dist.gather_object(self.jdict, None, dst=0)
|
dist.gather_object(self.jdict, None, dst=0)
|
||||||
|
self._gather_image_metrics(self.metrics.box)
|
||||||
self.jdict = []
|
self.jdict = []
|
||||||
self.metrics.clear_stats()
|
self.metrics.clear_stats()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -185,6 +185,11 @@ class PoseValidator(DetectionValidator):
|
||||||
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
|
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
|
||||||
return tp
|
return tp
|
||||||
|
|
||||||
|
def gather_stats(self) -> None:
|
||||||
|
"""Gather stats from all GPUs."""
|
||||||
|
super().gather_stats() # gather stats from DetectionValidator
|
||||||
|
self._gather_image_metrics(self.metrics.pose)
|
||||||
|
|
||||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||||
"""Save YOLO pose detections to a text file in normalized coordinates.
|
"""Save YOLO pose detections to a text file in normalized coordinates.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -141,6 +141,11 @@ class SegmentationValidator(DetectionValidator):
|
||||||
prepared_batch["masks"] = masks
|
prepared_batch["masks"] = masks
|
||||||
return prepared_batch
|
return prepared_batch
|
||||||
|
|
||||||
|
def gather_stats(self) -> None:
|
||||||
|
"""Gather stats from all GPUs."""
|
||||||
|
super().gather_stats() # gather stats from DetectionValidator
|
||||||
|
self._gather_image_metrics(self.metrics.seg)
|
||||||
|
|
||||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||||
"""Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
"""Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -879,6 +879,7 @@ class Metric(SimpleClass):
|
||||||
self.all_ap = [] # (nc, 10)
|
self.all_ap = [] # (nc, 10)
|
||||||
self.ap_class_index = [] # (nc, )
|
self.ap_class_index = [] # (nc, )
|
||||||
self.nc = 0
|
self.nc = 0
|
||||||
|
self.image_metrics = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ap50(self) -> np.ndarray | list:
|
def ap50(self) -> np.ndarray | list:
|
||||||
|
|
@ -993,6 +994,10 @@ class Metric(SimpleClass):
|
||||||
self.prec_values,
|
self.prec_values,
|
||||||
) = results
|
) = results
|
||||||
|
|
||||||
|
def clear_image_metrics(self) -> None:
|
||||||
|
"""Clear stored per-image metrics from the current validation run."""
|
||||||
|
self.image_metrics.clear()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def curves(self) -> list:
|
def curves(self) -> list:
|
||||||
"""Return a list of curves for accessing specific metrics curves."""
|
"""Return a list of curves for accessing specific metrics curves."""
|
||||||
|
|
@ -1008,6 +1013,33 @@ class Metric(SimpleClass):
|
||||||
[self.px, self.r_curve, "Confidence", "Recall"],
|
[self.px, self.r_curve, "Confidence", "Recall"],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def update_image_metrics(self, tp: np.ndarray, target_cls: np.ndarray, pred_cls: np.ndarray, im_name: str) -> None:
|
||||||
|
"""Update per-image precision, recall, F1, TP, FP, and FN at IoU threshold 0.5.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tp (np.ndarray): True positive array of shape (num_preds, num_iou_thresholds), where the first column (IoU
|
||||||
|
>= 0.5) is used.
|
||||||
|
target_cls (np.ndarray): Ground truth class labels for the image.
|
||||||
|
pred_cls (np.ndarray): Predicted class labels for the image.
|
||||||
|
im_name (str): The image filename used as the per-image key.
|
||||||
|
"""
|
||||||
|
# Use the default IoU=0.5 column to match the validator's image-level matching policy.
|
||||||
|
tp = int(tp[:, 0].sum())
|
||||||
|
num_preds = pred_cls.shape[0]
|
||||||
|
num_targets = target_cls.shape[0]
|
||||||
|
fp = num_preds - tp
|
||||||
|
fn = num_targets - tp
|
||||||
|
precision = tp / num_preds if num_preds else 0
|
||||||
|
recall = tp / num_targets if num_targets else 0
|
||||||
|
self.image_metrics[im_name] = {
|
||||||
|
"precision": float(precision),
|
||||||
|
"recall": float(recall),
|
||||||
|
"f1": float(2 * (precision * recall) / (precision + recall)) if (precision + recall) else 0.0,
|
||||||
|
"tp": int(tp),
|
||||||
|
"fp": int(fp),
|
||||||
|
"fn": int(fn),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class DetMetrics(SimpleClass, DataExportMixin):
|
class DetMetrics(SimpleClass, DataExportMixin):
|
||||||
"""Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
"""Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
||||||
|
|
@ -1059,6 +1091,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
||||||
"""
|
"""
|
||||||
for k in self.stats.keys():
|
for k in self.stats.keys():
|
||||||
self.stats[k].append(stat[k])
|
self.stats[k].append(stat[k])
|
||||||
|
self.box.update_image_metrics(stat["tp"], stat["target_cls"], stat["pred_cls"], stat["im_name"])
|
||||||
|
|
||||||
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
||||||
"""Process predicted results for object detection and update metrics.
|
"""Process predicted results for object detection and update metrics.
|
||||||
|
|
@ -1096,6 +1129,10 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
||||||
for v in self.stats.values():
|
for v in self.stats.values():
|
||||||
v.clear()
|
v.clear()
|
||||||
|
|
||||||
|
def clear_image_metrics(self) -> None:
|
||||||
|
"""Clear stored per-image metrics."""
|
||||||
|
self.box.clear_image_metrics()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keys(self) -> list[str]:
|
def keys(self) -> list[str]:
|
||||||
"""Return a list of keys for accessing specific metrics."""
|
"""Return a list of keys for accessing specific metrics."""
|
||||||
|
|
@ -1211,6 +1248,21 @@ class SegmentMetrics(DetMetrics):
|
||||||
self.seg = Metric()
|
self.seg = Metric()
|
||||||
self.stats["tp_m"] = [] # add additional stats for masks
|
self.stats["tp_m"] = [] # add additional stats for masks
|
||||||
|
|
||||||
|
def update_stats(self, stat: dict[str, Any]) -> None:
|
||||||
|
"""Update statistics by appending new values to existing stat collections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stat (dict[str, Any]): Dictionary containing new statistical values to append. Keys should match existing
|
||||||
|
keys in self.stats.
|
||||||
|
"""
|
||||||
|
super().update_stats(stat) # update box stats
|
||||||
|
self.seg.update_image_metrics(stat["tp_m"], stat["target_cls"], stat["pred_cls"], stat["im_name"])
|
||||||
|
|
||||||
|
def clear_image_metrics(self) -> None:
|
||||||
|
"""Clear stored per-image metrics."""
|
||||||
|
super().clear_image_metrics()
|
||||||
|
self.seg.clear_image_metrics()
|
||||||
|
|
||||||
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
||||||
"""Process the detection and segmentation metrics over the given set of predictions.
|
"""Process the detection and segmentation metrics over the given set of predictions.
|
||||||
|
|
||||||
|
|
@ -1347,6 +1399,21 @@ class PoseMetrics(DetMetrics):
|
||||||
self.pose = Metric()
|
self.pose = Metric()
|
||||||
self.stats["tp_p"] = [] # add additional stats for pose
|
self.stats["tp_p"] = [] # add additional stats for pose
|
||||||
|
|
||||||
|
def update_stats(self, stat: dict[str, Any]) -> None:
|
||||||
|
"""Update statistics by appending new values to existing stat collections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stat (dict[str, Any]): Dictionary containing new statistical values to append. Keys should match existing
|
||||||
|
keys in self.stats.
|
||||||
|
"""
|
||||||
|
super().update_stats(stat) # update box stats
|
||||||
|
self.pose.update_image_metrics(stat["tp_p"], stat["target_cls"], stat["pred_cls"], stat["im_name"])
|
||||||
|
|
||||||
|
def clear_image_metrics(self) -> None:
|
||||||
|
"""Clear stored per-image metrics."""
|
||||||
|
super().clear_image_metrics()
|
||||||
|
self.pose.clear_image_metrics()
|
||||||
|
|
||||||
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
||||||
"""Process the detection and pose metrics over the given set of predictions.
|
"""Process the detection and pose metrics over the given set of predictions.
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue