mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
ultralytics 8.4.40 Per-image Precision and Recall (#24089)
Some checks are pending
CI / Benchmarks (yolo26n, macos-26, 3.12) (push) Waiting to run
CI / Benchmarks (yolo26n, ubuntu-24.04-arm, 3.12) (push) Waiting to run
CI / Benchmarks (yolo26n, ubuntu-latest, 3.12) (push) Waiting to run
CI / Tests (macos-26, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-24.04-arm, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-latest, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-latest, 3.8, 1.8.0, 0.9.0) (push) Waiting to run
CI / Tests (windows-latest, 3.12, latest) (push) Waiting to run
CI / SlowTests (macos-26, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-24.04-arm, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.11.0, 0.12.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.12.0, 0.13.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.13.0, 0.14.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.11, 2.0.0, 0.15.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.11, 2.1.0, 0.16.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.10.0, 0.25.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.11.0, 0.26.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.2.0, 0.17.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.3.0, 0.18.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.4.0, 0.19.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.5.0, 0.20.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.6.0, 0.21.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.7.0, 0.22.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.8.0, 0.23.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.9.0, 0.24.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.8, 1.8.0, 0.9.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.9, 1.10.0, 0.11.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.9, 1.9.0, 0.10.0) (push) Waiting to run
CI / SlowTests (windows-latest, 3.12, latest) (push) Waiting to run
CI / GPU (push) Waiting to run
CI / RaspberryPi (push) Waiting to run
CI / NVIDIA_Jetson (JetPack5.1.2, 1.23.5, https://github.com/ultralytics/assets/releases/download/v0.0.0/onnxruntime_gpu-1.16.3-cp38-cp38-linux_aarch64.whl, 3.8, jetson-jp512, https://github.com/ultralytics/assets/releases/download/v0.0.0/torch-2.2.0-cp38-c… (push) Waiting to run
CI / NVIDIA_Jetson (JetPack6.2, 1.26.4, https://github.com/ultralytics/assets/releases/download/v0.0.0/onnxruntime_gpu-1.20.0-cp310-cp310-linux_aarch64.whl, 3.10, jetson-jp62, https://github.com/ultralytics/assets/releases/download/v0.0.0/torch-2.5.0a0+872d… (push) Waiting to run
CI / Conda (ubuntu-latest, 3.12) (push) Waiting to run
CI / Summary (push) Blocked by required conditions
Publish Docker Images / Build (push) Waiting to run
Publish Docker Images / trigger-actions (push) Blocked by required conditions
Publish Docker Images / notify (push) Blocked by required conditions
Publish Docs / Docs (push) Waiting to run
Publish to PyPI / check (push) Waiting to run
Publish to PyPI / build (push) Blocked by required conditions
Publish to PyPI / publish (push) Blocked by required conditions
Publish to PyPI / sbom (push) Blocked by required conditions
Publish to PyPI / notify (push) Blocked by required conditions
Some checks are pending
CI / Benchmarks (yolo26n, macos-26, 3.12) (push) Waiting to run
CI / Benchmarks (yolo26n, ubuntu-24.04-arm, 3.12) (push) Waiting to run
CI / Benchmarks (yolo26n, ubuntu-latest, 3.12) (push) Waiting to run
CI / Tests (macos-26, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-24.04-arm, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-latest, 3.12, latest) (push) Waiting to run
CI / Tests (ubuntu-latest, 3.8, 1.8.0, 0.9.0) (push) Waiting to run
CI / Tests (windows-latest, 3.12, latest) (push) Waiting to run
CI / SlowTests (macos-26, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-24.04-arm, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.11.0, 0.12.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.12.0, 0.13.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.10, 1.13.0, 0.14.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.11, 2.0.0, 0.15.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.11, 2.1.0, 0.16.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.10.0, 0.25.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.11.0, 0.26.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.2.0, 0.17.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.3.0, 0.18.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.4.0, 0.19.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.5.0, 0.20.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.6.0, 0.21.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.7.0, 0.22.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.8.0, 0.23.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, 2.9.0, 0.24.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.12, latest) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.8, 1.8.0, 0.9.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.9, 1.10.0, 0.11.0) (push) Waiting to run
CI / SlowTests (ubuntu-latest, 3.9, 1.9.0, 0.10.0) (push) Waiting to run
CI / SlowTests (windows-latest, 3.12, latest) (push) Waiting to run
CI / GPU (push) Waiting to run
CI / RaspberryPi (push) Waiting to run
CI / NVIDIA_Jetson (JetPack5.1.2, 1.23.5, https://github.com/ultralytics/assets/releases/download/v0.0.0/onnxruntime_gpu-1.16.3-cp38-cp38-linux_aarch64.whl, 3.8, jetson-jp512, https://github.com/ultralytics/assets/releases/download/v0.0.0/torch-2.2.0-cp38-c… (push) Waiting to run
CI / NVIDIA_Jetson (JetPack6.2, 1.26.4, https://github.com/ultralytics/assets/releases/download/v0.0.0/onnxruntime_gpu-1.20.0-cp310-cp310-linux_aarch64.whl, 3.10, jetson-jp62, https://github.com/ultralytics/assets/releases/download/v0.0.0/torch-2.5.0a0+872d… (push) Waiting to run
CI / Conda (ubuntu-latest, 3.12) (push) Waiting to run
CI / Summary (push) Blocked by required conditions
Publish Docker Images / Build (push) Waiting to run
Publish Docker Images / trigger-actions (push) Blocked by required conditions
Publish Docker Images / notify (push) Blocked by required conditions
Publish Docs / Docs (push) Waiting to run
Publish to PyPI / check (push) Waiting to run
Publish to PyPI / build (push) Blocked by required conditions
Publish to PyPI / publish (push) Blocked by required conditions
Publish to PyPI / sbom (push) Blocked by required conditions
Publish to PyPI / notify (push) Blocked by required conditions
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
d86cb96689
commit
901dcb0e47
12 changed files with 146 additions and 2 deletions
|
|
@ -93,6 +93,7 @@ This logs the mean F1 score across all classes and a per-class breakdown after e
|
|||
| Attribute | Description |
|
||||
|---|---|
|
||||
| `f1` | F1 score per class |
|
||||
| `image_metrics` | Per-image metrics dictionary with precision, recall, F1, TP, FP, and FN |
|
||||
| `p` | Precision per class |
|
||||
| `r` | Recall per class |
|
||||
| `ap50` | AP at IoU 0.5 per class |
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ If you want to get a deeper understanding of your YOLO26 model's performance, yo
|
|||
print("Mean results for different metrics:", results.box.mean_results)
|
||||
print("Mean precision:", results.box.mp)
|
||||
print("Mean recall:", results.box.mr)
|
||||
print("Per-image metrics:", results.box.image_metrics)
|
||||
print("Precision:", results.box.p)
|
||||
print("Precision curve:", results.box.p_curve)
|
||||
print("Precision values:", results.box.prec_values)
|
||||
|
|
@ -112,7 +113,10 @@ If you want to get a deeper understanding of your YOLO26 model's performance, yo
|
|||
print("Recall curve:", results.box.r_curve)
|
||||
```
|
||||
|
||||
The results object also includes speed metrics like preprocess time, inference time, loss, and postprocess time. By analyzing these metrics, you can fine-tune and optimize your YOLO26 model for better performance, making it more effective for your specific use case.
|
||||
The results object also includes `image_metrics`, a per-image dictionary keyed by image filename with `precision`,
|
||||
`recall`, `f1`, `tp`, `fp`, and `fn`, as well as speed metrics like preprocess time, inference time, loss, and
|
||||
postprocess time. By analyzing these metrics, you can fine-tune and optimize your YOLO26 model for better performance,
|
||||
making it more effective for your specific use case.
|
||||
|
||||
## How Does Fine-Tuning Work?
|
||||
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ Validate a trained YOLO26n model [accuracy](https://www.ultralytics.com/glossary
|
|||
metrics.box.map50 # map50
|
||||
metrics.box.map75 # map75
|
||||
metrics.box.maps # a list containing mAP50-95 for each category
|
||||
metrics.box.image_metrics # per-image metrics dictionary with precision, recall, F1, TP, FP, and FN
|
||||
```
|
||||
|
||||
=== "CLI"
|
||||
|
|
@ -137,6 +138,42 @@ The below examples showcase YOLO model validation with custom arguments in Pytho
|
|||
print(results.confusion_matrix.to_df())
|
||||
```
|
||||
|
||||
!!! tip "Per-Image Precision, Recall, and F1"
|
||||
|
||||
Validation stores per-image precision, recall, F1, TP, FP, and FN metrics (at IoU threshold 0.5) for all tasks
|
||||
except classification. Access them through `results.box.image_metrics` for detection and OBB, `results.seg.image_metrics`
|
||||
for segmentation, and `results.pose.image_metrics` for pose after validation completes.
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load a model
|
||||
model = YOLO("yolo26n.pt")
|
||||
|
||||
# Validate and access per-image metrics
|
||||
results = model.val(data="coco8.yaml")
|
||||
|
||||
# image_metrics is a dictionary with image filenames as keys
|
||||
print(results.box.image_metrics)
|
||||
# Output: {'image1.jpg': {'precision': 0.85, 'recall': 0.92, 'f1': 0.88, 'tp': 17, 'fp': 3, 'fn': 1}, ...}
|
||||
|
||||
# Access metrics for a specific image
|
||||
results.box.image_metrics["image1.jpg"] # {'precision': 0.85, 'recall': 0.92, 'f1': 0.88, 'tp': 17, 'fp': 3, 'fn': 1}
|
||||
```
|
||||
|
||||
Each entry in `image_metrics` contains the following keys:
|
||||
|
||||
| Key | Description |
|
||||
|-------------|---------------------------------------------------|
|
||||
| `precision` | Precision score for the image (`tp / (tp + fp)`). |
|
||||
| `recall` | Recall score for the image (`tp / (tp + fn)`). |
|
||||
| `f1` | Harmonic mean of precision and recall. |
|
||||
| `tp` | Number of true positives for the image. |
|
||||
| `fp` | Number of false positives for the image. |
|
||||
| `fn` | Number of false negatives for the image. |
|
||||
|
||||
This feature is available for detection, segmentation, pose, and OBB tasks.
|
||||
|
||||
| Method | Return Type | Description |
|
||||
| ----------- | ---------------------- | -------------------------------------------------------------------------- |
|
||||
| `summary()` | `List[Dict[str, Any]]` | Converts validation results to a summarized dictionary. |
|
||||
|
|
@ -187,6 +224,7 @@ print(metrics.box.map) # mAP50-95
|
|||
print(metrics.box.map50) # mAP50
|
||||
print(metrics.box.map75) # mAP75
|
||||
print(metrics.box.maps) # list of mAP50-95 for each category
|
||||
print(metrics.box.image_metrics) # per-image metrics dictionary with precision, recall, F1, TP, FP, and FN
|
||||
```
|
||||
|
||||
For a complete performance evaluation, it's crucial to review all these metrics. For more details, refer to the [Key Features of Val Mode](#key-features-of-val-mode).
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ Validate trained YOLO26n model [accuracy](https://www.ultralytics.com/glossary/a
|
|||
metrics.box.map50 # map50
|
||||
metrics.box.map75 # map75
|
||||
metrics.box.maps # a list containing mAP50-95 for each category
|
||||
metrics.box.image_metrics # per-image metrics dictionary with precision, recall, F1, TP, FP, and FN
|
||||
```
|
||||
|
||||
=== "CLI"
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ Validate trained YOLO26n-obb model [accuracy](https://www.ultralytics.com/glossa
|
|||
metrics.box.map50 # map50(B)
|
||||
metrics.box.map75 # map75(B)
|
||||
metrics.box.maps # a list containing mAP50-95(B) for each category
|
||||
metrics.box.image_metrics # per-image metrics dictionary with precision, recall, F1, TP, FP, and FN
|
||||
```
|
||||
|
||||
=== "CLI"
|
||||
|
|
|
|||
|
|
@ -120,10 +120,12 @@ Validate trained YOLO26n-pose model [accuracy](https://www.ultralytics.com/gloss
|
|||
metrics.box.map50 # map50
|
||||
metrics.box.map75 # map75
|
||||
metrics.box.maps # a list containing mAP50-95 for each category
|
||||
metrics.box.image_metrics # per-image metrics dictionary for box with precision, recall, F1, TP, FP, and FN
|
||||
metrics.pose.map # map50-95(P)
|
||||
metrics.pose.map50 # map50(P)
|
||||
metrics.pose.map75 # map75(P)
|
||||
metrics.pose.maps # a list containing mAP50-95(P) for each category
|
||||
metrics.pose.image_metrics # per-image metrics dictionary for pose with precision, recall, F1, TP, FP, and FN
|
||||
```
|
||||
|
||||
=== "CLI"
|
||||
|
|
|
|||
|
|
@ -98,10 +98,12 @@ Validate trained YOLO26n-seg model [accuracy](https://www.ultralytics.com/glossa
|
|||
metrics.box.map50 # map50(B)
|
||||
metrics.box.map75 # map75(B)
|
||||
metrics.box.maps # a list containing mAP50-95(B) for each category
|
||||
metrics.box.image_metrics # per-image metrics dictionary for det with precision, recall, F1, TP, FP, and FN
|
||||
metrics.seg.map # map50-95(M)
|
||||
metrics.seg.map50 # map50(M)
|
||||
metrics.seg.map75 # map75(M)
|
||||
metrics.seg.maps # a list containing mAP50-95(M) for each category
|
||||
metrics.seg.image_metrics # per-image metrics dictionary for seg with precision, recall, F1, TP, FP, and FN
|
||||
```
|
||||
|
||||
=== "CLI"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
__version__ = "8.4.39"
|
||||
__version__ = "8.4.40"
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -96,6 +96,8 @@ class DetectionValidator(BaseValidator):
|
|||
self.seen = 0
|
||||
self.jdict = []
|
||||
self.metrics.names = model.names
|
||||
self.metrics.clear_stats()
|
||||
self.metrics.clear_image_metrics()
|
||||
self.confusion_matrix = ConfusionMatrix(names=model.names, save_matches=self.args.plots and self.args.visualize)
|
||||
|
||||
def get_desc(self) -> str:
|
||||
|
|
@ -186,6 +188,7 @@ class DetectionValidator(BaseValidator):
|
|||
"target_img": np.unique(cls),
|
||||
"conf": np.zeros(0) if no_pred else predn["conf"].cpu().numpy(),
|
||||
"pred_cls": np.zeros(0) if no_pred else predn["cls"].cpu().numpy(),
|
||||
"im_name": Path(pbatch["im_file"]).name,
|
||||
}
|
||||
)
|
||||
# Evaluate
|
||||
|
|
@ -219,6 +222,19 @@ class DetectionValidator(BaseValidator):
|
|||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
self.metrics.save_dir = self.save_dir
|
||||
|
||||
def _gather_image_metrics(self, metric) -> None:
|
||||
"""Gather per-image metrics from all GPUs for a single metric object."""
|
||||
if RANK == 0:
|
||||
gathered_image_metrics = [None] * dist.get_world_size()
|
||||
dist.gather_object(metric.image_metrics, gathered_image_metrics, dst=0)
|
||||
metric.clear_image_metrics()
|
||||
for image_metrics in gathered_image_metrics:
|
||||
if image_metrics:
|
||||
metric.image_metrics.update(image_metrics)
|
||||
elif RANK > 0:
|
||||
dist.gather_object(metric.image_metrics, None, dst=0)
|
||||
metric.clear_image_metrics()
|
||||
|
||||
def gather_stats(self) -> None:
|
||||
"""Gather stats from all GPUs."""
|
||||
if RANK == 0:
|
||||
|
|
@ -234,10 +250,12 @@ class DetectionValidator(BaseValidator):
|
|||
for jdict in gathered_jdict:
|
||||
self.jdict.extend(jdict)
|
||||
self.metrics.stats = merged_stats
|
||||
self._gather_image_metrics(self.metrics.box)
|
||||
self.seen = len(self.dataloader.dataset) # total image count from dataset
|
||||
elif RANK > 0:
|
||||
dist.gather_object(self.metrics.stats, None, dst=0)
|
||||
dist.gather_object(self.jdict, None, dst=0)
|
||||
self._gather_image_metrics(self.metrics.box)
|
||||
self.jdict = []
|
||||
self.metrics.clear_stats()
|
||||
|
||||
|
|
|
|||
|
|
@ -185,6 +185,11 @@ class PoseValidator(DetectionValidator):
|
|||
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
|
||||
return tp
|
||||
|
||||
def gather_stats(self) -> None:
|
||||
"""Gather stats from all GPUs."""
|
||||
super().gather_stats() # gather stats from DetectionValidator
|
||||
self._gather_image_metrics(self.metrics.pose)
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""Save YOLO pose detections to a text file in normalized coordinates.
|
||||
|
||||
|
|
|
|||
|
|
@ -141,6 +141,11 @@ class SegmentationValidator(DetectionValidator):
|
|||
prepared_batch["masks"] = masks
|
||||
return prepared_batch
|
||||
|
||||
def gather_stats(self) -> None:
|
||||
"""Gather stats from all GPUs."""
|
||||
super().gather_stats() # gather stats from DetectionValidator
|
||||
self._gather_image_metrics(self.metrics.seg)
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||
"""Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
||||
|
||||
|
|
|
|||
|
|
@ -879,6 +879,7 @@ class Metric(SimpleClass):
|
|||
self.all_ap = [] # (nc, 10)
|
||||
self.ap_class_index = [] # (nc, )
|
||||
self.nc = 0
|
||||
self.image_metrics = {}
|
||||
|
||||
@property
|
||||
def ap50(self) -> np.ndarray | list:
|
||||
|
|
@ -993,6 +994,10 @@ class Metric(SimpleClass):
|
|||
self.prec_values,
|
||||
) = results
|
||||
|
||||
def clear_image_metrics(self) -> None:
|
||||
"""Clear stored per-image metrics from the current validation run."""
|
||||
self.image_metrics.clear()
|
||||
|
||||
@property
|
||||
def curves(self) -> list:
|
||||
"""Return a list of curves for accessing specific metrics curves."""
|
||||
|
|
@ -1008,6 +1013,33 @@ class Metric(SimpleClass):
|
|||
[self.px, self.r_curve, "Confidence", "Recall"],
|
||||
]
|
||||
|
||||
def update_image_metrics(self, tp: np.ndarray, target_cls: np.ndarray, pred_cls: np.ndarray, im_name: str) -> None:
|
||||
"""Update per-image precision, recall, F1, TP, FP, and FN at IoU threshold 0.5.
|
||||
|
||||
Args:
|
||||
tp (np.ndarray): True positive array of shape (num_preds, num_iou_thresholds), where the first column (IoU
|
||||
>= 0.5) is used.
|
||||
target_cls (np.ndarray): Ground truth class labels for the image.
|
||||
pred_cls (np.ndarray): Predicted class labels for the image.
|
||||
im_name (str): The image filename used as the per-image key.
|
||||
"""
|
||||
# Use the default IoU=0.5 column to match the validator's image-level matching policy.
|
||||
tp = int(tp[:, 0].sum())
|
||||
num_preds = pred_cls.shape[0]
|
||||
num_targets = target_cls.shape[0]
|
||||
fp = num_preds - tp
|
||||
fn = num_targets - tp
|
||||
precision = tp / num_preds if num_preds else 0
|
||||
recall = tp / num_targets if num_targets else 0
|
||||
self.image_metrics[im_name] = {
|
||||
"precision": float(precision),
|
||||
"recall": float(recall),
|
||||
"f1": float(2 * (precision * recall) / (precision + recall)) if (precision + recall) else 0.0,
|
||||
"tp": int(tp),
|
||||
"fp": int(fp),
|
||||
"fn": int(fn),
|
||||
}
|
||||
|
||||
|
||||
class DetMetrics(SimpleClass, DataExportMixin):
|
||||
"""Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
||||
|
|
@ -1059,6 +1091,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|||
"""
|
||||
for k in self.stats.keys():
|
||||
self.stats[k].append(stat[k])
|
||||
self.box.update_image_metrics(stat["tp"], stat["target_cls"], stat["pred_cls"], stat["im_name"])
|
||||
|
||||
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
||||
"""Process predicted results for object detection and update metrics.
|
||||
|
|
@ -1096,6 +1129,10 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|||
for v in self.stats.values():
|
||||
v.clear()
|
||||
|
||||
def clear_image_metrics(self) -> None:
|
||||
"""Clear stored per-image metrics."""
|
||||
self.box.clear_image_metrics()
|
||||
|
||||
@property
|
||||
def keys(self) -> list[str]:
|
||||
"""Return a list of keys for accessing specific metrics."""
|
||||
|
|
@ -1211,6 +1248,21 @@ class SegmentMetrics(DetMetrics):
|
|||
self.seg = Metric()
|
||||
self.stats["tp_m"] = [] # add additional stats for masks
|
||||
|
||||
def update_stats(self, stat: dict[str, Any]) -> None:
|
||||
"""Update statistics by appending new values to existing stat collections.
|
||||
|
||||
Args:
|
||||
stat (dict[str, Any]): Dictionary containing new statistical values to append. Keys should match existing
|
||||
keys in self.stats.
|
||||
"""
|
||||
super().update_stats(stat) # update box stats
|
||||
self.seg.update_image_metrics(stat["tp_m"], stat["target_cls"], stat["pred_cls"], stat["im_name"])
|
||||
|
||||
def clear_image_metrics(self) -> None:
|
||||
"""Clear stored per-image metrics."""
|
||||
super().clear_image_metrics()
|
||||
self.seg.clear_image_metrics()
|
||||
|
||||
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
||||
"""Process the detection and segmentation metrics over the given set of predictions.
|
||||
|
||||
|
|
@ -1347,6 +1399,21 @@ class PoseMetrics(DetMetrics):
|
|||
self.pose = Metric()
|
||||
self.stats["tp_p"] = [] # add additional stats for pose
|
||||
|
||||
def update_stats(self, stat: dict[str, Any]) -> None:
|
||||
"""Update statistics by appending new values to existing stat collections.
|
||||
|
||||
Args:
|
||||
stat (dict[str, Any]): Dictionary containing new statistical values to append. Keys should match existing
|
||||
keys in self.stats.
|
||||
"""
|
||||
super().update_stats(stat) # update box stats
|
||||
self.pose.update_image_metrics(stat["tp_p"], stat["target_cls"], stat["pred_cls"], stat["im_name"])
|
||||
|
||||
def clear_image_metrics(self) -> None:
|
||||
"""Clear stored per-image metrics."""
|
||||
super().clear_image_metrics()
|
||||
self.pose.clear_image_metrics()
|
||||
|
||||
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
||||
"""Process the detection and pose metrics over the given set of predictions.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue