Large Python files documentation update (#19695)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-03-14 11:16:44 +01:00 committed by GitHub
parent f9fdff5a28
commit 23821bd2e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1534 additions and 645 deletions

View file

@ -11,7 +11,7 @@ keywords: Meituan YOLOv6, object detection, real-time applications, BiC module,
[Meituan](https://www.meituan.com/) YOLOv6 is a cutting-edge object detector that offers remarkable balance between speed and accuracy, making it a popular choice for real-time applications. This model introduces several notable enhancements on its architecture and training scheme, including the implementation of a Bi-directional Concatenation (BiC) module, an anchor-aided training (AAT) strategy, and an improved [backbone](https://www.ultralytics.com/glossary/backbone) and neck design for state-of-the-art accuracy on the COCO dataset.
![Meituan YOLOv6](https://github.com/ultralytics/docs/releases/download/0/meituan-yolov6.avif)
![Model example image](https://github.com/ultralytics/docs/releases/download/0/yolov6-architecture-diagram.avif) **Overview of YOLOv6.** Model architecture diagram showing the redesigned network components and training strategies that have led to significant performance improvements. (a) The neck of YOLOv6 (N and S are shown). Note for M/L, RepBlocks is replaced with CSPStackRep. (b) The structure of a BiC module. (c) A SimCSPSPPF block. ([source](https://arxiv.org/pdf/2301.05586.pdf)).
![Model example image](https://github.com/ultralytics/docs/releases/download/0/yolov6-architecture-diagram.avif) **Overview of YOLOv6.** Model architecture diagram showing the redesigned network components and training strategies that have led to significant performance improvements. (a) The neck of YOLOv6 (N and S are shown). Note for M/L, RepBlocks is replaced with CSPStackRep. (b) The structure of a BiC module. (c) A SimCSPSPPF block. ([source](https://arxiv.org/pdf/2301.05586)).
### Key Features

View file

@ -175,13 +175,8 @@ def visualize_image_annotations(image_path, txt_path, label_map):
adjusted for readability, depending on the background color's luminance.
Args:
image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL (e.g., .jpg, .png).
txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object with:
- class_id (int): The class index.
- x_center (float): The X center of the bounding box (relative to image width).
- y_center (float): The Y center of the bounding box (relative to image height).
- width (float): The width of the bounding box (relative to image width).
- height (float): The height of the bounding box (relative to image height).
image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL.
txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object.
label_map (dict): A dictionary that maps class IDs (integers) to class labels (strings).
Examples:
@ -222,8 +217,8 @@ def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
imgsz (tuple): The size of the image as (height, width).
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
N is the number of polygons, and M is the number of points such that M % 2 = 0.
color (int, optional): The color value to fill in the polygons on the mask. Defaults to 1.
downsample_ratio (int, optional): Factor by which to downsample the mask. Defaults to 1.
color (int, optional): The color value to fill in the polygons on the mask.
downsample_ratio (int, optional): Factor by which to downsample the mask.
Returns:
(np.ndarray): A binary mask of the specified image size with the polygons filled in.
@ -246,7 +241,7 @@ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
N is the number of polygons, and M is the number of points such that M % 2 = 0.
color (int): The color value to fill in the polygons on the masks.
downsample_ratio (int, optional): Factor by which to downsample each mask. Defaults to 1.
downsample_ratio (int, optional): Factor by which to downsample each mask.
Returns:
(np.ndarray): A set of binary masks of the specified image size with the polygons filled in.
@ -281,8 +276,7 @@ def find_dataset_yaml(path: Path) -> Path:
Find and return the YAML file associated with a Detect, Segment or Pose dataset.
This function searches for a YAML file at the root level of the provided directory first, and if not found, it
performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError
is raised if no YAML file is found or if multiple YAML files are found.
performs a recursive search. It prefers YAML files that have the same stem as the provided path.
Args:
path (Path): The directory path to search for the YAML file.
@ -308,7 +302,7 @@ def check_det_dataset(dataset, autodownload=True):
Args:
dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True.
autodownload (bool, optional): Whether to automatically download the dataset if not found.
Returns:
(dict): Parsed dataset information and paths.
@ -400,7 +394,7 @@ def check_cls_dataset(dataset, split=""):
Args:
dataset (str | Path): The name of the dataset.
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
split (str, optional): The split of the dataset. Either 'val', 'test', or ''.
Returns:
(dict): A dictionary containing the following keys:
@ -634,8 +628,8 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
Args:
f (str): The path to the input image file.
f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
quality (int, optional): The image compression quality as a percentage. Default is 50%.
max_dim (int, optional): The maximum dimension (width or height) of the output image.
quality (int, optional): The image compression quality as a percentage.
Examples:
>>> from pathlib import Path
@ -664,9 +658,9 @@ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annot
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
Args:
path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.
weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
path (Path, optional): Path to images directory.
weights (list | tuple, optional): Train, validation, and test split fractions.
annotated_only (bool, optional): If True, only images with an associated txt file are used.
Examples:
>>> from ultralytics.data.utils import autosplit

View file

@ -138,7 +138,7 @@ def validate_args(format, passed_args, valid_args):
Args:
format (str): The export format.
passed_args (Namespace): The arguments used during export.
valid_args (dict): List of valid arguments for the format.
valid_args (List): List of valid arguments for the format.
Raises:
AssertionError: If an unsupported argument is used, or if the format lacks supported argument listings.
@ -219,8 +219,8 @@ class Exporter:
Args:
cfg (str, optional): Path to a configuration file.
overrides (dict, optional): Configuration overrides.
_callbacks (dict, optional): Dictionary of callback functions.
overrides (Dict, optional): Configuration overrides.
_callbacks (Dict, optional): Dictionary of callback functions.
"""
self.args = get_cfg(cfg, overrides)
if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors
@ -1574,7 +1574,7 @@ class NMSModel(torch.nn.Module):
x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).
Returns:
out (torch.Tensor): The post-processed results with shape (N, max_det, 4 + 2 + extra_shape).
(torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the number of detections after NMS.
"""
from functools import partial

View file

@ -95,7 +95,7 @@ class BaseTrainer:
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the BaseTrainer class.
Initialize the BaseTrainer class.
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
@ -159,11 +159,11 @@ class BaseTrainer:
callbacks.add_integration_callbacks(self)
def add_callback(self, event: str, callback):
"""Appends the given callback."""
"""Append the given callback to the event's callback list."""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""Overrides the existing callbacks with the given callback."""
"""Override the existing callbacks with the given callback for the specified event."""
self.callbacks[event] = [callback]
def run_callbacks(self, event: str):
@ -219,7 +219,7 @@ class BaseTrainer:
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
def _setup_ddp(self, world_size):
"""Initializes and sets the DistributedDataParallel parameters for training."""
"""Initialize and set the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK)
self.device = torch.device("cuda", RANK)
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
@ -232,7 +232,7 @@ class BaseTrainer:
)
def _setup_train(self, world_size):
"""Builds dataloaders and optimizer on correct rank process."""
"""Build dataloaders and optimizer on correct rank process."""
# Model
self.run_callbacks("on_pretrain_routine_start")
ckpt = self.setup_model()
@ -320,7 +320,7 @@ class BaseTrainer:
self.run_callbacks("on_pretrain_routine_end")
def _do_train(self, world_size=1):
"""Train completed, evaluate and plot if specified by arguments."""
"""Train the model with the specified world size."""
if world_size > 1:
self._setup_ddp(world_size)
self._setup_train(world_size)
@ -480,7 +480,7 @@ class BaseTrainer:
self.run_callbacks("teardown")
def auto_batch(self, max_num_obj=0):
"""Get batch size by calculating memory occupation of model."""
"""Calculate optimal batch size based on model and device memory constraints."""
return check_train_batch_size(
model=self.model,
imgsz=self.args.imgsz,
@ -490,7 +490,7 @@ class BaseTrainer:
) # returns batch size
def _get_memory(self, fraction=False):
"""Get accelerator memory utilization in GB or fraction."""
"""Get accelerator memory utilization in GB or as a fraction of total memory."""
memory, total = 0, 0
if self.device.type == "mps":
memory = torch.mps.driver_allocated_memory()
@ -505,7 +505,7 @@ class BaseTrainer:
return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
def _clear_memory(self):
"""Clear accelerator memory on different platforms."""
"""Clear accelerator memory by calling garbage collector and emptying cache."""
gc.collect()
if self.device.type == "mps":
torch.mps.empty_cache()
@ -515,7 +515,7 @@ class BaseTrainer:
torch.cuda.empty_cache()
def read_results_csv(self):
"""Read results.csv into a dict using pandas."""
"""Read results.csv into a dictionary using pandas."""
import pandas as pd # scope for faster 'import ultralytics'
return pd.read_csv(self.csv).to_dict(orient="list")
@ -557,9 +557,10 @@ class BaseTrainer:
def get_dataset(self):
"""
Get train, val path from data dict if it exists.
Get train and validation datasets from data dictionary.
Returns None if data format is not recognized.
Returns:
(tuple): A tuple containing the training and validation/test datasets.
"""
try:
if self.args.task == "classify":
@ -583,7 +584,12 @@ class BaseTrainer:
return data["train"], data.get("val") or data.get("test")
def setup_model(self):
"""Load/create/download model for any task."""
"""
Load, create, or download model for any task.
Returns:
(dict): Optional checkpoint to resume training from.
"""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
@ -613,9 +619,10 @@ class BaseTrainer:
def validate(self):
"""
Runs validation on test set using self.validator.
Run validation on test set using self.validator.
The returned dict is expected to contain "fitness" key.
Returns:
(tuple): A tuple containing metrics dictionary and fitness score.
"""
metrics = self.validator(self)
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
@ -649,7 +656,7 @@ class BaseTrainer:
return {"loss": loss_items} if loss_items is not None else ["loss"]
def set_model_attributes(self):
"""To set or update model parameters before training."""
"""Set or update model parameters before training."""
self.model.names = self.data["names"]
def build_targets(self, preds, targets):
@ -670,7 +677,7 @@ class BaseTrainer:
pass
def save_metrics(self, metrics):
"""Saves training metrics to a CSV file."""
"""Save training metrics to a CSV file."""
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 2 # number of cols
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
@ -688,7 +695,7 @@ class BaseTrainer:
self.plots[path] = {"data": data, "timestamp": time.time()}
def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO model."""
"""Perform final evaluation and validation for object detection YOLO model."""
ckpt = {}
for f in self.last, self.best:
if f.exists():
@ -772,8 +779,7 @@ class BaseTrainer:
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
"""
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
weight decay, and number of iterations.
Construct an optimizer for the given model.
Args:
model (torch.nn.Module): The model for which to build an optimizer.

View file

@ -176,7 +176,7 @@ class SAM2Model(torch.nn.Module):
compile_image_encoder: bool = False,
):
"""
Initializes the SAM2Model for video object segmentation with memory-based tracking.
Initialize the SAM2Model for video object segmentation with memory-based tracking.
Args:
image_encoder (nn.Module): Visual encoder for extracting image features.
@ -213,9 +213,9 @@ class SAM2Model(torch.nn.Module):
the encoder.
proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
encoding in object pointers.
use_signed_tpos_enc_to_obj_ptrs (bool): whether to use signed distance (instead of unsigned absolute distance)
in the temporal positional encoding in the object pointers, only relevant when both `use_obj_ptrs_in_encoder=True`
and `add_tpos_enc_to_obj_ptrs=True`.
use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance (instead of unsigned absolute distance)
in the temporal positional encoding in the object pointers, only relevant when both
`use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`.
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
during evaluation.
pred_obj_scores (bool): Whether to predict if there is an object in the frame.
@ -332,18 +332,18 @@ class SAM2Model(torch.nn.Module):
@property
def device(self):
"""Returns the device on which the model's parameters are stored."""
"""Return the device on which the model's parameters are stored."""
return next(self.parameters()).device
def forward(self, *args, **kwargs):
"""Processes image and prompt inputs to generate object masks and scores in video sequences."""
"""Process image and prompt inputs to generate object masks and scores in video sequences."""
raise NotImplementedError(
"Please use the corresponding methods in SAM2VideoPredictor for inference."
"See notebooks/video_predictor_example.ipynb for an example."
)
def _build_sam_heads(self):
"""Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
"""Build SAM-style prompt encoder and mask decoder for image segmentation tasks."""
self.sam_prompt_embed_dim = self.hidden_dim
self.sam_image_embedding_size = self.image_size // self.backbone_stride
@ -545,7 +545,7 @@ class SAM2Model(torch.nn.Module):
)
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
"""Processes mask inputs directly as output, bypassing SAM encoder/decoder."""
"""Process mask inputs directly as output, bypassing SAM encoder/decoder."""
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
mask_inputs_float = mask_inputs.float()
@ -592,7 +592,7 @@ class SAM2Model(torch.nn.Module):
)
def forward_image(self, img_batch: torch.Tensor):
"""Processes image batch through encoder to extract multi-level features for SAM model."""
"""Process image batch through encoder to extract multi-level features for SAM model."""
backbone_out = self.image_encoder(img_batch)
if self.use_high_res_features_in_sam:
# precompute projected level 0 and level 1 features in SAM decoder
@ -602,7 +602,7 @@ class SAM2Model(torch.nn.Module):
return backbone_out
def _prepare_backbone_features(self, backbone_out):
"""Prepares and flattens visual features from the image backbone output for further processing."""
"""Prepare and flatten visual features from the image backbone output for further processing."""
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
@ -627,7 +627,7 @@ class SAM2Model(torch.nn.Module):
num_frames,
track_in_reverse=False, # tracking in reverse time order (for demo usage)
):
"""Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
"""Prepare memory-conditioned features by fusing current frame's visual features with previous memories."""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
@ -788,7 +788,7 @@ class SAM2Model(torch.nn.Module):
object_score_logits,
is_mask_from_pts,
):
"""Encodes frame features and masks into a new memory representation for video segmentation."""
"""Encode frame features and masks into a new memory representation for video segmentation."""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
@ -838,7 +838,7 @@ class SAM2Model(torch.nn.Module):
track_in_reverse,
prev_sam_mask_logits,
):
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
"""Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
if len(current_vision_feats) > 1:
@ -893,9 +893,7 @@ class SAM2Model(torch.nn.Module):
object_score_logits,
current_out,
):
"""Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be
used in future frames).
"""
"""Run memory encoder on predicted mask to encode it into a new memory feature for future frames."""
if run_mem_encoder and self.num_maskmem > 0:
high_res_masks_for_mem_enc = high_res_masks
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
@ -932,7 +930,7 @@ class SAM2Model(torch.nn.Module):
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
prev_sam_mask_logits=None,
):
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
"""Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
current_out, sam_outputs, _, _ = self._track_step(
frame_idx,
is_init_cond_frame,
@ -970,7 +968,7 @@ class SAM2Model(torch.nn.Module):
return current_out
def _use_multimask(self, is_init_cond_frame, point_inputs):
"""Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
"""Determine whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
return (
self.multimask_output_in_sam
@ -980,7 +978,7 @@ class SAM2Model(torch.nn.Module):
@staticmethod
def _apply_non_overlapping_constraints(pred_masks):
"""Applies non-overlapping constraints to masks, keeping the highest scoring object per location."""
"""Apply non-overlapping constraints to masks, keeping the highest scoring object per location."""
batch_size = pred_masks.size(0)
if batch_size == 1:
return pred_masks
@ -1001,12 +999,7 @@ class SAM2Model(torch.nn.Module):
self.binarize_mask_from_pts_for_mem_enc = binarize
def set_imgsz(self, imgsz):
"""
Set image size to make model compatible with different image sizes.
Args:
imgsz (Tuple[int, int]): The size of the input image.
"""
"""Set image size to make model compatible with different image sizes."""
self.image_size = imgsz[0]
self.sam_prompt_encoder.input_image_size = imgsz
self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16

View file

@ -27,7 +27,7 @@ class Conv2d_BN(torch.nn.Sequential):
Attributes:
c (torch.nn.Conv2d): 2D convolution layer.
1 (torch.nn.BatchNorm2d): Batch normalization layer.
bn (torch.nn.BatchNorm2d): Batch normalization layer.
Methods:
__init__: Initializes the Conv2d_BN with specified parameters.
@ -265,9 +265,9 @@ class ConvLayer(nn.Module):
dim (int): The dimensionality of the input and output.
input_resolution (Tuple[int, int]): The resolution of the input image.
depth (int): The number of MBConv layers in the block.
activation (Callable): Activation function applied after each convolution.
activation (nn.Module): Activation function applied after each convolution.
drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
downsample (Optional[nn.Module]): Function for downsampling the output. None to skip downsampling.
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
conv_expand_ratio (float): Expansion ratio for the MBConv layers.
@ -413,12 +413,9 @@ class Attention(torch.nn.Module):
Args:
dim (int): The dimensionality of the input and output.
key_dim (int): The dimensionality of the keys and queries.
num_heads (int): Number of attention heads. Default is 8.
attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
resolution (Tuple[int, int]): Spatial resolution of the input feature map. Default is (14, 14).
Raises:
AssertionError: If 'resolution' is not a tuple of length 2.
num_heads (int): Number of attention heads.
attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors.
resolution (Tuple[int, int]): Spatial resolution of the input feature map.
Examples:
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
@ -821,22 +818,20 @@ class TinyViT(nn.Module):
attention and convolution blocks, and a classification head.
Args:
img_size (int): Size of the input image. Default is 224.
in_chans (int): Number of input channels. Default is 3.
num_classes (int): Number of classes for classification. Default is 1000.
img_size (int): Size of the input image.
in_chans (int): Number of input channels.
num_classes (int): Number of classes for classification.
embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
Default is (96, 192, 384, 768).
depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2).
depths (Tuple[int, int, int, int]): Number of blocks in each stage.
num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
Default is (3, 6, 12, 24).
window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. Default is (7, 7, 14, 7).
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. Default is 4.0.
drop_rate (float): Dropout rate. Default is 0.0.
drop_path_rate (float): Stochastic depth rate. Default is 0.1.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default is False.
mbconv_expand_ratio (float): Expansion ratio for MBConv layer. Default is 4.0.
local_conv_size (int): Kernel size for local convolutions. Default is 3.
layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0.
window_sizes (Tuple[int, int, int, int]): Window sizes for each stage.
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
drop_rate (float): Dropout rate.
drop_path_rate (float): Stochastic depth rate.
use_checkpoint (bool): Whether to use checkpointing to save memory.
mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
local_conv_size (int): Kernel size for local convolutions.
layer_lr_decay (float): Layer-wise learning rate decay factor.
Examples:
>>> model = TinyViT(img_size=224, num_classes=1000)
@ -992,12 +987,7 @@ class TinyViT(nn.Module):
return self.forward_features(x)
def set_imgsz(self, imgsz=[1024, 1024]):
"""
Set image size to make model compatible with different image sizes.
Args:
imgsz (Tuple[int, int]): The size of the input image.
"""
"""Set image size to make model compatible with different image sizes."""
imgsz = [s // 4 for s in imgsz]
self.patches_resolution = imgsz
for i, layer in enumerate(self.layers):

View file

@ -701,9 +701,6 @@ class SAM2Predictor(Predictor):
- The method supports batched inference for multiple objects when points or bboxes are provided.
- Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
- When both bboxes and points are provided, they are merged into a single 'points' input for the model.
References:
- SAM2 Paper: [Add link to SAM2 paper when available]
"""
features = self.get_im_features(im) if self.features is None else self.features

View file

@ -19,11 +19,7 @@ from ultralytics.utils.downloads import attempt_download_asset, is_url
def check_class_names(names):
"""
Check class names.
Map imagenet class codes to human-readable names if required. Convert lists to dicts.
"""
"""Check class names and convert to dict format if needed."""
if isinstance(names, list): # names is a list
names = dict(enumerate(names)) # convert to dict
if isinstance(names, dict):
@ -78,8 +74,23 @@ class AutoBackend(nn.Module):
| IMX | *_imx_model/ |
| RKNN | *_rknn_model/ |
This class offers dynamic backend switching capabilities based on the input model format, making it easier to deploy
models across various platforms.
Attributes:
model (torch.nn.Module): The loaded YOLO model.
device (torch.device): The device (CPU or GPU) on which the model is loaded.
task (str): The type of task the model performs (detect, segment, classify, pose).
names (Dict): A dictionary of class names that the model can detect.
stride (int): The model stride, typically 32 for YOLO models.
fp16 (bool): Whether the model uses half-precision (FP16) inference.
Methods:
forward: Run inference on an input image.
from_numpy: Convert numpy array to tensor.
warmup: Warm up the model with a dummy input.
_model_type: Determine the model type from file path.
Examples:
>>> model = AutoBackend(weights="yolov8n.pt", device="cuda")
>>> results = model(img)
"""
@torch.no_grad()
@ -101,7 +112,7 @@ class AutoBackend(nn.Module):
weights (str | torch.nn.Module): Path to the model weights file or a module instance. Defaults to 'yolo11n.pt'.
device (torch.device): Device to run the model on. Defaults to CPU.
dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False.
data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional.
data (str | Path | optional): Path to the additional data.yaml file containing class names.
fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False.
batch (int): Batch-size to assume for inference.
fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True.
@ -539,12 +550,12 @@ class AutoBackend(nn.Module):
Args:
im (torch.Tensor): The image tensor to perform inference on.
augment (bool): whether to perform data augmentation during inference, defaults to False
visualize (bool): whether to visualize the output predictions, defaults to False
embed (list, optional): A list of feature vectors/embeddings to return.
augment (bool): Whether to perform data augmentation during inference. Defaults to False.
visualize (bool): Whether to visualize the output predictions. Defaults to False.
embed (List, optional): A list of feature vectors/embeddings to return.
Returns:
(tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
(torch.Tensor | List[torch.Tensor]): The raw output tensor(s) from the model.
"""
b, ch, h, w = im.shape # batch, channel, height, width
if self.fp16 and im.dtype != torch.float16:
@ -776,10 +787,13 @@ class AutoBackend(nn.Module):
def _model_type(p="path/to/model.pt"):
"""
Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml,
saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle.
saved_model, pb, tflite, edgetpu, tfjs, ncnn, mnn, imx or paddle.
Args:
p (str): path to the model file. Defaults to path/to/model.pt
p (str): Path to the model file. Defaults to path/to/model.pt
Returns:
(List[bool]): List of booleans indicating the model type.
Examples:
>>> model = AutoBackend(weights="path/to/model.onnx")

File diff suppressed because it is too large Load diff

View file

@ -119,10 +119,10 @@ class BaseModel(torch.nn.Module):
Args:
x (torch.Tensor): The input tensor to the model.
profile (bool): Print the computation time of each layer if True, defaults to False.
visualize (bool): Save the feature maps of the model if True, defaults to False.
augment (bool): Augment image during prediction, defaults to False.
embed (list, optional): A list of feature vectors/embeddings to return.
profile (bool): Print the computation time of each layer if True.
visualize (bool): Save the feature maps of the model if True.
augment (bool): Augment image during prediction.
embed (List, optional): A list of feature vectors/embeddings to return.
Returns:
(torch.Tensor): The last output of the model.
@ -137,9 +137,9 @@ class BaseModel(torch.nn.Module):
Args:
x (torch.Tensor): The input tensor to the model.
profile (bool): Print the computation time of each layer if True, defaults to False.
visualize (bool): Save the feature maps of the model if True, defaults to False.
embed (list, optional): A list of feature vectors/embeddings to return.
profile (bool): Print the computation time of each layer if True.
visualize (bool): Save the feature maps of the model if True.
embed (List, optional): A list of feature vectors/embeddings to return.
Returns:
(torch.Tensor): The last output of the model.
@ -170,13 +170,12 @@ class BaseModel(torch.nn.Module):
def _profile_one_layer(self, m, x, dt):
"""
Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to
the provided list.
Profile the computation time and FLOPs of a single layer of the model on a given input.
Args:
m (torch.nn.Module): The layer to be profiled.
x (torch.Tensor): The input data to the layer.
dt (list): A list to store the computation time of the layer.
dt (List): A list to store the computation time of the layer.
"""
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
@ -192,8 +191,8 @@ class BaseModel(torch.nn.Module):
def fuse(self, verbose=True):
"""
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
computation efficiency.
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
efficiency.
Returns:
(torch.nn.Module): The fused model is returned.
@ -225,7 +224,7 @@ class BaseModel(torch.nn.Module):
Check if the model has less than a certain threshold of BatchNorm layers.
Args:
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
thresh (int, optional): The threshold number of BatchNorm layers.
Returns:
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
@ -235,21 +234,21 @@ class BaseModel(torch.nn.Module):
def info(self, detailed=False, verbose=True, imgsz=640):
"""
Prints model information.
Print model information.
Args:
detailed (bool): if True, prints out detailed information about the model. Defaults to False
verbose (bool): if True, prints out the model information. Defaults to False
imgsz (int): the size of the image that the model will be trained on. Defaults to 640
detailed (bool): If True, prints out detailed information about the model.
verbose (bool): If True, prints out the model information.
imgsz (int): The size of the image that the model will be trained on.
"""
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
def _apply(self, fn):
"""
Applies a function to all the tensors in the model that are not parameters or registered buffers.
Apply a function to all tensors in the model that are not parameters or registered buffers.
Args:
fn (function): the function to apply to the model
fn (function): The function to apply to the model.
Returns:
(BaseModel): An updated BaseModel object.
@ -264,11 +263,11 @@ class BaseModel(torch.nn.Module):
def load(self, weights, verbose=True):
"""
Load the weights into the model.
Load weights into the model.
Args:
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
verbose (bool, optional): Whether to log the transfer progress.
"""
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
csd = model.float().state_dict() # checkpoint state_dict as FP32
@ -282,8 +281,8 @@ class BaseModel(torch.nn.Module):
Compute loss.
Args:
batch (dict): Batch to compute loss on
preds (torch.Tensor | List[torch.Tensor]): Predictions.
batch (dict): Batch to compute loss on.
preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
"""
if getattr(self, "criterion", None) is None:
self.criterion = self.init_criterion()
@ -300,7 +299,15 @@ class DetectionModel(BaseModel):
"""YOLO detection model."""
def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes
"""Initialize the YOLO detection model with the given config and parameters."""
"""
Initialize the YOLO detection model with the given config and parameters.
Args:
cfg (str | dict): Model configuration file path or dictionary.
ch (int): Number of input channels.
nc (int, optional): Number of classes.
verbose (bool): Whether to display model information.
"""
super().__init__()
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
if self.yaml["backbone"][0][2] == "Silence":
@ -327,7 +334,7 @@ class DetectionModel(BaseModel):
m.inplace = self.inplace
def _forward(x):
"""Performs a forward pass through the model, handling different Detect subclass types accordingly."""
"""Perform a forward pass through the model, handling different Detect subclass types accordingly."""
if self.end2end:
return self.forward(x)["one2many"]
return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
@ -345,7 +352,15 @@ class DetectionModel(BaseModel):
LOGGER.info("")
def _predict_augment(self, x):
"""Perform augmentations on input image x and return augmented inference and train outputs."""
"""
Perform augmentations on input image x and return augmented inference and train outputs.
Args:
x (torch.Tensor): Input image tensor.
Returns:
(torch.Tensor): Augmented inference output.
"""
if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
LOGGER.warning("WARNING ⚠️ Model does not support 'augment=True', reverting to single-scale prediction.")
return self._predict_once(x)
@ -363,7 +378,19 @@ class DetectionModel(BaseModel):
@staticmethod
def _descale_pred(p, flips, scale, img_size, dim=1):
"""De-scale predictions following augmented inference (inverse operation)."""
"""
De-scale predictions following augmented inference (inverse operation).
Args:
p (torch.Tensor): Predictions tensor.
flips (int): Flip type (0=none, 2=ud, 3=lr).
scale (float): Scale factor.
img_size (tuple): Original image size (height, width).
dim (int): Dimension to split at.
Returns:
(torch.Tensor): De-scaled predictions.
"""
p[:, :4] /= scale # de-scale
x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
if flips == 2:
@ -373,7 +400,15 @@ class DetectionModel(BaseModel):
return torch.cat((x, y, wh, cls), dim)
def _clip_augmented(self, y):
"""Clip YOLO augmented inference tails."""
"""
Clip YOLO augmented inference tails.
Args:
y (List[torch.Tensor]): List of detection tensors.
Returns:
(List[torch.Tensor]): Clipped detection tensors.
"""
nl = self.model[-1].nl # number of detection layers (P3-P5)
g = sum(4**x for x in range(nl)) # grid points
e = 1 # exclude layer count
@ -392,7 +427,15 @@ class OBBModel(DetectionModel):
"""YOLO Oriented Bounding Box (OBB) model."""
def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
"""Initialize YOLO OBB model with given config and parameters."""
"""
Initialize YOLO OBB model with given config and parameters.
Args:
cfg (str | dict): Model configuration file path or dictionary.
ch (int): Number of input channels.
nc (int, optional): Number of classes.
verbose (bool): Whether to display model information.
"""
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def init_criterion(self):
@ -404,7 +447,15 @@ class SegmentationModel(DetectionModel):
"""YOLO segmentation model."""
def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
"""Initialize YOLOv8 segmentation model with given config and parameters."""
"""
Initialize YOLOv8 segmentation model with given config and parameters.
Args:
cfg (str | dict): Model configuration file path or dictionary.
ch (int): Number of input channels.
nc (int, optional): Number of classes.
verbose (bool): Whether to display model information.
"""
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def init_criterion(self):
@ -416,7 +467,16 @@ class PoseModel(DetectionModel):
"""YOLO pose model."""
def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
"""Initialize YOLOv8 Pose model."""
"""
Initialize YOLOv8 Pose model.
Args:
cfg (str | dict): Model configuration file path or dictionary.
ch (int): Number of input channels.
nc (int, optional): Number of classes.
data_kpt_shape (tuple): Shape of keypoints data.
verbose (bool): Whether to display model information.
"""
if not isinstance(cfg, dict):
cfg = yaml_model_load(cfg) # load model YAML
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
@ -433,12 +493,28 @@ class ClassificationModel(BaseModel):
"""YOLO classification model."""
def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
"""
Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
Args:
cfg (str | dict): Model configuration file path or dictionary.
ch (int): Number of input channels.
nc (int, optional): Number of classes.
verbose (bool): Whether to display model information.
"""
super().__init__()
self._from_yaml(cfg, ch, nc, verbose)
def _from_yaml(self, cfg, ch, nc, verbose):
"""Set YOLOv8 model configurations and define the model architecture."""
"""
Set YOLOv8 model configurations and define the model architecture.
Args:
cfg (str | dict): Model configuration file path or dictionary.
ch (int): Number of input channels.
nc (int, optional): Number of classes.
verbose (bool): Whether to display model information.
"""
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
# Define model
@ -455,7 +531,13 @@ class ClassificationModel(BaseModel):
@staticmethod
def reshape_outputs(model, nc):
"""Update a TorchVision classification model to class count 'n' if required."""
"""
Update a TorchVision classification model to class count 'n' if required.
Args:
model (torch.nn.Module): Model to update.
nc (int): New number of classes.
"""
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
if isinstance(m, Classify): # YOLO Classify() head
if m.linear.out_features != nc:
@ -500,10 +582,10 @@ class RTDETRDetectionModel(DetectionModel):
Initialize the RTDETRDetectionModel.
Args:
cfg (str): Configuration file name or path.
cfg (str | dict): Configuration file name or path.
ch (int): Number of input channels.
nc (int, optional): Number of classes. Defaults to None.
verbose (bool, optional): Print additional information during initialization. Defaults to True.
nc (int, optional): Number of classes.
verbose (bool): Print additional information during initialization.
"""
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
@ -519,7 +601,7 @@ class RTDETRDetectionModel(DetectionModel):
Args:
batch (dict): Dictionary containing image and label data.
preds (torch.Tensor, optional): Precomputed model predictions. Defaults to None.
preds (torch.Tensor, optional): Precomputed model predictions.
Returns:
(tuple): A tuple containing the total loss and main three losses in a tensor.
@ -564,11 +646,11 @@ class RTDETRDetectionModel(DetectionModel):
Args:
x (torch.Tensor): The input tensor.
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
batch (dict, optional): Ground truth data for evaluation. Defaults to None.
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
embed (list, optional): A list of feature vectors/embeddings to return.
profile (bool): If True, profile the computation time for each layer.
visualize (bool): If True, save feature maps for visualization.
batch (dict, optional): Ground truth data for evaluation.
augment (bool): If True, perform data augmentation during inference.
embed (List, optional): A list of feature vectors/embeddings to return.
Returns:
(torch.Tensor): Model's output tensor.
@ -596,13 +678,28 @@ class WorldModel(DetectionModel):
"""YOLOv8 World Model."""
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
"""Initialize YOLOv8 world model with given config and parameters."""
"""
Initialize YOLOv8 world model with given config and parameters.
Args:
cfg (str | dict): Model configuration file path or dictionary.
ch (int): Number of input channels.
nc (int, optional): Number of classes.
verbose (bool): Whether to display model information.
"""
self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
self.clip_model = None # CLIP model placeholder
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def set_classes(self, text, batch=80, cache_clip_model=True):
"""Set classes in advance so that model could do offline-inference without clip model."""
"""
Set classes in advance so that model could do offline-inference without clip model.
Args:
text (List[str]): List of class names.
batch (int): Batch size for processing text tokens.
cache_clip_model (bool): Whether to cache the CLIP model.
"""
try:
import clip
except ImportError:
@ -628,11 +725,11 @@ class WorldModel(DetectionModel):
Args:
x (torch.Tensor): The input tensor.
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
embed (list, optional): A list of feature vectors/embeddings to return.
profile (bool): If True, profile the computation time for each layer.
visualize (bool): If True, save feature maps for visualization.
txt_feats (torch.Tensor, optional): The text features, use it if it's given.
augment (bool): If True, perform data augmentation during inference.
embed (List, optional): A list of feature vectors/embeddings to return.
Returns:
(torch.Tensor): Model's output tensor.
@ -671,7 +768,7 @@ class WorldModel(DetectionModel):
Args:
batch (dict): Batch to compute loss on.
preds (torch.Tensor | List[torch.Tensor]): Predictions.
preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
"""
if not hasattr(self, "criterion"):
self.criterion = self.init_criterion()
@ -689,7 +786,18 @@ class Ensemble(torch.nn.ModuleList):
super().__init__()
def forward(self, x, augment=False, profile=False, visualize=False):
"""Function generates the YOLO network's final layer."""
"""
Generate the YOLO network's final layer.
Args:
x (torch.Tensor): Input tensor.
augment (bool): Whether to augment the input.
profile (bool): Whether to profile the model.
visualize (bool): Whether to visualize the features.
Returns:
(tuple): Tuple containing the concatenated predictions and None.
"""
y = [module(x, augment, profile, visualize)[0] for module in self]
# y = torch.stack(y).max(0)[0] # max ensemble
# y = torch.stack(y).mean(0) # mean ensemble
@ -765,7 +873,16 @@ class SafeUnpickler(pickle.Unpickler):
"""Custom Unpickler that replaces unknown classes with SafeClass."""
def find_class(self, module, name):
"""Attempt to find a class, returning SafeClass if not among safe modules."""
"""
Attempt to find a class, returning SafeClass if not among safe modules.
Args:
module (str): Module name.
name (str): Class name.
Returns:
(type): Found class or SafeClass.
"""
safe_modules = (
"torch",
"collections",
@ -791,13 +908,13 @@ def torch_safe_load(weight, safe_only=False):
weight (str): The file path of the PyTorch model.
safe_only (bool): If True, replace unknown classes with SafeClass during loading.
Returns:
ckpt (dict): The loaded model checkpoint.
file (str): The loaded filename.
Examples:
>>> from ultralytics.nn.tasks import torch_safe_load
>>> ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
Returns:
ckpt (dict): The loaded model checkpoint.
file (str): The loaded filename
"""
from ultralytics.utils.downloads import attempt_download_asset
@ -858,7 +975,18 @@ def torch_safe_load(weight, safe_only=False):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
"""
Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.
Args:
weights (str | List[str]): Model weights path(s).
device (torch.device, optional): Device to load model to.
inplace (bool): Whether to do inplace operations.
fuse (bool): Whether to fuse model.
Returns:
(torch.nn.Module): Loaded model.
"""
ensemble = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt, w = torch_safe_load(w) # load ckpt
@ -896,7 +1024,18 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
"""Loads a single model weights."""
"""
Load a single model weights.
Args:
weight (str): Model weight path.
device (torch.device, optional): Device to load model to.
inplace (bool): Whether to do inplace operations.
fuse (bool): Whether to fuse model.
Returns:
(tuple): Tuple containing the model and checkpoint.
"""
ckpt, weight = torch_safe_load(weight) # load ckpt
args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
@ -922,7 +1061,17 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
"""Parse a YOLO model.yaml dictionary into a PyTorch model."""
"""
Parse a YOLO model.yaml dictionary into a PyTorch model.
Args:
d (dict): Model dictionary.
ch (int): Input channels.
verbose (bool): Whether to print model details.
Returns:
(tuple): Tuple containing the PyTorch model and sorted list of output layers.
"""
import ast
# Args
@ -1086,7 +1235,15 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
def yaml_model_load(path):
"""Load a YOLOv8 model from a YAML file."""
"""
Load a YOLOv8 model from a YAML file.
Args:
path (str | Path): Path to the YAML file.
Returns:
(dict): Model dictionary.
"""
path = Path(path)
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
@ -1103,15 +1260,13 @@ def yaml_model_load(path):
def guess_model_scale(model_path):
"""
Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale. The function
uses regular expression matching to find the pattern of the model scale in the YAML file name, which is denoted by
n, s, m, l, or x. The function returns the size character of the model scale as a string.
Extract the size character n, s, m, l, or x of the model's scale from the model path.
Args:
model_path (str | Path): The path to the YOLO model's YAML file.
Returns:
(str): The size character of the model's scale, which can be n, s, m, l, or x.
(str): The size character of the model's scale (n, s, m, l, or x).
"""
try:
return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # returns n, s, m, l, or x
@ -1127,10 +1282,7 @@ def guess_model_task(model):
model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
Returns:
(str): Task of the model ('detect', 'segment', 'classify', 'pose').
Raises:
SyntaxError: If the task of the model could not be determined.
(str): Task of the model ('detect', 'segment', 'classify', 'pose', 'obb').
"""
def cfg2task(cfg):

View file

@ -304,17 +304,24 @@ def plt_settings(rcparams=None, backend="Agg"):
"""
Decorator to temporarily set rc parameters and the backend for a plotting function.
Example:
decorator: @plt_settings({"font.size": 12})
context manager: with plt_settings({"font.size": 12}):
Args:
rcparams (dict): Dictionary of rc parameters to set.
rcparams (dict, optional): Dictionary of rc parameters to set.
backend (str, optional): Name of the backend to use. Defaults to 'Agg'.
Returns:
(Callable): Decorated function with temporarily set rc parameters and backend. This decorator can be
applied to any function that needs to have specific matplotlib rc parameters and backend for its execution.
(Callable): Decorated function with temporarily set rc parameters and backend.
Examples:
>>> @plt_settings({"font.size": 12})
>>> def plot_function():
... plt.figure()
... plt.plot([1, 2, 3])
... plt.show()
>>> with plt_settings({"font.size": 12}):
... plt.figure()
... plt.plot([1, 2, 3])
... plt.show()
"""
if rcparams is None:
rcparams = {"font.size": 11}
@ -357,6 +364,9 @@ def set_logging(name="LOGGING_NAME", verbose=True):
name (str): Name of the logger. Defaults to "LOGGING_NAME".
verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise. Defaults to True.
Returns:
(logging.Logger): Configured logger object.
Examples:
>>> set_logging(name="ultralytics", verbose=True)
>>> logger = logging.getLogger("ultralytics")
@ -376,7 +386,7 @@ def set_logging(name="LOGGING_NAME", verbose=True):
class CustomFormatter(logging.Formatter):
def format(self, record):
"""Sets up logging with UTF-8 encoding and configurable verbosity."""
"""Format log records with UTF-8 encoding for Windows compatibility."""
return emojis(super().format(record))
try:
@ -420,9 +430,10 @@ def emojis(string=""):
class ThreadingLocked:
"""
A decorator class for ensuring thread-safe execution of a function or method. This class can be used as a decorator
to make sure that if the decorated function is called from multiple threads, only one thread at a time will be able
to execute the function.
A decorator class for ensuring thread-safe execution of a function or method.
This class can be used as a decorator to make sure that if the decorated function is called from multiple threads,
only one thread at a time will be able to execute the function.
Attributes:
lock (threading.Lock): A lock object used to manage access to the decorated function.
@ -435,7 +446,7 @@ class ThreadingLocked:
"""
def __init__(self):
"""Initializes the decorator class for thread-safe execution of a function or method."""
"""Initialize the decorator class with a threading lock."""
self.lock = threading.Lock()
def __call__(self, f):
@ -536,8 +547,7 @@ DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
def read_device_model() -> str:
"""
Reads the device model information from the system and caches it for quick access. Used by is_jetson() and
is_raspberrypi().
Reads the device model information from the system and caches it for quick access.
Returns:
(str): Kernel release information.
@ -619,7 +629,7 @@ def is_docker() -> bool:
def is_raspberrypi() -> bool:
"""
Determines if the Python environment is running on a Raspberry Pi by checking the device model information.
Determines if the Python environment is running on a Raspberry Pi.
Returns:
(bool): True if running on a Raspberry Pi, False otherwise.
@ -629,7 +639,7 @@ def is_raspberrypi() -> bool:
def is_jetson() -> bool:
"""
Determines if the Python environment is running on an NVIDIA Jetson device by checking the device model information.
Determines if the Python environment is running on an NVIDIA Jetson device.
Returns:
(bool): True if running on an NVIDIA Jetson device, False otherwise.
@ -709,8 +719,7 @@ def is_github_action_running() -> bool:
def get_git_dir():
"""
Determines whether the current file is part of a git repository and if so, returns the repository root directory. If
the current file is not part of a git repository, returns None.
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
Returns:
(Path | None): Git root directory if found or None if not found.
@ -722,8 +731,7 @@ def get_git_dir():
def is_git_dir():
"""
Determines whether the current file is part of a git repository. If the current file is not part of a git
repository, returns None.
Determines whether the current file is part of a git repository.
Returns:
(bool): True if current file is part of a git repository.
@ -1004,8 +1012,10 @@ def threaded(func):
def set_sentry():
"""
Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed and
sync=True in settings. Run 'yolo settings' to see and update settings.
Initialize the Sentry SDK for error tracking and reporting.
Only used if sentry_sdk package is installed and sync=True in settings. Run 'yolo settings' to see and update
settings.
Conditions required to send errors (ALL conditions must be met or no errors will be reported):
- sentry_sdk package is installed
@ -1016,11 +1026,6 @@ def set_sentry():
- running with rank -1 or 0
- online environment
- CLI used to run package (checked with 'yolo' as the name of the main CLI command)
The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundError exceptions and to exclude
events with 'out of memory' in their exception message.
Additionally, the function sets custom tags and user information for Sentry events.
"""
if (
not SETTINGS["sync"]

View file

@ -182,10 +182,10 @@ def check_version(
Args:
current (str): Current version or package name to get version from.
required (str): Required version or range (in pip-style format).
name (str, optional): Name to be used in warning message.
hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
verbose (bool, optional): If True, print warning message if requirement is not met.
msg (str, optional): Extra message to display if verbose.
name (str): Name to be used in warning message.
hard (bool): If True, raise an AssertionError if the requirement is not met.
verbose (bool): If True, print warning message if requirement is not met.
msg (str): Extra message to display if verbose.
Returns:
(bool): True if requirement is met, False otherwise.
@ -307,7 +307,7 @@ def check_font(font="Arial.ttf"):
font (str): Path or name of font.
Returns:
file (Path): Resolved font file path.
(Path): Resolved font file path.
"""
from matplotlib import font_manager

View file

@ -26,7 +26,7 @@ class VarifocalLoss(nn.Module):
@staticmethod
def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
"""Computes varfocal loss."""
"""Compute varfocal loss between predictions and ground truth."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with autocast(enabled=False):
loss = (
@ -41,12 +41,12 @@ class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self):
"""Initializer for FocalLoss class with no parameters."""
"""Initialize FocalLoss class with no parameters."""
super().__init__()
@staticmethod
def forward(pred, label, gamma=1.5, alpha=0.25):
"""Calculates and updates confusion matrix for object detection/classification tasks."""
"""Calculate focal loss with modulating factors for class imbalance."""
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
@ -63,20 +63,15 @@ class FocalLoss(nn.Module):
class DFLoss(nn.Module):
"""Criterion class for computing DFL losses during training."""
"""Criterion class for computing Distribution Focal Loss (DFL)."""
def __init__(self, reg_max=16) -> None:
"""Initialize the DFL module."""
"""Initialize the DFL module with regularization maximum."""
super().__init__()
self.reg_max = reg_max
def __call__(self, pred_dist, target):
"""
Return sum of left and right DFL losses.
Distribution Focal Loss (DFL) proposed in Generalized Focal Loss
https://ieeexplore.ieee.org/document/9792391
"""
"""Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
target = target.clamp_(0, self.reg_max - 1 - 0.01)
tl = target.long() # target left
tr = tl + 1 # target right
@ -89,7 +84,7 @@ class DFLoss(nn.Module):
class BboxLoss(nn.Module):
"""Criterion class for computing training losses during training."""
"""Criterion class for computing training losses for bounding boxes."""
def __init__(self, reg_max=16):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
@ -97,7 +92,7 @@ class BboxLoss(nn.Module):
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
"""Compute IoU and DFL losses for bounding boxes."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
@ -114,14 +109,14 @@ class BboxLoss(nn.Module):
class RotatedBboxLoss(BboxLoss):
"""Criterion class for computing training losses during training."""
"""Criterion class for computing training losses for rotated bounding boxes."""
def __init__(self, reg_max):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
super().__init__(reg_max)
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
"""Compute IoU and DFL losses for rotated bounding boxes."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
@ -138,15 +133,15 @@ class RotatedBboxLoss(BboxLoss):
class KeypointLoss(nn.Module):
"""Criterion class for computing training losses."""
"""Criterion class for computing keypoint losses."""
def __init__(self, sigmas) -> None:
"""Initialize the KeypointLoss class."""
"""Initialize the KeypointLoss class with keypoint sigmas."""
super().__init__()
self.sigmas = sigmas
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
"""Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
"""Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
@ -155,10 +150,10 @@ class KeypointLoss(nn.Module):
class v8DetectionLoss:
"""Criterion class for computing training losses."""
"""Criterion class for computing training losses for YOLOv8 object detection."""
def __init__(self, model, tal_topk=10): # model must be de-paralleled
"""Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
"""Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
device = next(model.parameters()).device # get model device
h = model.args # hyperparameters
@ -178,7 +173,7 @@ class v8DetectionLoss:
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
"""Preprocess targets by converting to tensor format and scaling coordinates."""
nl, ne = targets.shape
if nl == 0:
out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
@ -261,15 +256,15 @@ class v8DetectionLoss:
class v8SegmentationLoss(v8DetectionLoss):
"""Criterion class for computing training losses."""
"""Criterion class for computing training losses for YOLOv8 segmentation."""
def __init__(self, model): # model must be de-paralleled
"""Initializes the v8SegmentationLoss class, taking a de-paralleled model as argument."""
"""Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
super().__init__(model)
self.overlap = model.args.overlap_mask
def __call__(self, preds, batch):
"""Calculate and return the loss for the YOLO model."""
"""Calculate and return the combined loss for detection and segmentation."""
loss = torch.zeros(4, device=self.device) # box, cls, dfl
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
@ -444,10 +439,10 @@ class v8SegmentationLoss(v8DetectionLoss):
class v8PoseLoss(v8DetectionLoss):
"""Criterion class for computing training losses."""
"""Criterion class for computing training losses for YOLOv8 pose estimation."""
def __init__(self, model): # model must be de-paralleled
"""Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance."""
"""Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
super().__init__(model)
self.kpt_shape = model.model[-1].kpt_shape
self.bce_pose = nn.BCEWithLogitsLoss()
@ -457,7 +452,7 @@ class v8PoseLoss(v8DetectionLoss):
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
def __call__(self, preds, batch):
"""Calculate the total loss and detach it."""
"""Calculate the total loss and detach it for pose estimation."""
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
@ -524,7 +519,7 @@ class v8PoseLoss(v8DetectionLoss):
@staticmethod
def kpts_decode(anchor_points, pred_kpts):
"""Decodes predicted keypoints to image coordinates."""
"""Decode predicted keypoints to image coordinates."""
y = pred_kpts.clone()
y[..., :2] *= 2.0
y[..., 0] += anchor_points[:, [0]] - 0.5
@ -599,7 +594,7 @@ class v8PoseLoss(v8DetectionLoss):
class v8ClassificationLoss:
"""Criterion class for computing training losses."""
"""Criterion class for computing training losses for classification."""
def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
@ -613,13 +608,13 @@ class v8OBBLoss(v8DetectionLoss):
"""Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
def __init__(self, model):
"""Initializes v8OBBLoss with model, assigner, and rotated bbox loss; note model must be de-paralleled."""
"""Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
super().__init__(model)
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
"""Preprocess targets for oriented bounding box detection."""
if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 6, device=self.device)
else:
@ -636,7 +631,7 @@ class v8OBBLoss(v8DetectionLoss):
return out
def __call__(self, preds, batch):
"""Calculate and return the loss for the YOLO model."""
"""Calculate and return the loss for oriented bounding box detection."""
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
@ -726,7 +721,7 @@ class v8OBBLoss(v8DetectionLoss):
class E2EDetectLoss:
"""Criterion class for computing training losses."""
"""Criterion class for computing training losses for end-to-end detection."""
def __init__(self, model):
"""Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""

View file

@ -25,7 +25,7 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes.
box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes.
iou (bool): Calculate the standard IoU if True else return inter_area/box2_area.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
eps (float, optional): A small value to avoid division by zero.
Returns:
(np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
@ -57,7 +57,7 @@ def box_iou(box1, box2, eps=1e-7):
Args:
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
eps (float, optional): A small value to avoid division by zero.
Returns:
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
@ -73,7 +73,7 @@ def box_iou(box1, box2, eps=1e-7):
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
"""
Calculates the Intersection over Union (IoU) between bounding boxes.
Calculate the Intersection over Union (IoU) between bounding boxes.
This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
@ -84,11 +84,11 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
(x1, y1, x2, y2) format. Defaults to True.
GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
(x1, y1, x2, y2) format.
GIoU (bool, optional): If True, calculate Generalized IoU.
DIoU (bool, optional): If True, calculate Distance IoU.
CIoU (bool, optional): If True, calculate Complete IoU.
eps (float, optional): A small value to avoid division by zero.
Returns:
(torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
@ -143,7 +143,7 @@ def mask_iou(mask1, mask2, eps=1e-7):
product of image width and height.
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
product of image width and height.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
eps (float, optional): A small value to avoid division by zero.
Returns:
(torch.Tensor): A tensor of shape (N, M) representing masks IoU.
@ -162,7 +162,7 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
sigma (list): A list containing 17 values representing keypoint scales.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
eps (float, optional): A small value to avoid division by zero.
Returns:
(torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
@ -177,7 +177,7 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
def _get_covariance_matrix(boxes):
"""
Generating covariance matrix from obbs.
Generate covariance matrix from oriented bounding boxes.
Args:
boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
@ -199,20 +199,18 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
"""
Calculate probabilistic IoU between oriented bounding boxes.
Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
Args:
obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.
CIoU (bool, optional): If True, calculate CIoU. Defaults to False.
eps (float, optional): Small value to avoid division by zero. Defaults to 1e-7.
CIoU (bool, optional): If True, calculate CIoU.
eps (float, optional): Small value to avoid division by zero.
Returns:
(torch.Tensor): OBB similarities, shape (N,).
Note:
OBB format: [center_x, center_y, width, height, rotation_angle].
If CIoU is True, returns CIoU instead of IoU.
Notes:
- OBB format: [center_x, center_y, width, height, rotation_angle].
- Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
"""
x1, y1 = obb1[..., :2].split(1, dim=-1)
x2, y2 = obb2[..., :2].split(1, dim=-1)
@ -243,15 +241,18 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
def batch_probiou(obb1, obb2, eps=1e-7):
"""
Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
Calculate the probabilistic IoU between oriented bounding boxes.
Args:
obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
eps (float, optional): A small value to avoid division by zero.
Returns:
(torch.Tensor): A tensor of shape (N, M) representing obb similarities.
References:
https://arxiv.org/pdf/2106.06072v1.pdf
"""
obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2
@ -277,16 +278,16 @@ def batch_probiou(obb1, obb2, eps=1e-7):
def smooth_bce(eps=0.1):
"""
Computes smoothed positive and negative Binary Cross-Entropy targets.
This function calculates positive and negative label smoothing BCE targets based on a given epsilon value.
For implementation details, refer to https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441.
Compute smoothed positive and negative Binary Cross-Entropy targets.
Args:
eps (float, optional): The epsilon value for label smoothing. Defaults to 0.1.
eps (float, optional): The epsilon value for label smoothing.
Returns:
(tuple): A tuple containing the positive and negative label smoothing BCE targets.
References:
https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
"""
return 1.0 - 0.5 * eps, 0.5 * eps
@ -304,7 +305,15 @@ class ConfusionMatrix:
"""
def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):
"""Initialize attributes for the YOLO model."""
"""
Initialize a ConfusionMatrix instance.
Args:
nc (int): Number of classes.
conf (float, optional): Confidence threshold for detections.
iou_thres (float, optional): IoU threshold for matching detections to ground truth.
task (str, optional): Type of task, either 'detect' or 'classify'.
"""
self.task = task
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
self.nc = nc # number of classes
@ -382,11 +391,16 @@ class ConfusionMatrix:
self.matrix[dc, self.nc] += 1 # predicted background
def matrix(self):
"""Returns the confusion matrix."""
"""Return the confusion matrix."""
return self.matrix
def tp_fp(self):
"""Returns true positives and false positives."""
"""
Return true positives and false positives.
Returns:
(tuple): True positives and false positives.
"""
tp = self.matrix.diagonal() # true positives
fp = self.matrix.sum(1) - tp # false positives
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
@ -454,7 +468,17 @@ def smooth(y, f=0.05):
@plt_settings()
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
"""Plots a precision-recall curve."""
"""
Plot precision-recall curve.
Args:
px (np.ndarray): X values for the PR curve.
py (np.ndarray): Y values for the PR curve.
ap (np.ndarray): Average precision values.
save_dir (Path, optional): Path to save the plot.
names (dict, optional): Dictionary mapping class indices to class names.
on_plot (callable, optional): Function to call after plot is saved.
"""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)
@ -479,7 +503,18 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=N
@plt_settings()
def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None):
"""Plots a metric-confidence curve."""
"""
Plot metric-confidence curve.
Args:
px (np.ndarray): X values for the metric-confidence curve.
py (np.ndarray): Y values for the metric-confidence curve.
save_dir (Path, optional): Path to save the plot.
names (dict, optional): Dictionary mapping class indices to class names.
xlabel (str, optional): X-axis label.
ylabel (str, optional): Y-axis label.
on_plot (callable, optional): Function to call after plot is saved.
"""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
@ -538,33 +573,33 @@ def ap_per_class(
tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix=""
):
"""
Computes the average precision per class for object detection evaluation.
Compute the average precision per class for object detection evaluation.
Args:
tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
conf (np.ndarray): Array of confidence scores of the detections.
pred_cls (np.ndarray): Array of predicted classes of the detections.
target_cls (np.ndarray): Array of true classes of the detections.
plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
names (dict, optional): Dict of class names to plot PR curves. Defaults to an empty tuple.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
plot (bool, optional): Whether to plot PR curves or not.
on_plot (func, optional): A callback to pass plots path and data when they are rendered.
save_dir (Path, optional): Directory to save the PR curves.
names (dict, optional): Dict of class names to plot PR curves.
eps (float, optional): A small value to avoid division by zero.
prefix (str, optional): A prefix string for saving the plot files.
Returns:
tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
x (np.ndarray): X-axis values for the curves. Shape: (1000,).
prec_values (np.ndarray): Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.
fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class.
p (np.ndarray): Precision values at threshold given by max F1 metric for each class.
r (np.ndarray): Recall values at threshold given by max F1 metric for each class.
f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class.
ap (np.ndarray): Average precision for each class at different IoU thresholds.
unique_classes (np.ndarray): An array of unique classes that have data.
p_curve (np.ndarray): Precision curves for each class.
r_curve (np.ndarray): Recall curves for each class.
f1_curve (np.ndarray): F1-score curves for each class.
x (np.ndarray): X-axis values for the curves.
prec_values (np.ndarray): Precision values at mAP@0.5 for each class.
"""
# Sort by objectness
i = np.argsort(-conf)
@ -651,7 +686,7 @@ class Metric(SimpleClass):
"""
def __init__(self) -> None:
"""Initializes a Metric instance for computing evaluation metrics for the YOLOv8 model."""
"""Initialize a Metric instance for computing evaluation metrics for the YOLOv8 model."""
self.p = [] # (nc, )
self.r = [] # (nc, )
self.f1 = [] # (nc, )
@ -662,7 +697,7 @@ class Metric(SimpleClass):
@property
def ap50(self):
"""
Returns the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
Returns:
(np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
@ -672,7 +707,7 @@ class Metric(SimpleClass):
@property
def ap(self):
"""
Returns the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
Returns:
(np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
@ -682,7 +717,7 @@ class Metric(SimpleClass):
@property
def mp(self):
"""
Returns the Mean Precision of all classes.
Return the Mean Precision of all classes.
Returns:
(float): The mean precision of all classes.
@ -692,7 +727,7 @@ class Metric(SimpleClass):
@property
def mr(self):
"""
Returns the Mean Recall of all classes.
Return the Mean Recall of all classes.
Returns:
(float): The mean recall of all classes.
@ -702,7 +737,7 @@ class Metric(SimpleClass):
@property
def map50(self):
"""
Returns the mean Average Precision (mAP) at an IoU threshold of 0.5.
Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
Returns:
(float): The mAP at an IoU threshold of 0.5.
@ -712,7 +747,7 @@ class Metric(SimpleClass):
@property
def map75(self):
"""
Returns the mean Average Precision (mAP) at an IoU threshold of 0.75.
Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
Returns:
(float): The mAP at an IoU threshold of 0.75.
@ -722,7 +757,7 @@ class Metric(SimpleClass):
@property
def map(self):
"""
Returns the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
Returns:
(float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
@ -730,41 +765,42 @@ class Metric(SimpleClass):
return self.all_ap.mean() if len(self.all_ap) else 0.0
def mean_results(self):
"""Mean of results, return mp, mr, map50, map."""
"""Return mean of results, mp, mr, map50, map."""
return [self.mp, self.mr, self.map50, self.map]
def class_result(self, i):
"""Class-aware result, return p[i], r[i], ap50[i], ap[i]."""
"""Return class-aware result, p[i], r[i], ap50[i], ap[i]."""
return self.p[i], self.r[i], self.ap50[i], self.ap[i]
@property
def maps(self):
"""MAP of each class."""
"""Return mAP of each class."""
maps = np.zeros(self.nc) + self.map
for i, c in enumerate(self.ap_class_index):
maps[c] = self.ap[i]
return maps
def fitness(self):
"""Model fitness as a weighted combination of metrics."""
"""Return model fitness as a weighted combination of metrics."""
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
return (np.array(self.mean_results()) * w).sum()
def update(self, results):
"""
Updates the evaluation metrics of the model with a new set of results.
Update the evaluation metrics with a new set of results.
Args:
results (tuple): A tuple containing the following evaluation metrics:
- p (list): Precision for each class. Shape: (nc,).
- r (list): Recall for each class. Shape: (nc,).
- f1 (list): F1 score for each class. Shape: (nc,).
- all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
- ap_class_index (list): Index of class for each AP score. Shape: (nc,).
Side Effects:
Updates the class attributes `self.p`, `self.r`, `self.f1`, `self.all_ap`, and `self.ap_class_index` based
on the values provided in the `results` tuple.
results (tuple): A tuple containing evaluation metrics:
- p (list): Precision for each class.
- r (list): Recall for each class.
- f1 (list): F1 score for each class.
- all_ap (list): AP scores for all classes and all IoU thresholds.
- ap_class_index (list): Index of class for each AP score.
- p_curve (list): Precision curve for each class.
- r_curve (list): Recall curve for each class.
- f1_curve (list): F1 curve for each class.
- px (list): X values for the curves.
- prec_values (list): Precision values for each class.
"""
(
self.p,
@ -781,12 +817,12 @@ class Metric(SimpleClass):
@property
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
"""Return a list of curves for accessing specific metrics curves."""
return []
@property
def curves_results(self):
"""Returns a list of curves for accessing specific metrics curves."""
"""Return a list of curves for accessing specific metrics curves."""
return [
[self.px, self.prec_values, "Recall", "Precision"],
[self.px, self.f1_curve, "Confidence", "F1"],
@ -797,36 +833,26 @@ class Metric(SimpleClass):
class DetMetrics(SimpleClass):
"""
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an
object detection model.
Args:
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple.
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
Attributes:
save_dir (Path): A path to the directory where the output plots will be saved.
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
names (dict of str): A dict of strings that represents the names of the classes.
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
Methods:
process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions.
keys: Returns a list of keys for accessing the computed detection metrics.
mean_results: Returns a list of mean values for the computed detection metrics.
class_result(i): Returns a list of values for the computed detection metrics for a specific class.
maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds.
fitness: Computes the fitness score based on the computed detection metrics.
ap_class_index: Returns a list of class indices sorted by their average precision (AP) values.
results_dict: Returns a dictionary that maps detection metric keys to their computed values.
curves: TODO
curves_results: TODO
plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
names (dict): A dictionary of class names.
box (Metric): An instance of the Metric class for storing detection results.
speed (dict): A dictionary for storing execution times of different parts of the detection process.
task (str): The task type, set to 'detect'.
"""
def __init__(self, save_dir=Path("."), plot=False, names={}) -> None:
"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
"""
Initialize a DetMetrics instance with a save directory, plot flag, and class names.
Args:
save_dir (Path, optional): Directory to save plots.
plot (bool, optional): Whether to plot precision-recall curves.
names (dict, optional): Dictionary mapping class indices to names.
"""
self.save_dir = save_dir
self.plot = plot
self.names = names
@ -835,7 +861,16 @@ class DetMetrics(SimpleClass):
self.task = "detect"
def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
"""Process predicted results for object detection and update metrics."""
"""
Process predicted results for object detection and update metrics.
Args:
tp (np.ndarray): True positive array.
conf (np.ndarray): Confidence array.
pred_cls (np.ndarray): Predicted class indices array.
target_cls (np.ndarray): Target class indices array.
on_plot (callable, optional): Function to call after plots are generated.
"""
results = ap_per_class(
tp,
conf,
@ -851,7 +886,7 @@ class DetMetrics(SimpleClass):
@property
def keys(self):
"""Returns a list of keys for accessing specific metrics."""
"""Return a list of keys for accessing specific metrics."""
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
def mean_results(self):
@ -864,32 +899,32 @@ class DetMetrics(SimpleClass):
@property
def maps(self):
"""Returns mean Average Precision (mAP) scores per class."""
"""Return mean Average Precision (mAP) scores per class."""
return self.box.maps
@property
def fitness(self):
"""Returns the fitness of box object."""
"""Return the fitness of box object."""
return self.box.fitness()
@property
def ap_class_index(self):
"""Returns the average precision index per class."""
"""Return the average precision index per class."""
return self.box.ap_class_index
@property
def results_dict(self):
"""Returns dictionary of computed performance metrics and statistics."""
"""Return dictionary of computed performance metrics and statistics."""
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
@property
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
"""Return a list of curves for accessing specific metrics curves."""
return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
@property
def curves_results(self):
"""Returns dictionary of computed performance metrics and statistics."""
"""Return dictionary of computed performance metrics and statistics."""
return self.box.curves_results
@ -897,31 +932,25 @@ class SegmentMetrics(SimpleClass):
"""
Calculates and aggregates detection and segmentation metrics over a given set of classes.
Args:
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
plot (bool): Whether to save the detection and segmentation plots. Default is False.
names (list): List of class names. Default is an empty list.
Attributes:
save_dir (Path): Path to the directory where the output plots should be saved.
plot (bool): Whether to save the detection and segmentation plots.
names (list): List of class names.
names (dict): Dictionary of class names.
box (Metric): An instance of the Metric class to calculate box detection metrics.
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
speed (dict): Dictionary to store the time taken in different phases of inference.
Methods:
process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
class_result(i): Returns the detection and segmentation metrics of class `i`.
maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
fitness: Returns the fitness scores, which are a single weighted combination of metrics.
ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
task (str): The task type, set to 'segment'.
"""
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
"""Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names."""
"""
Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
Args:
save_dir (Path, optional): Directory to save plots.
plot (bool, optional): Whether to plot precision-recall curves.
names (dict, optional): Dictionary mapping class indices to names.
"""
self.save_dir = save_dir
self.plot = plot
self.names = names
@ -932,15 +961,15 @@ class SegmentMetrics(SimpleClass):
def process(self, tp, tp_m, conf, pred_cls, target_cls, on_plot=None):
"""
Processes the detection and segmentation metrics over the given set of predictions.
Process the detection and segmentation metrics over the given set of predictions.
Args:
tp (list): List of True Positive boxes.
tp_m (list): List of True Positive masks.
conf (list): List of confidence scores.
pred_cls (list): List of predicted classes.
target_cls (list): List of target classes.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
tp (np.ndarray): True positive array for boxes.
tp_m (np.ndarray): True positive array for masks.
conf (np.ndarray): Confidence array.
pred_cls (np.ndarray): Predicted class indices array.
target_cls (np.ndarray): Target class indices array.
on_plot (callable, optional): Function to call after plots are generated.
"""
results_mask = ap_per_class(
tp_m,
@ -971,7 +1000,7 @@ class SegmentMetrics(SimpleClass):
@property
def keys(self):
"""Returns a list of keys for accessing metrics."""
"""Return a list of keys for accessing metrics."""
return [
"metrics/precision(B)",
"metrics/recall(B)",
@ -988,32 +1017,36 @@ class SegmentMetrics(SimpleClass):
return self.box.mean_results() + self.seg.mean_results()
def class_result(self, i):
"""Returns classification results for a specified class index."""
"""Return classification results for a specified class index."""
return self.box.class_result(i) + self.seg.class_result(i)
@property
def maps(self):
"""Returns mAP scores for object detection and semantic segmentation models."""
"""Return mAP scores for object detection and semantic segmentation models."""
return self.box.maps + self.seg.maps
@property
def fitness(self):
"""Get the fitness score for both segmentation and bounding box models."""
"""Return the fitness score for both segmentation and bounding box models."""
return self.seg.fitness() + self.box.fitness()
@property
def ap_class_index(self):
"""Boxes and masks have the same ap_class_index."""
"""
Return the class indices.
Boxes and masks have the same ap_class_index.
"""
return self.box.ap_class_index
@property
def results_dict(self):
"""Returns results of object detection model for evaluation."""
"""Return results of object detection model for evaluation."""
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
@property
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
"""Return a list of curves for accessing specific metrics curves."""
return [
"Precision-Recall(B)",
"F1-Confidence(B)",
@ -1027,7 +1060,7 @@ class SegmentMetrics(SimpleClass):
@property
def curves_results(self):
"""Returns dictionary of computed performance metrics and statistics."""
"""Return dictionary of computed performance metrics and statistics."""
return self.box.curves_results + self.seg.curves_results
@ -1035,18 +1068,14 @@ class PoseMetrics(SegmentMetrics):
"""
Calculates and aggregates detection and pose metrics over a given set of classes.
Args:
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
plot (bool): Whether to save the detection and segmentation plots. Default is False.
names (list): List of class names. Default is an empty list.
Attributes:
save_dir (Path): Path to the directory where the output plots should be saved.
plot (bool): Whether to save the detection and segmentation plots.
names (list): List of class names.
plot (bool): Whether to save the detection and pose plots.
names (dict): Dictionary of class names.
box (Metric): An instance of the Metric class to calculate box detection metrics.
pose (Metric): An instance of the Metric class to calculate mask segmentation metrics.
pose (Metric): An instance of the Metric class to calculate pose metrics.
speed (dict): Dictionary to store the time taken in different phases of inference.
task (str): The task type, set to 'pose'.
Methods:
process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
@ -1059,7 +1088,14 @@ class PoseMetrics(SegmentMetrics):
"""
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
"""Initialize the PoseMetrics class with directory path, class names, and plotting options."""
"""
Initialize the PoseMetrics class with directory path, class names, and plotting options.
Args:
save_dir (Path, optional): Directory to save plots.
plot (bool, optional): Whether to plot precision-recall curves.
names (dict, optional): Dictionary mapping class indices to names.
"""
super().__init__(save_dir, plot, names)
self.save_dir = save_dir
self.plot = plot
@ -1071,15 +1107,15 @@ class PoseMetrics(SegmentMetrics):
def process(self, tp, tp_p, conf, pred_cls, target_cls, on_plot=None):
"""
Processes the detection and pose metrics over the given set of predictions.
Process the detection and pose metrics over the given set of predictions.
Args:
tp (list): List of True Positive boxes.
tp_p (list): List of True Positive keypoints.
conf (list): List of confidence scores.
pred_cls (list): List of predicted classes.
target_cls (list): List of target classes.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
tp (np.ndarray): True positive array for boxes.
tp_p (np.ndarray): True positive array for keypoints.
conf (np.ndarray): Confidence array.
pred_cls (np.ndarray): Predicted class indices array.
target_cls (np.ndarray): Target class indices array.
on_plot (callable, optional): Function to call after plots are generated.
"""
results_pose = ap_per_class(
tp_p,
@ -1110,7 +1146,7 @@ class PoseMetrics(SegmentMetrics):
@property
def keys(self):
"""Returns list of evaluation metric keys."""
"""Return list of evaluation metric keys."""
return [
"metrics/precision(B)",
"metrics/recall(B)",
@ -1132,17 +1168,17 @@ class PoseMetrics(SegmentMetrics):
@property
def maps(self):
"""Returns the mean average precision (mAP) per class for both box and pose detections."""
"""Return the mean average precision (mAP) per class for both box and pose detections."""
return self.box.maps + self.pose.maps
@property
def fitness(self):
"""Computes classification metrics and speed using the `targets` and `pred` inputs."""
"""Return combined fitness score for pose and box detection."""
return self.pose.fitness() + self.box.fitness()
@property
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
"""Return a list of curves for accessing specific metrics curves."""
return [
"Precision-Recall(B)",
"F1-Confidence(B)",
@ -1156,7 +1192,7 @@ class PoseMetrics(SegmentMetrics):
@property
def curves_results(self):
"""Returns dictionary of computed performance metrics and statistics."""
"""Return dictionary of computed performance metrics and statistics."""
return self.box.curves_results + self.pose.curves_results
@ -1167,13 +1203,8 @@ class ClassifyMetrics(SimpleClass):
Attributes:
top1 (float): The top-1 accuracy.
top5 (float): The top-5 accuracy.
speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline.
fitness (float): The fitness of the model, which is equal to top-5 accuracy.
results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.
keys (List[str]): A list of keys for the results_dict.
Methods:
process(targets, pred): Processes the targets and predictions to compute classification metrics.
speed (dict): A dictionary containing the time taken for each step in the pipeline.
task (str): The task type, set to 'classify'.
"""
def __init__(self) -> None:
@ -1184,7 +1215,13 @@ class ClassifyMetrics(SimpleClass):
self.task = "classify"
def process(self, targets, pred):
"""Target classes and predicted classes."""
"""
Process target classes and predicted classes to compute metrics.
Args:
targets (torch.Tensor): Target classes.
pred (torch.Tensor): Predicted classes.
"""
pred, targets = torch.cat(pred), torch.cat(targets)
correct = (targets[:, None] == pred).float()
acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
@ -1192,35 +1229,54 @@ class ClassifyMetrics(SimpleClass):
@property
def fitness(self):
"""Returns mean of top-1 and top-5 accuracies as fitness score."""
"""Return mean of top-1 and top-5 accuracies as fitness score."""
return (self.top1 + self.top5) / 2
@property
def results_dict(self):
"""Returns a dictionary with model's performance metrics and fitness score."""
"""Return a dictionary with model's performance metrics and fitness score."""
return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
@property
def keys(self):
"""Returns a list of keys for the results_dict property."""
"""Return a list of keys for the results_dict property."""
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
@property
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
"""Return a list of curves for accessing specific metrics curves."""
return []
@property
def curves_results(self):
"""Returns a list of curves for accessing specific metrics curves."""
"""Return a list of curves for accessing specific metrics curves."""
return []
class OBBMetrics(SimpleClass):
"""Metrics for evaluating oriented bounding box (OBB) detection, see https://arxiv.org/pdf/2106.06072.pdf."""
"""
Metrics for evaluating oriented bounding box (OBB) detection.
Attributes:
save_dir (Path): Path to the directory where the output plots should be saved.
plot (bool): Whether to save the detection plots.
names (dict): Dictionary of class names.
box (Metric): An instance of the Metric class for storing detection results.
speed (dict): A dictionary for storing execution times of different parts of the detection process.
References:
https://arxiv.org/pdf/2106.06072.pdf
"""
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
"""Initialize an OBBMetrics instance with directory, plotting, callback, and class names."""
"""
Initialize an OBBMetrics instance with directory, plotting, and class names.
Args:
save_dir (Path, optional): Directory to save plots.
plot (bool, optional): Whether to plot precision-recall curves.
names (dict, optional): Dictionary mapping class indices to names.
"""
self.save_dir = save_dir
self.plot = plot
self.names = names
@ -1228,7 +1284,16 @@ class OBBMetrics(SimpleClass):
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
"""Process predicted results for object detection and update metrics."""
"""
Process predicted results for object detection and update metrics.
Args:
tp (np.ndarray): True positive array.
conf (np.ndarray): Confidence array.
pred_cls (np.ndarray): Predicted class indices array.
target_cls (np.ndarray): Target class indices array.
on_plot (callable, optional): Function to call after plots are generated.
"""
results = ap_per_class(
tp,
conf,
@ -1244,7 +1309,7 @@ class OBBMetrics(SimpleClass):
@property
def keys(self):
"""Returns a list of keys for accessing specific metrics."""
"""Return a list of keys for accessing specific metrics."""
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
def mean_results(self):
@ -1257,30 +1322,30 @@ class OBBMetrics(SimpleClass):
@property
def maps(self):
"""Returns mean Average Precision (mAP) scores per class."""
"""Return mean Average Precision (mAP) scores per class."""
return self.box.maps
@property
def fitness(self):
"""Returns the fitness of box object."""
"""Return the fitness of box object."""
return self.box.fitness()
@property
def ap_class_index(self):
"""Returns the average precision index per class."""
"""Return the average precision index per class."""
return self.box.ap_class_index
@property
def results_dict(self):
"""Returns dictionary of computed performance metrics and statistics."""
"""Return dictionary of computed performance metrics and statistics."""
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
@property
def curves(self):
"""Returns a list of curves for accessing specific metrics curves."""
"""Return a list of curves for accessing specific metrics curves."""
return []
@property
def curves_results(self):
"""Returns a list of curves for accessing specific metrics curves."""
"""Return a list of curves for accessing specific metrics curves."""
return []

View file

@ -18,6 +18,11 @@ class Profile(contextlib.ContextDecorator):
"""
YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'.
Attributes:
t (float): Accumulated time.
device (torch.device): Device used for model inference.
cuda (bool): Whether CUDA is being used.
Examples:
>>> from ultralytics.utils.ops import Profile
>>> with Profile(device=device) as dt:
@ -30,8 +35,8 @@ class Profile(contextlib.ContextDecorator):
Initialize the Profile class.
Args:
t (float): Initial time. Defaults to 0.0.
device (torch.device): Devices used for model inference. Defaults to None (cpu).
t (float): Initial time.
device (torch.device): Device used for model inference.
"""
self.t = t
self.device = device
@ -63,12 +68,12 @@ def segment2box(segment, width=640, height=640):
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).
Args:
segment (torch.Tensor): the segment label
width (int): the width of the image. Defaults to 640
height (int): The height of the image. Defaults to 640
segment (torch.Tensor): The segment label.
width (int): The width of the image.
height (int): The height of the image.
Returns:
(np.ndarray): the minimum and maximum x and y values of the segment.
(np.ndarray): The minimum and maximum x and y values of the segment.
"""
x, y = segment.T # segment xy
# any 3 out of 4 sides are outside the image, clip coordinates first, https://github.com/ultralytics/ultralytics/pull/18294
@ -87,21 +92,20 @@ def segment2box(segment, width=640, height=640):
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
"""
Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
specified in (img1_shape) to the shape of a different image (img0_shape).
Rescale bounding boxes from img1_shape to img0_shape.
Args:
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
img0_shape (tuple): the shape of the target image, in the format of (height, width).
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
boxes (torch.Tensor): The bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2).
img0_shape (tuple): The shape of the target image, in the format of (height, width).
ratio_pad (tuple): A tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
calculated based on the size difference between the two images.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
xywh (bool): The box format is xywh or not, default=False.
xywh (bool): The box format is xywh or not.
Returns:
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
(torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2).
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
@ -146,8 +150,8 @@ def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
Args:
boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
scores (torch.Tensor): Confidence scores, shape (N,).
threshold (float, optional): IoU threshold. Defaults to 0.45.
use_triu (bool, optional): Whether to use `torch.triu` operator. It'd be useful for disable it
threshold (float): IoU threshold.
use_triu (bool): Whether to use `torch.triu` operator. It'd be useful for disable it
when exporting obb models to some formats that do not support `torch.triu`.
Returns:
@ -210,7 +214,7 @@ def non_max_suppression(
list contains the apriori labels for a given image. The list should be in the format
output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
max_det (int): The maximum number of boxes to keep after NMS.
nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
nc (int): The number of classes output by the model. Any indices after this will be considered masks.
max_time_img (float): The maximum time (seconds) for processing one image.
max_nms (int): The maximum number of boxes into torchvision.ops.nms().
max_wh (int): The maximum box width and height in pixels.
@ -333,7 +337,7 @@ def clip_boxes(boxes, shape):
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
Args:
boxes (torch.Tensor): The bounding boxes to clip.
boxes (torch.Tensor | numpy.ndarray): The bounding boxes to clip.
shape (tuple): The shape of the image.
Returns:
@ -359,7 +363,7 @@ def clip_coords(coords, shape):
shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
Returns:
(torch.Tensor | numpy.ndarray): Clipped coordinates
(torch.Tensor | numpy.ndarray): Clipped coordinates.
"""
if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x
@ -451,10 +455,11 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
Args:
x (np.ndarray | torch.Tensor): The bounding box coordinates.
w (int): Width of the image. Defaults to 640
h (int): Height of the image. Defaults to 640
padw (int): Padding width. Defaults to 0
padh (int): Padding height. Defaults to 0
w (int): Width of the image.
h (int): Height of the image.
padw (int): Padding width.
padh (int): Padding height.
Returns:
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
@ -475,10 +480,10 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
Args:
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
w (int): The width of the image. Defaults to 640
h (int): The height of the image. Defaults to 640
clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
eps (float): The minimum value of the box's width and height. Defaults to 0.0
w (int): The width of the image.
h (int): The height of the image.
clip (bool): If True, the boxes will be clipped to the image boundaries.
eps (float): The minimum value of the box's width and height.
Returns:
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
@ -598,13 +603,13 @@ def xywhr2xyxyxyxy(x):
def ltwh2xyxy(x):
"""
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
Args:
x (np.ndarray | torch.Tensor): the input image
x (np.ndarray | torch.Tensor): The input image.
Returns:
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
(np.ndarray | torch.Tensor): The xyxy coordinates of the bounding boxes.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 2] = x[..., 2] + x[..., 0] # width
@ -614,13 +619,13 @@ def ltwh2xyxy(x):
def segments2boxes(segments):
"""
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
Args:
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
segments (List): List of segments, each segment is a list of points, each point is a list of x, y coordinates.
Returns:
(np.ndarray): the xywh coordinates of the bounding boxes.
(np.ndarray): The xywh coordinates of the bounding boxes.
"""
boxes = []
for s in segments:
@ -634,11 +639,11 @@ def resample_segments(segments, n=1000):
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
Args:
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
n (int): number of points to resample the segment to. Defaults to 1000
segments (List): A list of (n,2) arrays, where n is the number of points in the segment.
n (int): Number of points to resample the segment to.
Returns:
segments (list): the resampled segments.
segments (List): The resampled segments.
"""
for i, s in enumerate(segments):
if len(s) == n:
@ -655,14 +660,14 @@ def resample_segments(segments, n=1000):
def crop_mask(masks, boxes):
"""
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box.
Crop masks to bounding boxes.
Args:
masks (torch.Tensor): [n, h, w] tensor of masks
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
masks (torch.Tensor): [n, h, w] tensor of masks.
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form.
Returns:
(torch.Tensor): The masks are being cropped to the bounding box.
(torch.Tensor): Cropped masks.
"""
_, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
@ -681,7 +686,7 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.
upsample (bool): A flag to indicate whether to upsample the mask to the original image size.
Returns:
(torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
@ -707,16 +712,16 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
def process_mask_native(protos, masks_in, bboxes, shape):
"""
It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
Apply masks to bounding boxes using the output of the mask head with native upsampling.
Args:
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
protos (torch.Tensor): [mask_dim, mask_h, mask_w].
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms.
bboxes (torch.Tensor): [n, 4], n is number of masks after nms.
shape (tuple): The size of the input image (h,w).
Returns:
masks (torch.Tensor): The returned masks with dimensions [h, w, n].
(torch.Tensor): The returned masks with dimensions [h, w, n].
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
@ -734,6 +739,9 @@ def scale_masks(masks, shape, padding=True):
shape (tuple): Height and width.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
Returns:
(torch.Tensor): Rescaled masks.
"""
mh, mw = masks.shape[2:]
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
@ -755,10 +763,10 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
Args:
img1_shape (tuple): The shape of the image that the coords are from.
coords (torch.Tensor): the coords to be scaled of shape n,2.
img0_shape (tuple): the shape of the image that the segmentation is being applied to.
ratio_pad (tuple): the ratio of the image size to the padded image size.
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False.
coords (torch.Tensor): The coords to be scaled of shape n,2.
img0_shape (tuple): The shape of the image that the segmentation is being applied to.
ratio_pad (tuple): The ratio of the image size to the padded image size.
normalize (bool): If True, the coordinates will be normalized to the range [0, 1].
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
@ -805,14 +813,14 @@ def regularize_rboxes(rboxes):
def masks2segments(masks, strategy="all"):
"""
It takes a list of masks(n,h,w) and returns a list of segments(n,xy).
Convert masks to segments.
Args:
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
strategy (str): 'all' or 'largest'. Defaults to all
masks (torch.Tensor): The output of the model, which is a tensor of shape (batch_size, 160, 160).
strategy (str): 'all' or 'largest'.
Returns:
segments (List): list of segment masks
(List): List of segment masks.
"""
from ultralytics.data.converter import merge_multi_segment
@ -852,10 +860,10 @@ def clean_str(s):
Cleans a string by replacing special characters with '_' character.
Args:
s (str): a string needing special characters replaced
s (str): A string needing special characters replaced.
Returns:
(str): a string with special characters replaced by an underscore _
(str): A string with special characters replaced by an underscore _.
"""
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)

View file

@ -25,9 +25,9 @@ class Colors:
RGB values.
Attributes:
palette (list of tuple): List of RGB color values.
palette (List[Tuple]): List of RGB color values.
n (int): The number of colors in the palette.
pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
Examples:
>>> from ultralytics.utils.plotting import Colors
@ -142,13 +142,13 @@ class Colors:
)
def __call__(self, i, bgr=False):
"""Converts hex color codes to RGB values."""
"""Convert hex color codes to RGB values."""
c = self.palette[int(i) % self.n]
return (c[2], c[1], c[0]) if bgr else c
@staticmethod
def hex2rgb(h):
"""Converts hex color codes to RGB values (i.e. default PIL order)."""
"""Convert hex color codes to RGB values (i.e. default PIL order)."""
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
@ -160,13 +160,15 @@ class Annotator:
Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
Attributes:
im (Image.Image or numpy array): The image to annotate.
im (Image.Image or np.ndarray): The image to annotate.
pil (bool): Whether to use PIL or cv2 for drawing annotations.
font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.
lw (float): Line width for drawing.
skeleton (List[List[int]]): Skeleton structure for keypoints.
limb_color (List[int]): Color palette for limbs.
kpt_color (List[int]): Color palette for keypoints.
dark_colors (set): Set of colors considered dark for text contrast.
light_colors (set): Set of colors considered light for text contrast.
Examples:
>>> from ultralytics.utils.plotting import Annotator
@ -256,7 +258,7 @@ class Annotator:
txt_color (tuple, optional): The color of the text (R, G, B).
Returns:
txt_color (tuple): Text color for label
(tuple): Text color for label.
Examples:
>>> from ultralytics.utils.plotting import Annotator
@ -273,14 +275,14 @@ class Annotator:
def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
"""
Draws a bounding box to image with label.
Draw a bounding box on an image with a given label.
Args:
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
label (str): The text label to be displayed.
label (str, optional): The text label to be displayed.
color (tuple, optional): The background color of the rectangle (B, G, R).
txt_color (tuple, optional): The color of the text (R, G, B).
rotated (bool, optional): Variable used to check if task is OBB
rotated (bool, optional): Whether the task is oriented bounding box detection.
Examples:
>>> from ultralytics.utils.plotting import Annotator
@ -340,11 +342,11 @@ class Annotator:
Plot masks on image.
Args:
masks (tensor): Predicted masks on cuda, shape: [n, h, w]
colors (List[List[Int]]): Colors for predicted masks, [[r, g, b] * n]
im_gpu (tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
masks (torch.Tensor): Predicted masks on cuda, shape: [n, h, w]
colors (List[List[int]]): Colors for predicted masks, [[r, g, b] * n]
im_gpu (torch.Tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
alpha (float, optional): Mask transparency: 0.0 fully transparent, 1.0 opaque.
retina_masks (bool, optional): Whether to use high resolution masks or not.
"""
if self.pil:
# Convert to numpy first
@ -377,11 +379,11 @@ class Annotator:
Args:
kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
shape (tuple, optional): Image shape (h, w). Defaults to (640, 640).
radius (int, optional): Keypoint radius. Defaults to 5.
kpt_line (bool, optional): Draw lines between keypoints. Defaults to True.
conf_thres (float, optional): Confidence threshold. Defaults to 0.25.
kpt_color (tuple, optional): Keypoint color (B, G, R). Defaults to None.
shape (tuple, optional): Image shape (h, w).
radius (int, optional): Keypoint radius.
kpt_line (bool, optional): Draw lines between keypoints.
conf_thres (float, optional): Confidence threshold.
kpt_color (tuple, optional): Keypoint color (B, G, R).
Note:
- `kpt_line=True` currently only supports human pose plotting.
@ -436,7 +438,16 @@ class Annotator:
self.draw.rectangle(xy, fill, outline, width)
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
"""Adds text to an image using PIL or cv2."""
"""
Add text to an image using PIL or cv2.
Args:
xy (List[int]): Top-left coordinates for text placement.
text (str): Text to be drawn.
txt_color (tuple, optional): Text color (R, G, B).
anchor (str, optional): Text anchor position ('top' or 'bottom').
box_style (bool, optional): Whether to draw text with a background box.
"""
if anchor == "bottom": # start y from font bottom
w, h = self.font.getsize(text) # text width, height
xy[1] += 1 - h
@ -492,7 +503,7 @@ class Annotator:
@staticmethod
def get_bbox_dimension(bbox=None):
"""
Calculate the area of a bounding box.
Calculate the dimensions and area of a bounding box.
Args:
bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
@ -517,7 +528,16 @@ class Annotator:
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
"""Plot training labels including class histograms and box statistics."""
"""
Plot training labels including class histograms and box statistics.
Args:
boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
cls (np.ndarray): Class indices.
names (Dict, optional): Dictionary mapping class indices to class names.
save_dir (Path, optional): Directory to save the plot.
on_plot (Callable, optional): Function to call after plot is saved.
"""
import pandas # scope for faster 'import ultralytics'
import seaborn # scope for faster 'import ultralytics'
@ -580,16 +600,16 @@ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False,
Args:
xyxy (torch.Tensor or list): A tensor or list representing the bounding box in xyxy format.
im (numpy.ndarray): The input image.
file (Path, optional): The path where the cropped image will be saved. Defaults to 'im.jpg'.
gain (float, optional): A multiplicative factor to increase the size of the bounding box. Defaults to 1.02.
pad (int, optional): The number of pixels to add to the width and height of the bounding box. Defaults to 10.
square (bool, optional): If True, the bounding box will be transformed into a square. Defaults to False.
BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB. Defaults to False.
save (bool, optional): If True, the cropped image will be saved to disk. Defaults to True.
im (np.ndarray): The input image.
file (Path, optional): The path where the cropped image will be saved.
gain (float, optional): A multiplicative factor to increase the size of the bounding box.
pad (int, optional): The number of pixels to add to the width and height of the bounding box.
square (bool, optional): If True, the bounding box will be transformed into a square.
BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB.
save (bool, optional): If True, the cropped image will be saved to disk.
Returns:
(numpy.ndarray): The cropped image.
(np.ndarray): The cropped image.
Examples:
>>> from ultralytics.utils.plotting import save_one_box
@ -653,7 +673,7 @@ def plot_images(
conf_thres: Confidence threshold for displaying detections.
Returns:
np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise.
(np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
Note:
This function supports both tensor and numpy array inputs. It will automatically
@ -789,13 +809,12 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
Args:
file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
file (str, optional): Path to the CSV file containing the training results.
dir (str, optional): Directory where the CSV file is located if 'file' is not provided.
segment (bool, optional): Flag to indicate if the data is for segmentation.
pose (bool, optional): Flag to indicate if the data is for pose estimation.
classify (bool, optional): Flag to indicate if the data is for classification.
on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
Defaults to None.
Examples:
>>> from ultralytics.utils.plotting import plot_results
@ -845,15 +864,15 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
"""
Plots a scatter plot with points colored based on a 2D histogram.
Plot a scatter plot with points colored based on a 2D histogram.
Args:
v (array-like): Values for the x-axis.
f (array-like): Values for the y-axis.
bins (int, optional): Number of bins for the histogram. Defaults to 20.
cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.
alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.
edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.
bins (int, optional): Number of bins for the histogram.
cmap (str, optional): Colormap for the scatter plot.
alpha (float, optional): Alpha for the scatter plot.
edgecolors (str, optional): Edge colors for the scatter plot.
Examples:
>>> v = np.random.rand(100)
@ -880,7 +899,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
Args:
csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
csv_file (str, optional): Path to the CSV file containing the tuning results.
Examples:
>>> plot_tune_results("path/to/tune_results.csv")
@ -959,8 +978,8 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detec
x (torch.Tensor): Features to be visualized.
module_type (str): Module type.
stage (int): Module stage within the model.
n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
n (int, optional): Maximum number of feature maps to plot.
save_dir (Path, optional): Directory to save results.
"""
for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model heads
if m in module_type:

View file

@ -90,12 +90,12 @@ def autocast(enabled: bool, device: str = "cuda"):
Returns:
(torch.amp.autocast): The appropriate autocast context manager.
Note:
Notes:
- For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
- For older versions, it uses `torch.cuda.autocast`.
Examples:
>>> with autocast(amp=True):
>>> with autocast(enabled=True):
... # Your mixed precision operations here
... pass
"""
@ -130,7 +130,7 @@ def get_gpu_info(index):
def select_device(device="", batch=0, newline=False, verbose=True):
"""
Selects the appropriate PyTorch device based on the provided arguments.
Select the appropriate PyTorch device based on the provided arguments.
The function takes a string specifying the device or a torch.device object and returns a torch.device object
representing the selected device. The function also validates the number of available devices and raises an
@ -299,7 +299,18 @@ def fuse_deconv_and_bn(deconv, bn):
def model_info(model, detailed=False, verbose=True, imgsz=640):
"""Print and return detailed model information layer by layer."""
"""
Print and return detailed model information layer by layer.
Args:
model (nn.Module): Model to analyze.
detailed (bool, optional): Whether to print detailed layer information. Defaults to False.
verbose (bool, optional): Whether to print model information. Defaults to True.
imgsz (int | List, optional): Input image size. Defaults to 640.
Returns:
(Tuple[int, int, int, float]): Number of layers, parameters, gradients, and GFLOPs.
"""
if not verbose:
return
n_p = get_num_params(model) # number of parameters
@ -343,6 +354,12 @@ def model_info_for_loggers(trainer):
"""
Return model info dict with useful model information.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.
Returns:
(dict): Dictionary containing model parameters, GFLOPs, and inference speeds.
Examples:
YOLOv8n info for loggers
>>> results = {
@ -368,7 +385,16 @@ def model_info_for_loggers(trainer):
def get_flops(model, imgsz=640):
"""Return a YOLO model's FLOPs."""
"""
Return a YOLO model's FLOPs.
Args:
model (nn.Module): The model to calculate FLOPs for.
imgsz (int | List[int], optional): Input image size. Defaults to 640.
Returns:
(float): The model's FLOPs in billions.
"""
if not thop:
return 0.0 # if not installed return 0.0 GFLOPs
@ -392,7 +418,16 @@ def get_flops(model, imgsz=640):
def get_flops_with_torch_profiler(model, imgsz=640):
"""Compute model FLOPs (thop package alternative, but 2-10x slower unfortunately)."""
"""
Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).
Args:
model (nn.Module): The model to calculate FLOPs for.
imgsz (int | List[int], optional): Input image size. Defaults to 640.
Returns:
(float): The model's FLOPs in billions.
"""
if not TORCH_2_0: # torch profiler implemented in torch>=2.0
return 0.0
model = de_parallel(model)
@ -430,7 +465,18 @@ def initialize_weights(model):
def scale_img(img, ratio=1.0, same_shape=False, gs=32):
"""Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple."""
"""
Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple.
Args:
img (torch.Tensor): Input image tensor.
ratio (float, optional): Scaling ratio. Defaults to 1.0.
same_shape (bool, optional): Whether to maintain the same shape. Defaults to False.
gs (int, optional): Grid size for padding. Defaults to 32.
Returns:
(torch.Tensor): Scaled and padded image tensor.
"""
if ratio == 1.0:
return img
h, w = img.shape[2:]
@ -442,7 +488,15 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
def copy_attr(a, b, include=(), exclude=()):
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
"""
Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes.
Args:
a (object): Destination object to copy attributes to.
b (object): Source object to copy attributes from.
include (tuple, optional): Attributes to include. If empty, all attributes are included. Defaults to ().
exclude (tuple, optional): Attributes to exclude. Defaults to ().
"""
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith("_") or k in exclude:
continue
@ -451,7 +505,12 @@ def copy_attr(a, b, include=(), exclude=()):
def get_latest_opset():
"""Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity."""
"""
Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.
Returns:
(int): The ONNX opset version.
"""
if TORCH_1_13:
# If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
@ -461,27 +520,69 @@ def get_latest_opset():
def intersect_dicts(da, db, exclude=()):
"""Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
"""
Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.
Args:
da (dict): First dictionary.
db (dict): Second dictionary.
exclude (tuple, optional): Keys to exclude. Defaults to ().
Returns:
(dict): Dictionary of intersecting keys with matching shapes.
"""
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
def is_parallel(model):
"""Returns True if model is of type DP or DDP."""
"""
Returns True if model is of type DP or DDP.
Args:
model (nn.Module): Model to check.
Returns:
(bool): True if model is DataParallel or DistributedDataParallel.
"""
return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
def de_parallel(model):
"""De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
"""
De-parallelize a model: returns single-GPU model if model is of type DP or DDP.
Args:
model (nn.Module): Model to de-parallelize.
Returns:
(nn.Module): De-parallelized model.
"""
return model.module if is_parallel(model) else model
def one_cycle(y1=0.0, y2=1.0, steps=100):
"""Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
"""
Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.
Args:
y1 (float, optional): Initial value. Defaults to 0.0.
y2 (float, optional): Final value. Defaults to 1.0.
steps (int, optional): Number of steps. Defaults to 100.
Returns:
(function): Lambda function for computing the sinusoidal ramp.
"""
return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
def init_seeds(seed=0, deterministic=False):
"""Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html."""
"""
Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.
Args:
seed (int, optional): Random seed. Defaults to 0.
deterministic (bool, optional): Whether to set deterministic algorithms. Defaults to False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
@ -510,16 +611,30 @@ def unset_deterministic():
class ModelEMA:
"""
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
average of everything in the model state_dict (parameters and buffers).
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models.
Keeps a moving average of everything in the model state_dict (parameters and buffers).
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
To disable EMA set the `enabled` attribute to `False`.
Attributes:
ema (nn.Module): Copy of the model in evaluation mode.
updates (int): Number of EMA updates.
decay (function): Decay function that determines the EMA weight.
enabled (bool): Whether EMA is enabled.
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
"""Initialize EMA for 'model' with given arguments."""
"""
Initialize EMA for 'model' with given arguments.
Args:
model (nn.Module): Model to create EMA for.
decay (float, optional): Maximum EMA decay rate. Defaults to 0.9999.
tau (int, optional): EMA decay time constant. Defaults to 2000.
updates (int, optional): Initial number of updates. Defaults to 0.
"""
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
@ -528,7 +643,12 @@ class ModelEMA:
self.enabled = True
def update(self, model):
"""Update EMA parameters."""
"""
Update EMA parameters.
Args:
model (nn.Module): Model to update EMA from.
"""
if self.enabled:
self.updates += 1
d = self.decay(self.updates)
@ -541,7 +661,14 @@ class ModelEMA:
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
"""Updates attributes and saves stripped model with optimizer removed."""
"""
Updates attributes and saves stripped model with optimizer removed.
Args:
model (nn.Module): Model to update attributes from.
include (tuple, optional): Attributes to include. Defaults to ().
exclude (tuple, optional): Attributes to exclude. Defaults to ("process_group", "reducer").
"""
if self.enabled:
copy_attr(self.ema, model, include, exclude)
@ -551,9 +678,9 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict
Strip optimizer from 'f' to finalize training, optionally save as 's'.
Args:
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
f (str | Path): File path to model to strip the optimizer from. Defaults to 'best.pt'.
s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.
Returns:
(dict): The combined checkpoint dictionary.
@ -563,9 +690,6 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict
>>> from ultralytics.utils.torch_utils import strip_optimizer
>>> for f in Path("path/to/model/checkpoints").rglob("*.pt"):
>>> strip_optimizer(f)
Note:
Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]`
"""
try:
x = torch.load(f, map_location=torch.device("cpu"))
@ -613,7 +737,11 @@ def convert_optimizer_state_dict_to_fp16(state_dict):
"""
Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data.
Args:
state_dict (dict): Optimizer state dictionary.
Returns:
(dict): Converted optimizer state dictionary with FP16 tensors.
"""
for state in state_dict["state"].values():
for k, v in state.items():
@ -653,6 +781,16 @@ def profile(input, ops, n=10, device=None, max_num_obj=0):
"""
Ultralytics speed, memory and FLOPs profiler.
Args:
input (torch.Tensor | List[torch.Tensor]): Input tensor(s) to profile.
ops (nn.Module | List[nn.Module]): Model or list of operations to profile.
n (int, optional): Number of iterations to average. Defaults to 10.
device (str | torch.device, optional): Device to profile on. Defaults to None.
max_num_obj (int, optional): Maximum number of objects for simulation. Defaults to 0.
Returns:
(List): Profile results for each operation.
Examples:
>>> from ultralytics.utils.torch_utils import profile
>>> input = torch.randn(16, 3, 640, 640)
@ -721,7 +859,15 @@ def profile(input, ops, n=10, device=None, max_num_obj=0):
class EarlyStopping:
"""Early stopping class that stops training when a specified number of epochs have passed without improvement."""
"""
Early stopping class that stops training when a specified number of epochs have passed without improvement.
Attributes:
best_fitness (float): Best fitness value observed.
best_epoch (int): Epoch where best fitness was observed.
patience (int): Number of epochs to wait after fitness stops improving before stopping.
possible_stop (bool): Flag indicating if stopping may occur next epoch.
"""
def __init__(self, patience=50):
"""
@ -770,11 +916,12 @@ class FXModel(nn.Module):
"""
A custom model class for torch.fx compatibility.
This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph manipulation.
It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying.
This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
copying.
Args:
model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
Attributes:
model (nn.Module): The original model's layers.
"""
def __init__(self, model):
@ -782,7 +929,7 @@ class FXModel(nn.Module):
Initialize the FXModel.
Args:
model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
model (nn.Module): The original model to wrap for torch.fx compatibility.
"""
super().__init__()
copy_attr(self, model)
@ -793,7 +940,8 @@ class FXModel(nn.Module):
"""
Forward pass through the model.
This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs.
This method performs the forward pass through the model, handling the dependencies between layers and saving
intermediate outputs.
Args:
x (torch.Tensor): The input tensor to the model.