Update Google-style docstrings (#22565)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-11-03 05:12:30 +09:00 committed by GitHub
parent b9a1365450
commit 0aef7a9a51
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
159 changed files with 1931 additions and 3432 deletions

View file

@ -41,8 +41,7 @@ This example illustrates a Google-style docstring. Ensure that both input and ou
```python
def example_function(arg1, arg2=4):
"""
Example function demonstrating Google-style docstrings.
"""Example function demonstrating Google-style docstrings.
Args:
arg1 (int): The first argument.
@ -65,8 +64,7 @@ This example includes both a Google-style docstring and [type hints](https://doc
```python
def example_function(arg1: int, arg2: int = 4) -> bool:
"""
Example function demonstrating Google-style docstrings.
"""Example function demonstrating Google-style docstrings.
Args:
arg1: The first argument.

View file

@ -262,8 +262,7 @@ def remove_macros():
def remove_comments_and_empty_lines(content: str, file_type: str) -> str:
"""
Remove comments and empty lines from a string of code, preserving newlines and URLs.
"""Remove comments and empty lines from a string of code, preserving newlines and URLs.
Args:
content (str): Code content to process.

View file

@ -401,8 +401,7 @@ import ros_numpy
def pointcloud2_to_array(pointcloud2: PointCloud2) -> tuple:
"""
Convert a ROS PointCloud2 message to a numpy array.
"""Convert a ROS PointCloud2 message to a numpy array.
Args:
pointcloud2 (PointCloud2): the PointCloud2 message
@ -472,8 +471,7 @@ for index, class_id in enumerate(classes):
def pointcloud2_to_array(pointcloud2: PointCloud2) -> tuple:
"""
Convert a ROS PointCloud2 message to a numpy array.
"""Convert a ROS PointCloud2 message to a numpy array.
Args:
pointcloud2 (PointCloud2): the PointCloud2 message

View file

@ -58,8 +58,7 @@ When adding new functions or classes, include [Google-style docstrings](https://
```python
def example_function(arg1, arg2=4):
"""
Example function demonstrating Google-style docstrings.
"""Example function demonstrating Google-style docstrings.
Args:
arg1 (int): The first argument.
@ -81,8 +80,7 @@ When adding new functions or classes, include [Google-style docstrings](https://
```python
def example_function(arg1, arg2=4):
"""
Example function demonstrating Google-style docstrings.
"""Example function demonstrating Google-style docstrings.
Args:
arg1 (int): The first argument.
@ -104,8 +102,7 @@ When adding new functions or classes, include [Google-style docstrings](https://
```python
def example_function(arg1, arg2=4):
"""
Example function demonstrating Google-style docstrings.
"""Example function demonstrating Google-style docstrings.
Args:
arg1 (int): The first argument.
@ -146,8 +143,7 @@ When adding new functions or classes, include [Google-style docstrings](https://
```python
def example_function(arg1: int, arg2: int = 4) -> bool:
"""
Example function demonstrating Google-style docstrings.
"""Example function demonstrating Google-style docstrings.
Args:
arg1: The first argument.

View file

@ -359,8 +359,7 @@ Finally, after all threads have completed their task, the windows displaying the
def run_tracker_in_thread(model_name, filename):
"""
Run YOLO tracker in its own thread for concurrent processing.
"""Run YOLO tracker in its own thread for concurrent processing.
Args:
model_name (str): The YOLO11 model object.
@ -449,8 +448,7 @@ To run object tracking on multiple video streams simultaneously, you can use Pyt
def run_tracker_in_thread(model_name, filename):
"""
Run YOLO tracker in its own thread for concurrent processing.
"""Run YOLO tracker in its own thread for concurrent processing.
Args:
model_name (str): The YOLO11 model object.

View file

@ -13,8 +13,7 @@ import yaml
def download_file(url: str, local_path: str) -> str:
"""
Download a file from a URL to a local path.
"""Download a file from a URL to a local path.
Args:
url (str): URL of the file to download.
@ -34,8 +33,7 @@ def download_file(url: str, local_path: str) -> str:
class RTDETR:
"""
RT-DETR (Real-Time Detection Transformer) object detection model for ONNX inference and visualization.
"""RT-DETR (Real-Time Detection Transformer) object detection model for ONNX inference and visualization.
This class implements the RT-DETR model for object detection tasks, supporting ONNX model inference and
visualization of detection results with bounding boxes and class labels.
@ -77,16 +75,15 @@ class RTDETR:
iou_thres: float = 0.5,
class_names: str | None = None,
):
"""
Initialize the RT-DETR object detection model.
"""Initialize the RT-DETR object detection model.
Args:
model_path (str): Path to the ONNX model file.
img_path (str): Path to the input image.
conf_thres (float, optional): Confidence threshold for filtering detections.
iou_thres (float, optional): IoU threshold for non-maximum suppression.
class_names (Optional[str], optional): Path to a YAML file containing class names.
If None, uses COCO dataset classes.
class_names (Optional[str], optional): Path to a YAML file containing class names. If None, uses COCO
dataset classes.
"""
self.model_path = model_path
self.img_path = img_path
@ -157,8 +154,7 @@ class RTDETR:
)
def preprocess(self) -> np.ndarray:
"""
Preprocess the input image for model inference.
"""Preprocess the input image for model inference.
Loads the image, converts color space from BGR to RGB, resizes to model input dimensions, and normalizes pixel
values to [0, 1] range.
@ -190,12 +186,11 @@ class RTDETR:
return image_data
def bbox_cxcywh_to_xyxy(self, boxes: np.ndarray) -> np.ndarray:
"""
Convert bounding boxes from center format to corner format.
"""Convert bounding boxes from center format to corner format.
Args:
boxes (np.ndarray): Array of shape (N, 4) where each row represents a bounding box in
(center_x, center_y, width, height) format.
boxes (np.ndarray): Array of shape (N, 4) where each row represents a bounding box in (center_x, center_y,
width, height) format.
Returns:
(np.ndarray): Array of shape (N, 4) with bounding boxes in (x_min, y_min, x_max, y_max) format.
@ -214,8 +209,7 @@ class RTDETR:
return np.column_stack((x_min, y_min, x_max, y_max))
def postprocess(self, model_output: list[np.ndarray]) -> np.ndarray:
"""
Postprocess model output to extract and visualize detections.
"""Postprocess model output to extract and visualize detections.
Applies confidence thresholding, converts bounding box format, scales coordinates to original image dimensions,
and draws detection annotations.
@ -255,8 +249,7 @@ class RTDETR:
return self.img
def main(self) -> np.ndarray:
"""
Execute the complete object detection pipeline on the input image.
"""Execute the complete object detection pipeline on the input image.
Performs preprocessing, ONNX model inference, and postprocessing to generate annotated detection results.

View file

@ -55,8 +55,7 @@ selected_center = None
def get_center(x1: int, y1: int, x2: int, y2: int) -> tuple[int, int]:
"""
Calculate the center point of a bounding box.
"""Calculate the center point of a bounding box.
Args:
x1 (int): Top-left X coordinate.
@ -72,8 +71,7 @@ def get_center(x1: int, y1: int, x2: int, y2: int) -> tuple[int, int]:
def extend_line_from_edge(mid_x: int, mid_y: int, direction: str, img_shape: tuple[int, int, int]) -> tuple[int, int]:
"""
Calculate the endpoint to extend a line from the center toward an image edge.
"""Calculate the endpoint to extend a line from the center toward an image edge.
Args:
mid_x (int): X-coordinate of the midpoint.
@ -99,8 +97,7 @@ def extend_line_from_edge(mid_x: int, mid_y: int, direction: str, img_shape: tup
def draw_tracking_scope(im, bbox: tuple, color: tuple) -> None:
"""
Draw tracking scope lines extending from the bounding box to image edges.
"""Draw tracking scope lines extending from the bounding box to image edges.
Args:
im (np.ndarray): Image array to draw on.
@ -119,8 +116,7 @@ def draw_tracking_scope(im, bbox: tuple, color: tuple) -> None:
def click_event(event: int, x: int, y: int, flags: int, param) -> None:
"""
Handle mouse click events to select an object for focused tracking.
"""Handle mouse click events to select an object for focused tracking.
Args:
event (int): OpenCV mouse event type.

View file

@ -19,8 +19,7 @@ from ultralytics.utils.torch_utils import select_device
class TorchVisionVideoClassifier:
"""
Video classifier using pretrained TorchVision models for action recognition.
"""Video classifier using pretrained TorchVision models for action recognition.
This class provides an interface for video classification using various pretrained models from TorchVision's video
model collection, supporting models like S3D, R3D, Swin3D, and MViT architectures.
@ -72,8 +71,7 @@ class TorchVisionVideoClassifier:
}
def __init__(self, model_name: str, device: str | torch.device = ""):
"""
Initialize the VideoClassifier with the specified model name and device.
"""Initialize the VideoClassifier with the specified model name and device.
Args:
model_name (str): The name of the model to use. Must be one of the available models.
@ -87,8 +85,7 @@ class TorchVisionVideoClassifier:
@staticmethod
def available_model_names() -> list[str]:
"""
Get the list of available model names.
"""Get the list of available model names.
Returns:
(list[str]): List of available model names that can be used with this classifier.
@ -98,8 +95,7 @@ class TorchVisionVideoClassifier:
def preprocess_crops_for_video_cls(
self, crops: list[np.ndarray], input_size: list[int] | None = None
) -> torch.Tensor:
"""
Preprocess a list of crops for video classification.
"""Preprocess a list of crops for video classification.
Args:
crops (list[np.ndarray]): List of crops to preprocess. Each crop should have dimensions (H, W, C).
@ -124,8 +120,7 @@ class TorchVisionVideoClassifier:
return torch.stack(processed_crops).unsqueeze(0).permute(0, 2, 1, 3, 4).to(self.device)
def __call__(self, sequences: torch.Tensor) -> torch.Tensor:
"""
Perform inference on the given sequences.
"""Perform inference on the given sequences.
Args:
sequences (torch.Tensor): The input sequences for the model with dimensions (B, T, C, H, W) for batched
@ -138,8 +133,7 @@ class TorchVisionVideoClassifier:
return self.model(sequences)
def postprocess(self, outputs: torch.Tensor) -> tuple[list[str], list[float]]:
"""
Postprocess the model's batch output.
"""Postprocess the model's batch output.
Args:
outputs (torch.Tensor): The model's output logits.
@ -161,8 +155,7 @@ class TorchVisionVideoClassifier:
class HuggingFaceVideoClassifier:
"""
Zero-shot video classifier using Hugging Face transformer models.
"""Zero-shot video classifier using Hugging Face transformer models.
This class provides an interface for zero-shot video classification using Hugging Face models, supporting custom
label sets and various transformer architectures for video understanding.
@ -195,8 +188,7 @@ class HuggingFaceVideoClassifier:
device: str | torch.device = "",
fp16: bool = False,
):
"""
Initialize the HuggingFaceVideoClassifier with the specified model name.
"""Initialize the HuggingFaceVideoClassifier with the specified model name.
Args:
labels (list[str]): List of labels for zero-shot classification.
@ -216,8 +208,7 @@ class HuggingFaceVideoClassifier:
def preprocess_crops_for_video_cls(
self, crops: list[np.ndarray], input_size: list[int] | None = None
) -> torch.Tensor:
"""
Preprocess a list of crops for video classification.
"""Preprocess a list of crops for video classification.
Args:
crops (list[np.ndarray]): List of crops to preprocess. Each crop should have dimensions (H, W, C).
@ -247,8 +238,7 @@ class HuggingFaceVideoClassifier:
return output
def __call__(self, sequences: torch.Tensor) -> torch.Tensor:
"""
Perform inference on the given sequences.
"""Perform inference on the given sequences.
Args:
sequences (torch.Tensor): Batched input video frames with shape (B, T, H, W, C).
@ -266,8 +256,7 @@ class HuggingFaceVideoClassifier:
return outputs.logits_per_video
def postprocess(self, outputs: torch.Tensor) -> tuple[list[list[str]], list[list[float]]]:
"""
Postprocess the model's batch output.
"""Postprocess the model's batch output.
Args:
outputs (torch.Tensor): The model's output logits.
@ -294,8 +283,7 @@ class HuggingFaceVideoClassifier:
def crop_and_pad(frame: np.ndarray, box: list[float], margin_percent: int) -> np.ndarray:
"""
Crop box with margin and take square crop from frame.
"""Crop box with margin and take square crop from frame.
Args:
frame (np.ndarray): The input frame to crop from.
@ -338,8 +326,7 @@ def run(
video_classifier_model: str = "microsoft/xclip-base-patch32",
labels: list[str] | None = None,
) -> None:
"""
Run action recognition on a video source using YOLO for object detection and a video classifier.
"""Run action recognition on a video source using YOLO for object detection and a video classifier.
Args:
weights (str): Path to the YOLO model weights.

View file

@ -14,8 +14,7 @@ from ultralytics.utils.checks import check_requirements, check_yaml
class YOLOv8:
"""
YOLOv8 object detection model class for handling ONNX inference and visualization.
"""YOLOv8 object detection model class for handling ONNX inference and visualization.
This class provides functionality to load a YOLOv8 ONNX model, perform inference on images, and visualize the
detection results with bounding boxes and labels.
@ -47,8 +46,7 @@ class YOLOv8:
"""
def __init__(self, onnx_model: str, input_image: str, confidence_thres: float, iou_thres: float):
"""
Initialize an instance of the YOLOv8 class.
"""Initialize an instance of the YOLOv8 class.
Args:
onnx_model (str): Path to the ONNX model.
@ -68,8 +66,7 @@ class YOLOv8:
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
def letterbox(self, img: np.ndarray, new_shape: tuple[int, int] = (640, 640)) -> tuple[np.ndarray, tuple[int, int]]:
"""
Resize and reshape images while maintaining aspect ratio by adding padding.
"""Resize and reshape images while maintaining aspect ratio by adding padding.
Args:
img (np.ndarray): Input image to be resized.
@ -126,8 +123,7 @@ class YOLOv8:
cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
def preprocess(self) -> tuple[np.ndarray, tuple[int, int]]:
"""
Preprocess the input image before performing inference.
"""Preprocess the input image before performing inference.
This method reads the input image, converts its color space, applies letterboxing to maintain aspect ratio,
normalizes pixel values, and prepares the image data for model input.
@ -160,8 +156,7 @@ class YOLOv8:
return image_data, pad
def postprocess(self, input_image: np.ndarray, output: list[np.ndarray], pad: tuple[int, int]) -> np.ndarray:
"""
Perform post-processing on the model's output to extract and visualize detections.
"""Perform post-processing on the model's output to extract and visualize detections.
This method processes the raw model output to extract bounding boxes, scores, and class IDs. It applies
non-maximum suppression to filter overlapping detections and draws the results on the input image.
@ -234,8 +229,7 @@ class YOLOv8:
return input_image
def main(self) -> np.ndarray:
"""
Perform inference using an ONNX model and return the output image with drawn detections.
"""Perform inference using an ONNX model and return the output image with drawn detections.
Returns:
(np.ndarray): The output image with drawn detections.

View file

@ -18,8 +18,7 @@ colors = np.random.uniform(0, 255, size=(len(CLASSES), 3))
def draw_bounding_box(
img: np.ndarray, class_id: int, confidence: float, x: int, y: int, x_plus_w: int, y_plus_h: int
) -> None:
"""
Draw bounding boxes on the input image based on the provided arguments.
"""Draw bounding boxes on the input image based on the provided arguments.
Args:
img (np.ndarray): The input image to draw the bounding box on.
@ -37,8 +36,7 @@ def draw_bounding_box(
def main(onnx_model: str, input_image: str) -> list[dict[str, Any]]:
"""
Load ONNX model, perform inference, draw bounding boxes, and display the output image.
"""Load ONNX model, perform inference, draw bounding boxes, and display the output image.
Args:
onnx_model (str): Path to the ONNX model.

View file

@ -40,8 +40,7 @@ counting_regions = [
def mouse_callback(event: int, x: int, y: int, flags: int, param: Any) -> None:
"""
Handle mouse events for region manipulation in the video frame.
"""Handle mouse events for region manipulation in the video frame.
This function enables interactive region selection and dragging functionality for counting regions. It responds to
mouse button down, move, and up events to allow users to select and reposition counting regions in real-time.
@ -97,8 +96,7 @@ def run(
track_thickness: int = 2,
region_thickness: int = 2,
) -> None:
"""
Run object detection and counting within specified regions using YOLO and ByteTrack.
"""Run object detection and counting within specified regions using YOLO and ByteTrack.
This function performs real-time object detection, tracking, and counting within user-defined polygonal or
rectangular regions. It supports interactive region manipulation, multiple counting areas, and both live viewing and

View file

@ -11,8 +11,7 @@ from ultralytics.utils.files import increment_path
class SAHIInference:
"""
Runs Ultralytics YOLO11 and SAHI for object detection on video with options to view, save, and track results.
"""Runs Ultralytics YOLO11 and SAHI for object detection on video with options to view, save, and track results.
This class integrates SAHI (Slicing Aided Hyper Inference) with YOLO11 models to perform efficient object detection
on large images by slicing them into smaller pieces, running inference on each slice, and then merging the results.
@ -36,8 +35,7 @@ class SAHIInference:
self.detection_model = None
def load_model(self, weights: str, device: str) -> None:
"""
Load a YOLO11 model with specified weights for object detection using SAHI.
"""Load a YOLO11 model with specified weights for object detection using SAHI.
Args:
weights (str): Path to the model weights file.
@ -63,8 +61,7 @@ class SAHIInference:
slice_width: int = 512,
slice_height: int = 512,
) -> None:
"""
Run object detection on a video using YOLO11 and SAHI.
"""Run object detection on a video using YOLO11 and SAHI.
The function processes each frame of the video, applies sliced inference using SAHI, and optionally displays
and/or saves the results with bounding boxes and labels.
@ -123,8 +120,7 @@ class SAHIInference:
@staticmethod
def parse_opt() -> argparse.Namespace:
"""
Parse command line arguments for the inference process.
"""Parse command line arguments for the inference process.
Returns:
(argparse.Namespace): Parsed command line arguments.

View file

@ -15,8 +15,7 @@ from ultralytics.utils.checks import check_yaml
class YOLOv8Seg:
"""
YOLOv8 segmentation model for performing instance segmentation using ONNX Runtime.
"""YOLOv8 segmentation model for performing instance segmentation using ONNX Runtime.
This class implements a YOLOv8 instance segmentation model using ONNX Runtime for inference. It handles
preprocessing of input images, running inference with the ONNX model, and postprocessing the results to generate
@ -43,15 +42,14 @@ class YOLOv8Seg:
"""
def __init__(self, onnx_model: str, conf: float = 0.25, iou: float = 0.7, imgsz: int | tuple[int, int] = 640):
"""
Initialize the instance segmentation model using an ONNX model.
"""Initialize the instance segmentation model using an ONNX model.
Args:
onnx_model (str): Path to the ONNX model file.
conf (float, optional): Confidence threshold for filtering detections.
iou (float, optional): IoU threshold for non-maximum suppression.
imgsz (int | tuple[int, int], optional): Input image size of the model. Can be an integer for square
input or a tuple for rectangular input.
imgsz (int | tuple[int, int], optional): Input image size of the model. Can be an integer for square input
or a tuple for rectangular input.
"""
self.session = ort.InferenceSession(
onnx_model,
@ -66,8 +64,7 @@ class YOLOv8Seg:
self.iou = iou
def __call__(self, img: np.ndarray) -> list[Results]:
"""
Run inference on the input image using the ONNX model.
"""Run inference on the input image using the ONNX model.
Args:
img (np.ndarray): The original input image in BGR format.
@ -81,8 +78,7 @@ class YOLOv8Seg:
return self.postprocess(img, prep_img, outs)
def letterbox(self, img: np.ndarray, new_shape: tuple[int, int] = (640, 640)) -> np.ndarray:
"""
Resize and pad image while maintaining aspect ratio.
"""Resize and pad image while maintaining aspect ratio.
Args:
img (np.ndarray): Input image in BGR format.
@ -109,16 +105,15 @@ class YOLOv8Seg:
return img
def preprocess(self, img: np.ndarray, new_shape: tuple[int, int]) -> np.ndarray:
"""
Preprocess the input image before feeding it into the model.
"""Preprocess the input image before feeding it into the model.
Args:
img (np.ndarray): The input image in BGR format.
new_shape (tuple[int, int]): The target shape for resizing as (height, width).
Returns:
(np.ndarray): Preprocessed image ready for model inference, with shape (1, 3, height, width) and
normalized to [0, 1].
(np.ndarray): Preprocessed image ready for model inference, with shape (1, 3, height, width) and normalized
to [0, 1].
"""
img = self.letterbox(img, new_shape)
img = img[..., ::-1].transpose([2, 0, 1])[None] # BGR to RGB, BHWC to BCHW
@ -127,8 +122,7 @@ class YOLOv8Seg:
return img
def postprocess(self, img: np.ndarray, prep_img: np.ndarray, outs: list) -> list[Results]:
"""
Post-process model predictions to extract meaningful results.
"""Post-process model predictions to extract meaningful results.
Args:
img (np.ndarray): The original input image.
@ -152,8 +146,7 @@ class YOLOv8Seg:
def process_mask(
self, protos: torch.Tensor, masks_in: torch.Tensor, bboxes: torch.Tensor, shape: tuple[int, int]
) -> torch.Tensor:
"""
Process prototype masks with predicted mask coefficients to generate instance segmentation masks.
"""Process prototype masks with predicted mask coefficients to generate instance segmentation masks.
Args:
protos (torch.Tensor): Prototype masks with shape (mask_dim, mask_h, mask_w).

View file

@ -19,8 +19,7 @@ except ImportError:
class YOLOv8TFLite:
"""
A YOLOv8 object detection class using TensorFlow Lite for efficient inference.
"""A YOLOv8 object detection class using TensorFlow Lite for efficient inference.
This class handles model loading, preprocessing, inference, and visualization of detection results for YOLOv8 models
converted to TensorFlow Lite format.
@ -56,8 +55,7 @@ class YOLOv8TFLite:
"""
def __init__(self, model: str, conf: float = 0.25, iou: float = 0.45, metadata: str | None = None):
"""
Initialize the YOLOv8TFLite detector.
"""Initialize the YOLOv8TFLite detector.
Args:
model (str): Path to the TFLite model file.
@ -94,8 +92,7 @@ class YOLOv8TFLite:
def letterbox(
self, img: np.ndarray, new_shape: tuple[int, int] = (640, 640)
) -> tuple[np.ndarray, tuple[float, float]]:
"""
Resize and pad image while maintaining aspect ratio.
"""Resize and pad image while maintaining aspect ratio.
Args:
img (np.ndarray): Input image with shape (H, W, C).
@ -123,8 +120,7 @@ class YOLOv8TFLite:
return img, (top / img.shape[0], left / img.shape[1])
def draw_detections(self, img: np.ndarray, box: np.ndarray, score: np.float32, class_id: int) -> None:
"""
Draw bounding boxes and labels on the input image based on detected objects.
"""Draw bounding boxes and labels on the input image based on detected objects.
Args:
img (np.ndarray): The input image to draw detections on.
@ -161,8 +157,7 @@ class YOLOv8TFLite:
cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
def preprocess(self, img: np.ndarray) -> tuple[np.ndarray, tuple[float, float]]:
"""
Preprocess the input image before performing inference.
"""Preprocess the input image before performing inference.
Args:
img (np.ndarray): The input image to be preprocessed with shape (H, W, C).
@ -178,8 +173,7 @@ class YOLOv8TFLite:
return img / 255, pad # Normalize to [0, 1]
def postprocess(self, img: np.ndarray, outputs: np.ndarray, pad: tuple[float, float]) -> np.ndarray:
"""
Process model outputs to extract and visualize detections.
"""Process model outputs to extract and visualize detections.
Args:
img (np.ndarray): The original input image.
@ -216,8 +210,7 @@ class YOLOv8TFLite:
return img
def detect(self, img_path: str) -> np.ndarray:
"""
Perform object detection on an input image.
"""Perform object detection on an input image.
Args:
img_path (str): Path to the input image file.

View file

@ -10,8 +10,7 @@ def pytest_addoption(parser):
def pytest_collection_modifyitems(config, items):
"""
Modify the list of test items to exclude tests marked as slow if the --slow option is not specified.
"""Modify the list of test items to exclude tests marked as slow if the --slow option is not specified.
Args:
config: The pytest configuration object that provides access to command-line options.
@ -23,8 +22,7 @@ def pytest_collection_modifyitems(config, items):
def pytest_sessionstart(session):
"""
Initialize session configurations for pytest.
"""Initialize session configurations for pytest.
This function is automatically called by pytest after the 'Session' object has been created but before performing
test collection. It sets the initial seeds for the test session.
@ -38,8 +36,7 @@ def pytest_sessionstart(session):
def pytest_terminal_summary(terminalreporter, exitstatus, config):
"""
Cleanup operations after pytest session.
"""Cleanup operations after pytest session.
This function is automatically called by pytest at the end of the entire test session. It removes certain files and
directories used during testing.

View file

@ -175,8 +175,7 @@ def test_youtube():
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
@pytest.mark.parametrize("model", MODELS)
def test_track_stream(model, tmp_path):
"""
Test streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.
"""Test streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.
Note imgsz=160 required for tracking for higher confidence and better matches.
"""

View file

@ -242,12 +242,11 @@ CFG_BOOL_KEYS = frozenset(
def cfg2dict(cfg: str | Path | dict | SimpleNamespace) -> dict:
"""
Convert a configuration object to a dictionary.
"""Convert a configuration object to a dictionary.
Args:
cfg (str | Path | dict | SimpleNamespace): Configuration object to be converted. Can be a file path,
a string, a dictionary, or a SimpleNamespace object.
cfg (str | Path | dict | SimpleNamespace): Configuration object to be converted. Can be a file path, a string, a
dictionary, or a SimpleNamespace object.
Returns:
(dict): Configuration object in dictionary format.
@ -279,8 +278,7 @@ def cfg2dict(cfg: str | Path | dict | SimpleNamespace) -> dict:
def get_cfg(
cfg: str | Path | dict | SimpleNamespace = DEFAULT_CFG_DICT, overrides: dict | None = None
) -> SimpleNamespace:
"""
Load and merge configuration data from a file or dictionary, with optional overrides.
"""Load and merge configuration data from a file or dictionary, with optional overrides.
Args:
cfg (str | Path | dict | SimpleNamespace): Configuration data source. Can be a file path, dictionary, or
@ -327,8 +325,7 @@ def get_cfg(
def check_cfg(cfg: dict, hard: bool = True) -> None:
"""
Check configuration argument types and values for the Ultralytics library.
"""Check configuration argument types and values for the Ultralytics library.
This function validates the types and values of configuration arguments, ensuring correctness and converting them if
necessary. It checks for specific key types defined in global variables such as `CFG_FLOAT_KEYS`,
@ -389,14 +386,13 @@ def check_cfg(cfg: dict, hard: bool = True) -> None:
def get_save_dir(args: SimpleNamespace, name: str | None = None) -> Path:
"""
Return the directory path for saving outputs, derived from arguments or default settings.
"""Return the directory path for saving outputs, derived from arguments or default settings.
Args:
args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task',
'mode', and 'save_dir'.
name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name'
or the 'args.mode'.
args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task', 'mode',
and 'save_dir'.
name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name' or the
'args.mode'.
Returns:
(Path): Directory path where outputs should be saved.
@ -421,8 +417,7 @@ def get_save_dir(args: SimpleNamespace, name: str | None = None) -> Path:
def _handle_deprecation(custom: dict) -> dict:
"""
Handle deprecated configuration keys by mapping them to current equivalents with deprecation warnings.
"""Handle deprecated configuration keys by mapping them to current equivalents with deprecation warnings.
Args:
custom (dict): Configuration dictionary potentially containing deprecated keys.
@ -465,8 +460,7 @@ def _handle_deprecation(custom: dict) -> dict:
def check_dict_alignment(base: dict, custom: dict, e: Exception | None = None) -> None:
"""
Check alignment between custom and base configuration dictionaries, handling deprecated keys and providing error
"""Check alignment between custom and base configuration dictionaries, handling deprecated keys and providing error
messages for mismatched keys.
Args:
@ -505,8 +499,7 @@ def check_dict_alignment(base: dict, custom: dict, e: Exception | None = None) -
def merge_equals_args(args: list[str]) -> list[str]:
"""
Merge arguments around isolated '=' in a list of strings and join fragments with brackets.
"""Merge arguments around isolated '=' in a list of strings and join fragments with brackets.
This function handles the following cases:
1. ['arg', '=', 'val'] becomes ['arg=val']
@ -565,15 +558,14 @@ def merge_equals_args(args: list[str]) -> list[str]:
def handle_yolo_hub(args: list[str]) -> None:
"""
Handle Ultralytics HUB command-line interface (CLI) commands for authentication.
"""Handle Ultralytics HUB command-line interface (CLI) commands for authentication.
This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a
script with arguments related to HUB authentication.
Args:
args (list[str]): A list of command line arguments. The first argument should be either 'login'
or 'logout'. For 'login', an optional second argument can be the API key.
args (list[str]): A list of command line arguments. The first argument should be either 'login' or 'logout'. For
'login', an optional second argument can be the API key.
Examples:
$ yolo login YOUR_API_KEY
@ -595,8 +587,7 @@ def handle_yolo_hub(args: list[str]) -> None:
def handle_yolo_settings(args: list[str]) -> None:
"""
Handle YOLO settings command-line interface (CLI) commands.
"""Handle YOLO settings command-line interface (CLI) commands.
This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be
called when executing a script with arguments related to YOLO settings management.
@ -638,8 +629,7 @@ def handle_yolo_settings(args: list[str]) -> None:
def handle_yolo_solutions(args: list[str]) -> None:
"""
Process YOLO solutions arguments and run the specified computer vision solutions pipeline.
"""Process YOLO solutions arguments and run the specified computer vision solutions pipeline.
Args:
args (list[str]): Command-line arguments for configuring and running the Ultralytics YOLO solutions.
@ -748,8 +738,7 @@ def handle_yolo_solutions(args: list[str]) -> None:
def parse_key_value_pair(pair: str = "key=value") -> tuple:
"""
Parse a key-value pair string into separate key and value components.
"""Parse a key-value pair string into separate key and value components.
Args:
pair (str): A string containing a key-value pair in the format "key=value".
@ -782,8 +771,7 @@ def parse_key_value_pair(pair: str = "key=value") -> tuple:
def smart_value(v: str) -> Any:
"""
Convert a string representation of a value to its appropriate Python type.
"""Convert a string representation of a value to its appropriate Python type.
This function attempts to convert a given string into a Python object of the most appropriate type. It handles
conversions to None, bool, int, float, and other types that can be evaluated safely.
@ -792,8 +780,8 @@ def smart_value(v: str) -> Any:
v (str): The string representation of the value to be converted.
Returns:
(Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion
is applicable.
(Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion is
applicable.
Examples:
>>> smart_value("42")
@ -827,8 +815,7 @@ def smart_value(v: str) -> Any:
def entrypoint(debug: str = "") -> None:
"""
Ultralytics entrypoint function for parsing and executing command-line arguments.
"""Ultralytics entrypoint function for parsing and executing command-line arguments.
This function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments and executing
the corresponding tasks such as training, validation, prediction, exporting models, and more.
@ -1000,8 +987,7 @@ def entrypoint(debug: str = "") -> None:
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_cfg() -> None:
"""
Copy the default configuration file and create a new one with '_copy' appended to its name.
"""Copy the default configuration file and create a new one with '_copy' appended to its name.
This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it with '_copy'
appended to its name in the current working directory. It provides a convenient way to create a custom configuration

View file

@ -19,8 +19,7 @@ def auto_annotate(
classes: list[int] | None = None,
output_dir: str | Path | None = None,
) -> None:
"""
Automatically annotate images using a YOLO object detection model and a SAM segmentation model.
"""Automatically annotate images using a YOLO object detection model and a SAM segmentation model.
This function processes images in a specified directory, detects objects using a YOLO model, and then generates
segmentation masks using a SAM model. The resulting annotations are saved as text files in YOLO format.
@ -35,8 +34,8 @@ def auto_annotate(
imgsz (int): Input image resize dimension.
max_det (int): Maximum number of detections per image.
classes (list[int], optional): Filter predictions to specified class IDs, returning only relevant detections.
output_dir (str | Path, optional): Directory to save the annotated results. If None, creates a default
directory based on the input data path.
output_dir (str | Path, optional): Directory to save the annotated results. If None, creates a default directory
based on the input data path.
Examples:
>>> from ultralytics.data.annotator import auto_annotate

View file

@ -26,8 +26,7 @@ DEFAULT_STD = (1.0, 1.0, 1.0)
class BaseTransform:
"""
Base class for image transformations in the Ultralytics library.
"""Base class for image transformations in the Ultralytics library.
This class serves as a foundation for implementing various image processing operations, designed to be compatible
with both classification and semantic segmentation tasks.
@ -45,8 +44,7 @@ class BaseTransform:
"""
def __init__(self) -> None:
"""
Initialize the BaseTransform object.
"""Initialize the BaseTransform object.
This constructor sets up the base transformation object, which can be extended for specific image processing
tasks. It is designed to be compatible with both classification and semantic segmentation.
@ -57,15 +55,14 @@ class BaseTransform:
pass
def apply_image(self, labels):
"""
Apply image transformations to labels.
"""Apply image transformations to labels.
This method is intended to be overridden by subclasses to implement specific image transformation
logic. In its base form, it returns the input labels unchanged.
Args:
labels (Any): The input labels to be transformed. The exact type and structure of labels may
vary depending on the specific implementation.
labels (Any): The input labels to be transformed. The exact type and structure of labels may vary depending
on the specific implementation.
Returns:
(Any): The transformed labels. In the base implementation, this is identical to the input.
@ -80,8 +77,7 @@ class BaseTransform:
pass
def apply_instances(self, labels):
"""
Apply transformations to object instances in labels.
"""Apply transformations to object instances in labels.
This method is responsible for applying various transformations to object instances within the given
labels. It is designed to be overridden by subclasses to implement specific instance transformation
@ -101,8 +97,7 @@ class BaseTransform:
pass
def apply_semantic(self, labels):
"""
Apply semantic segmentation transformations to an image.
"""Apply semantic segmentation transformations to an image.
This method is intended to be overridden by subclasses to implement specific semantic segmentation
transformations. In its base form, it does not perform any operations.
@ -121,16 +116,15 @@ class BaseTransform:
pass
def __call__(self, labels):
"""
Apply all label transformations to an image, instances, and semantic masks.
"""Apply all label transformations to an image, instances, and semantic masks.
This method orchestrates the application of various transformations defined in the BaseTransform class to the
input labels. It sequentially calls the apply_image and apply_instances methods to process the image and object
instances, respectively.
Args:
labels (dict): A dictionary containing image data and annotations. Expected keys include 'img' for
the image data, and 'instances' for object instances.
labels (dict): A dictionary containing image data and annotations. Expected keys include 'img' for the image
data, and 'instances' for object instances.
Returns:
(dict): The input labels dictionary with transformed image and instances.
@ -146,8 +140,7 @@ class BaseTransform:
class Compose:
"""
A class for composing multiple image transformations.
"""A class for composing multiple image transformations.
Attributes:
transforms (list[Callable]): A list of transformation functions to be applied sequentially.
@ -169,8 +162,7 @@ class Compose:
"""
def __init__(self, transforms):
"""
Initialize the Compose object with a list of transforms.
"""Initialize the Compose object with a list of transforms.
Args:
transforms (list[Callable]): A list of callable transform objects to be applied sequentially.
@ -183,14 +175,13 @@ class Compose:
self.transforms = transforms if isinstance(transforms, list) else [transforms]
def __call__(self, data):
"""
Apply a series of transformations to input data.
"""Apply a series of transformations to input data.
This method sequentially applies each transformation in the Compose object's transforms to the input data.
Args:
data (Any): The input data to be transformed. This can be of any type, depending on the
transformations in the list.
data (Any): The input data to be transformed. This can be of any type, depending on the transformations in
the list.
Returns:
(Any): The transformed data after applying all transformations in sequence.
@ -205,8 +196,7 @@ class Compose:
return data
def append(self, transform):
"""
Append a new transform to the existing list of transforms.
"""Append a new transform to the existing list of transforms.
Args:
transform (BaseTransform): The transformation to be added to the composition.
@ -218,8 +208,7 @@ class Compose:
self.transforms.append(transform)
def insert(self, index, transform):
"""
Insert a new transform at a specified index in the existing list of transforms.
"""Insert a new transform at a specified index in the existing list of transforms.
Args:
index (int): The index at which to insert the new transform.
@ -234,8 +223,7 @@ class Compose:
self.transforms.insert(index, transform)
def __getitem__(self, index: list | int) -> Compose:
"""
Retrieve a specific transform or a set of transforms using indexing.
"""Retrieve a specific transform or a set of transforms using indexing.
Args:
index (int | list[int]): Index or list of indices of the transforms to retrieve.
@ -256,8 +244,7 @@ class Compose:
return Compose([self.transforms[i] for i in index]) if isinstance(index, list) else self.transforms[index]
def __setitem__(self, index: list | int, value: list | int) -> None:
"""
Set one or more transforms in the composition using indexing.
"""Set one or more transforms in the composition using indexing.
Args:
index (int | list[int]): Index or list of indices to set transforms at.
@ -283,8 +270,7 @@ class Compose:
self.transforms[i] = v
def tolist(self):
"""
Convert the list of transforms to a standard Python list.
"""Convert the list of transforms to a standard Python list.
Returns:
(list): A list containing all the transform objects in the Compose instance.
@ -299,8 +285,7 @@ class Compose:
return self.transforms
def __repr__(self):
"""
Return a string representation of the Compose object.
"""Return a string representation of the Compose object.
Returns:
(str): A string representation of the Compose object, including the list of transforms.
@ -318,8 +303,7 @@ class Compose:
class BaseMixTransform:
"""
Base class for mix transformations like Cutmix, MixUp and Mosaic.
"""Base class for mix transformations like Cutmix, MixUp and Mosaic.
This class provides a foundation for implementing mix transformations on datasets. It handles the probability-based
application of transforms and manages the mixing of multiple images and labels.
@ -349,8 +333,7 @@ class BaseMixTransform:
"""
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
"""
Initialize the BaseMixTransform object for mix transformations like CutMix, MixUp and Mosaic.
"""Initialize the BaseMixTransform object for mix transformations like CutMix, MixUp and Mosaic.
This class serves as a base for implementing mix transformations in image processing pipelines.
@ -369,8 +352,7 @@ class BaseMixTransform:
self.p = p
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Apply pre-processing transforms and cutmix/mixup/mosaic transforms to labels data.
"""Apply pre-processing transforms and cutmix/mixup/mosaic transforms to labels data.
This method determines whether to apply the mix transform based on a probability factor. If applied, it selects
additional images, applies pre-transforms if specified, and then performs the mix transform.
@ -409,8 +391,7 @@ class BaseMixTransform:
return labels
def _mix_transform(self, labels: dict[str, Any]):
"""
Apply CutMix, MixUp or Mosaic augmentation to the label dictionary.
"""Apply CutMix, MixUp or Mosaic augmentation to the label dictionary.
This method should be implemented by subclasses to perform specific mix transformations like CutMix, MixUp or
Mosaic. It modifies the input label dictionary in-place with the augmented data.
@ -430,8 +411,7 @@ class BaseMixTransform:
raise NotImplementedError
def get_indexes(self):
"""
Get a list of shuffled indexes for mosaic augmentation.
"""Get a list of shuffled indexes for mosaic augmentation.
Returns:
(list[int]): A list of shuffled indexes from the dataset.
@ -445,15 +425,14 @@ class BaseMixTransform:
@staticmethod
def _update_label_text(labels: dict[str, Any]) -> dict[str, Any]:
"""
Update label text and class IDs for mixed labels in image augmentation.
"""Update label text and class IDs for mixed labels in image augmentation.
This method processes the 'texts' and 'cls' fields of the input labels dictionary and any mixed labels, creating
a unified set of text labels and updating class IDs accordingly.
Args:
labels (dict[str, Any]): A dictionary containing label information, including 'texts' and 'cls' fields,
and optionally a 'mix_labels' field with additional label dictionaries.
labels (dict[str, Any]): A dictionary containing label information, including 'texts' and 'cls' fields, and
optionally a 'mix_labels' field with additional label dictionaries.
Returns:
(dict[str, Any]): The updated labels dictionary with unified text labels and updated class IDs.
@ -490,8 +469,7 @@ class BaseMixTransform:
class Mosaic(BaseMixTransform):
"""
Mosaic augmentation for image datasets.
"""Mosaic augmentation for image datasets.
This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. The
augmentation is applied to a dataset with a given probability.
@ -520,8 +498,7 @@ class Mosaic(BaseMixTransform):
"""
def __init__(self, dataset, imgsz: int = 640, p: float = 1.0, n: int = 4):
"""
Initialize the Mosaic augmentation object.
"""Initialize the Mosaic augmentation object.
This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. The
augmentation is applied to a dataset with a given probability.
@ -546,15 +523,14 @@ class Mosaic(BaseMixTransform):
self.buffer_enabled = self.dataset.cache != "ram"
def get_indexes(self):
"""
Return a list of random indexes from the dataset for mosaic augmentation.
"""Return a list of random indexes from the dataset for mosaic augmentation.
This method selects random image indexes either from a buffer or from the entire dataset, depending on the
'buffer' parameter. It is used to choose images for creating mosaic augmentations.
Returns:
(list[int]): A list of random image indexes. The length of the list is n-1, where n is the number
of images used in the mosaic (either 3 or 8, depending on whether n is 4 or 9).
(list[int]): A list of random image indexes. The length of the list is n-1, where n is the number of images
used in the mosaic (either 3 or 8, depending on whether n is 4 or 9).
Examples:
>>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)
@ -567,8 +543,7 @@ class Mosaic(BaseMixTransform):
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Apply mosaic augmentation to the input image and labels.
"""Apply mosaic augmentation to the input image and labels.
This method combines multiple images (3, 4, or 9) into a single mosaic image based on the 'n' attribute. It
ensures that rectangular annotations are not present and that there are other images available for mosaic
@ -596,8 +571,7 @@ class Mosaic(BaseMixTransform):
) # This code is modified for mosaic3 method.
def _mosaic3(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Create a 1x3 image mosaic by combining three images.
"""Create a 1x3 image mosaic by combining three images.
This method arranges three images in a horizontal layout, with the main image in the center and two additional
images on either side. It's part of the Mosaic augmentation technique used in object detection.
@ -655,8 +629,7 @@ class Mosaic(BaseMixTransform):
return final_labels
def _mosaic4(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Create a 2x2 image mosaic from four input images.
"""Create a 2x2 image mosaic from four input images.
This method combines four images into a single mosaic image by placing them in a 2x2 grid. It also updates the
corresponding labels for each image in the mosaic.
@ -714,8 +687,7 @@ class Mosaic(BaseMixTransform):
return final_labels
def _mosaic9(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Create a 3x3 image mosaic from the input image and eight additional images.
"""Create a 3x3 image mosaic from the input image and eight additional images.
This method combines nine images into a single mosaic image. The input image is placed at the center, and eight
additional images from the dataset are placed around it in a 3x3 grid pattern.
@ -788,8 +760,7 @@ class Mosaic(BaseMixTransform):
@staticmethod
def _update_labels(labels, padw: int, padh: int) -> dict[str, Any]:
"""
Update label coordinates with padding values.
"""Update label coordinates with padding values.
This method adjusts the bounding box coordinates of object instances in the labels by adding padding
values. It also denormalizes the coordinates if they were previously normalized.
@ -814,8 +785,7 @@ class Mosaic(BaseMixTransform):
return labels
def _cat_labels(self, mosaic_labels: list[dict[str, Any]]) -> dict[str, Any]:
"""
Concatenate and process labels for mosaic augmentation.
"""Concatenate and process labels for mosaic augmentation.
This method combines labels from multiple images used in mosaic augmentation, clips instances to the mosaic
border, and removes zero-area boxes.
@ -866,8 +836,7 @@ class Mosaic(BaseMixTransform):
class MixUp(BaseMixTransform):
"""
Apply MixUp augmentation to image datasets.
"""Apply MixUp augmentation to image datasets.
This class implements the MixUp augmentation technique as described in the paper [mixup: Beyond Empirical Risk
Minimization](https://arxiv.org/abs/1710.09412). MixUp combines two images and their labels using a random weight.
@ -888,8 +857,7 @@ class MixUp(BaseMixTransform):
"""
def __init__(self, dataset, pre_transform=None, p: float = 0.0) -> None:
"""
Initialize the MixUp augmentation object.
"""Initialize the MixUp augmentation object.
MixUp is an image augmentation technique that combines two images by taking a weighted sum of their pixel values
and labels. This implementation is designed for use with the Ultralytics YOLO framework.
@ -907,8 +875,7 @@ class MixUp(BaseMixTransform):
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Apply MixUp augmentation to the input labels.
"""Apply MixUp augmentation to the input labels.
This method implements the MixUp augmentation technique as described in the paper "mixup: Beyond Empirical Risk
Minimization" (https://arxiv.org/abs/1710.09412).
@ -932,8 +899,7 @@ class MixUp(BaseMixTransform):
class CutMix(BaseMixTransform):
"""
Apply CutMix augmentation to image datasets as described in the paper https://arxiv.org/abs/1905.04899.
"""Apply CutMix augmentation to image datasets as described in the paper https://arxiv.org/abs/1905.04899.
CutMix combines two images by replacing a random rectangular region of one image with the corresponding region from
another image, and adjusts the labels proportionally to the area of the mixed region.
@ -957,8 +923,7 @@ class CutMix(BaseMixTransform):
"""
def __init__(self, dataset, pre_transform=None, p: float = 0.0, beta: float = 1.0, num_areas: int = 3) -> None:
"""
Initialize the CutMix augmentation object.
"""Initialize the CutMix augmentation object.
Args:
dataset (Any): The dataset to which CutMix augmentation will be applied.
@ -972,8 +937,7 @@ class CutMix(BaseMixTransform):
self.num_areas = num_areas
def _rand_bbox(self, width: int, height: int) -> tuple[int, int, int, int]:
"""
Generate random bounding box coordinates for the cut region.
"""Generate random bounding box coordinates for the cut region.
Args:
width (int): Width of the image.
@ -1002,8 +966,7 @@ class CutMix(BaseMixTransform):
return x1, y1, x2, y2
def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Apply CutMix augmentation to the input labels.
"""Apply CutMix augmentation to the input labels.
Args:
labels (dict[str, Any]): A dictionary containing the original image and label information.
@ -1050,8 +1013,7 @@ class CutMix(BaseMixTransform):
class RandomPerspective:
"""
Implement random perspective and affine transformations on images and corresponding annotations.
"""Implement random perspective and affine transformations on images and corresponding annotations.
This class applies random rotations, translations, scaling, shearing, and perspective transformations to images and
their associated bounding boxes, segments, and keypoints. It can be used as part of an augmentation pipeline for
@ -1093,8 +1055,7 @@ class RandomPerspective:
border: tuple[int, int] = (0, 0),
pre_transform=None,
):
"""
Initialize RandomPerspective object with transformation parameters.
"""Initialize RandomPerspective object with transformation parameters.
This class implements random perspective and affine transformations on images and corresponding bounding boxes,
segments, and keypoints. Transformations include rotation, translation, scaling, and shearing.
@ -1122,8 +1083,7 @@ class RandomPerspective:
self.pre_transform = pre_transform
def affine_transform(self, img: np.ndarray, border: tuple[int, int]) -> tuple[np.ndarray, np.ndarray, float]:
"""
Apply a sequence of affine transformations centered around the image center.
"""Apply a sequence of affine transformations centered around the image center.
This function performs a series of geometric transformations on the input image, including translation,
perspective change, rotation, scaling, and shearing. The transformations are applied in a specific order to
@ -1186,15 +1146,14 @@ class RandomPerspective:
return img, M, s
def apply_bboxes(self, bboxes: np.ndarray, M: np.ndarray) -> np.ndarray:
"""
Apply affine transformation to bounding boxes.
"""Apply affine transformation to bounding boxes.
This function applies an affine transformation to a set of bounding boxes using the provided transformation
matrix.
Args:
bboxes (np.ndarray): Bounding boxes in xyxy format with shape (N, 4), where N is the number
of bounding boxes.
bboxes (np.ndarray): Bounding boxes in xyxy format with shape (N, 4), where N is the number of bounding
boxes.
M (np.ndarray): Affine transformation matrix with shape (3, 3).
Returns:
@ -1220,8 +1179,7 @@ class RandomPerspective:
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
def apply_segments(self, segments: np.ndarray, M: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Apply affine transformations to segments and generate new bounding boxes.
"""Apply affine transformations to segments and generate new bounding boxes.
This function applies affine transformations to input segments and generates new bounding boxes based on the
transformed segments. It clips the transformed segments to fit within the new bounding boxes.
@ -1256,16 +1214,15 @@ class RandomPerspective:
return bboxes, segments
def apply_keypoints(self, keypoints: np.ndarray, M: np.ndarray) -> np.ndarray:
"""
Apply affine transformation to keypoints.
"""Apply affine transformation to keypoints.
This method transforms the input keypoints using the provided affine transformation matrix. It handles
perspective rescaling if necessary and updates the visibility of keypoints that fall outside the image
boundaries after transformation.
Args:
keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances,
17 is the number of keypoints per instance, and 3 represents (x, y, visibility).
keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances, 17 is
the number of keypoints per instance, and 3 represents (x, y, visibility).
M (np.ndarray): 3x3 affine transformation matrix.
Returns:
@ -1290,8 +1247,7 @@ class RandomPerspective:
return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3)
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Apply random perspective and affine transformations to an image and its associated labels.
"""Apply random perspective and affine transformations to an image and its associated labels.
This method performs a series of transformations including rotation, translation, scaling, shearing, and
perspective distortion on the input image and adjusts the corresponding bounding boxes, segments, and keypoints
@ -1378,29 +1334,27 @@ class RandomPerspective:
area_thr: float = 0.1,
eps: float = 1e-16,
) -> np.ndarray:
"""
Compute candidate boxes for further processing based on size and aspect ratio criteria.
"""Compute candidate boxes for further processing based on size and aspect ratio criteria.
This method compares boxes before and after augmentation to determine if they meet specified thresholds for
width, height, aspect ratio, and area. It's used to filter out boxes that have been overly distorted or reduced
by the augmentation process.
Args:
box1 (np.ndarray): Original boxes before augmentation, shape (4, N) where n is the
number of boxes. Format is [x1, y1, x2, y2] in absolute coordinates.
box2 (np.ndarray): Augmented boxes after transformation, shape (4, N). Format is
[x1, y1, x2, y2] in absolute coordinates.
wh_thr (int): Width and height threshold in pixels. Boxes smaller than this in either
dimension are rejected.
ar_thr (int): Aspect ratio threshold. Boxes with an aspect ratio greater than this
value are rejected.
area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than
this value are rejected.
box1 (np.ndarray): Original boxes before augmentation, shape (4, N) where n is the number of boxes. Format
is [x1, y1, x2, y2] in absolute coordinates.
box2 (np.ndarray): Augmented boxes after transformation, shape (4, N). Format is [x1, y1, x2, y2] in
absolute coordinates.
wh_thr (int): Width and height threshold in pixels. Boxes smaller than this in either dimension are
rejected.
ar_thr (int): Aspect ratio threshold. Boxes with an aspect ratio greater than this value are rejected.
area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than this value are
rejected.
eps (float): Small epsilon value to prevent division by zero.
Returns:
(np.ndarray): Boolean array of shape (n) indicating which boxes are candidates.
True values correspond to boxes that meet all criteria.
(np.ndarray): Boolean array of shape (n) indicating which boxes are candidates. True values correspond to
boxes that meet all criteria.
Examples:
>>> random_perspective = RandomPerspective()
@ -1417,8 +1371,7 @@ class RandomPerspective:
class RandomHSV:
"""
Randomly adjust the Hue, Saturation, and Value (HSV) channels of an image.
"""Randomly adjust the Hue, Saturation, and Value (HSV) channels of an image.
This class applies random HSV augmentation to images within predefined limits set by hgain, sgain, and vgain.
@ -1441,8 +1394,7 @@ class RandomHSV:
"""
def __init__(self, hgain: float = 0.5, sgain: float = 0.5, vgain: float = 0.5) -> None:
"""
Initialize the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation.
"""Initialize the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation.
This class applies random adjustments to the HSV channels of an image within specified limits.
@ -1460,15 +1412,14 @@ class RandomHSV:
self.vgain = vgain
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Apply random HSV augmentation to an image within predefined limits.
"""Apply random HSV augmentation to an image within predefined limits.
This method modifies the input image by randomly adjusting its Hue, Saturation, and Value (HSV) channels. The
adjustments are made within the limits set by hgain, sgain, and vgain during initialization.
Args:
labels (dict[str, Any]): A dictionary containing image data and metadata. Must include an 'img' key with
the image as a numpy array.
labels (dict[str, Any]): A dictionary containing image data and metadata. Must include an 'img' key with the
image as a numpy array.
Returns:
(dict[str, Any]): A dictionary containing the mixed image and adjusted labels.
@ -1500,8 +1451,7 @@ class RandomHSV:
class RandomFlip:
"""
Apply a random horizontal or vertical flip to an image with a given probability.
"""Apply a random horizontal or vertical flip to an image with a given probability.
This class performs random image flipping and updates corresponding instance annotations such as bounding boxes and
keypoints.
@ -1522,8 +1472,7 @@ class RandomFlip:
"""
def __init__(self, p: float = 0.5, direction: str = "horizontal", flip_idx: list[int] | None = None) -> None:
"""
Initialize the RandomFlip class with probability and direction.
"""Initialize the RandomFlip class with probability and direction.
This class applies a random horizontal or vertical flip to an image with a given probability. It also updates
any instances (bounding boxes, keypoints, etc.) accordingly.
@ -1548,8 +1497,7 @@ class RandomFlip:
self.flip_idx = flip_idx
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Apply random flip to an image and update any instances like bounding boxes or keypoints accordingly.
"""Apply random flip to an image and update any instances like bounding boxes or keypoints accordingly.
This method randomly flips the input image either horizontally or vertically based on the initialized
probability and direction. It also updates the corresponding instances (bounding boxes, keypoints) to match the
@ -1594,8 +1542,7 @@ class RandomFlip:
class LetterBox:
"""
Resize image and padding for detection, instance segmentation, pose.
"""Resize image and padding for detection, instance segmentation, pose.
This class resizes and pads images to a specified shape while preserving aspect ratio. It also updates corresponding
labels and bounding boxes.
@ -1629,8 +1576,7 @@ class LetterBox:
padding_value: int = 114,
interpolation: int = cv2.INTER_LINEAR,
):
"""
Initialize LetterBox object for resizing and padding images.
"""Initialize LetterBox object for resizing and padding images.
This class is designed to resize and pad images for object detection, instance segmentation, and pose estimation
tasks. It supports various resizing modes including auto-sizing, scale-fill, and letterboxing.
@ -1668,8 +1614,7 @@ class LetterBox:
self.interpolation = interpolation
def __call__(self, labels: dict[str, Any] | None = None, image: np.ndarray = None) -> dict[str, Any] | np.ndarray:
"""
Resize and pad an image for object detection, instance segmentation, or pose estimation tasks.
"""Resize and pad an image for object detection, instance segmentation, or pose estimation tasks.
This method applies letterboxing to the input image, which involves resizing the image while maintaining its
aspect ratio and adding padding to fit the new shape. It also updates any associated labels accordingly.
@ -1748,8 +1693,7 @@ class LetterBox:
@staticmethod
def _update_labels(labels: dict[str, Any], ratio: tuple[float, float], padw: float, padh: float) -> dict[str, Any]:
"""
Update labels after applying letterboxing to an image.
"""Update labels after applying letterboxing to an image.
This method modifies the bounding box coordinates of instances in the labels to account for resizing and padding
applied during letterboxing.
@ -1778,8 +1722,7 @@ class LetterBox:
class CopyPaste(BaseMixTransform):
"""
CopyPaste class for applying Copy-Paste augmentation to image datasets.
"""CopyPaste class for applying Copy-Paste augmentation to image datasets.
This class implements the Copy-Paste augmentation technique as described in the paper "Simple Copy-Paste is a Strong
Data Augmentation Method for Instance Segmentation" (https://arxiv.org/abs/2012.07177). It combines objects from
@ -1878,8 +1821,7 @@ class CopyPaste(BaseMixTransform):
class Albumentations:
"""
Albumentations transformations for image augmentation.
"""Albumentations transformations for image augmentation.
This class applies various image transformations using the Albumentations library. It includes operations such as
Blur, Median Blur, conversion to grayscale, Contrast Limited Adaptive Histogram Equalization (CLAHE), random changes
@ -1904,8 +1846,7 @@ class Albumentations:
"""
def __init__(self, p: float = 1.0) -> None:
"""
Initialize the Albumentations transform object for YOLO bbox formatted parameters.
"""Initialize the Albumentations transform object for YOLO bbox formatted parameters.
This class applies various image augmentations using the Albumentations library, including Blur, Median Blur,
conversion to grayscale, Contrast Limited Adaptive Histogram Equalization, random changes of brightness and
@ -2018,8 +1959,7 @@ class Albumentations:
LOGGER.info(f"{prefix}{e}")
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Apply Albumentations transformations to input labels.
"""Apply Albumentations transformations to input labels.
This method applies a series of image augmentations using the Albumentations library. It can perform both
spatial and non-spatial transformations on the input image and its corresponding labels.
@ -2075,8 +2015,7 @@ class Albumentations:
class Format:
"""
A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks.
"""A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks.
This class standardizes image and instance annotations to be used by the `collate_fn` in PyTorch DataLoader.
@ -2116,8 +2055,7 @@ class Format:
batch_idx: bool = True,
bgr: float = 0.0,
):
"""
Initialize the Format class with given parameters for image and instance annotation formatting.
"""Initialize the Format class with given parameters for image and instance annotation formatting.
This class standardizes image and instance annotations for object detection, instance segmentation, and pose
estimation tasks, preparing them for use in PyTorch DataLoader's `collate_fn`.
@ -2160,8 +2098,7 @@ class Format:
self.bgr = bgr
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Format image annotations for object detection, instance segmentation, and pose estimation tasks.
"""Format image annotations for object detection, instance segmentation, and pose estimation tasks.
This method standardizes the image and instance annotations to be used by the `collate_fn` in PyTorch
DataLoader. It processes the input labels dictionary, converting annotations to the specified format and
@ -2229,8 +2166,7 @@ class Format:
return labels
def _format_img(self, img: np.ndarray) -> torch.Tensor:
"""
Format an image for YOLO from a Numpy array to a PyTorch tensor.
"""Format an image for YOLO from a Numpy array to a PyTorch tensor.
This function performs the following operations:
1. Ensures the image has 3 dimensions (adds a channel dimension if needed).
@ -2262,8 +2198,7 @@ class Format:
def _format_segments(
self, instances: Instances, cls: np.ndarray, w: int, h: int
) -> tuple[np.ndarray, Instances, np.ndarray]:
"""
Convert polygon segments to bitmap masks.
"""Convert polygon segments to bitmap masks.
Args:
instances (Instances): Object containing segment information.
@ -2297,8 +2232,7 @@ class LoadVisualPrompt:
"""Create visual prompts from bounding boxes or masks for model input."""
def __init__(self, scale_factor: float = 1 / 8) -> None:
"""
Initialize the LoadVisualPrompt with a scale factor.
"""Initialize the LoadVisualPrompt with a scale factor.
Args:
scale_factor (float): Factor to scale the input image dimensions.
@ -2306,8 +2240,7 @@ class LoadVisualPrompt:
self.scale_factor = scale_factor
def make_mask(self, boxes: torch.Tensor, h: int, w: int) -> torch.Tensor:
"""
Create binary masks from bounding boxes.
"""Create binary masks from bounding boxes.
Args:
boxes (torch.Tensor): Bounding boxes in xyxy format, shape: (N, 4).
@ -2324,8 +2257,7 @@ class LoadVisualPrompt:
return (r >= x1) * (r < x2) * (c >= y1) * (c < y2)
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Process labels to create visual prompts.
"""Process labels to create visual prompts.
Args:
labels (dict[str, Any]): Dictionary containing image data and annotations.
@ -2351,8 +2283,7 @@ class LoadVisualPrompt:
bboxes: np.ndarray | torch.Tensor = None,
masks: np.ndarray | torch.Tensor = None,
) -> torch.Tensor:
"""
Generate visual masks based on bounding boxes or masks.
"""Generate visual masks based on bounding boxes or masks.
Args:
category (int | np.ndarray | torch.Tensor): The category labels for the objects.
@ -2393,8 +2324,7 @@ class LoadVisualPrompt:
class RandomLoadText:
"""
Randomly sample positive and negative texts and update class indices accordingly.
"""Randomly sample positive and negative texts and update class indices accordingly.
This class is responsible for sampling texts from a given set of class texts, including both positive (present in
the image) and negative (not present in the image) samples. It updates the class indices to reflect the sampled
@ -2426,20 +2356,19 @@ class RandomLoadText:
padding: bool = False,
padding_value: list[str] = [""],
) -> None:
"""
Initialize the RandomLoadText class for randomly sampling positive and negative texts.
"""Initialize the RandomLoadText class for randomly sampling positive and negative texts.
This class is designed to randomly sample positive texts and negative texts, and update the class indices
accordingly to the number of samples. It can be used for text-based object detection tasks.
Args:
prompt_format (str): Format string for the prompt. The format string should
contain a single pair of curly braces {} where the text will be inserted.
neg_samples (tuple[int, int]): A range to randomly sample negative texts. The first integer
specifies the minimum number of negative samples, and the second integer specifies the maximum.
prompt_format (str): Format string for the prompt. The format string should contain a single pair of curly
braces {} where the text will be inserted.
neg_samples (tuple[int, int]): A range to randomly sample negative texts. The first integer specifies the
minimum number of negative samples, and the second integer specifies the maximum.
max_samples (int): The maximum number of different text samples in one image.
padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always
be equal to max_samples.
padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always be equal to
max_samples.
padding_value (str): The padding text to use when padding is True.
Attributes:
@ -2465,8 +2394,7 @@ class RandomLoadText:
self.padding_value = padding_value
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
"""
Randomly sample positive and negative texts and update class indices accordingly.
"""Randomly sample positive and negative texts and update class indices accordingly.
This method samples positive texts based on the existing class labels in the image, and randomly selects
negative texts from the remaining classes. It then updates the class indices to match the new sampled text
@ -2532,8 +2460,7 @@ class RandomLoadText:
def v8_transforms(dataset, imgsz: int, hyp: IterableSimpleNamespace, stretch: bool = False):
"""
Apply a series of image transformations for training.
"""Apply a series of image transformations for training.
This function creates a composition of image augmentation techniques to prepare images for YOLO training. It
includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments.
@ -2608,8 +2535,7 @@ def classify_transforms(
interpolation: str = "BILINEAR",
crop_fraction: float | None = None,
):
"""
Create a composition of image transforms for classification tasks.
"""Create a composition of image transforms for classification tasks.
This function generates a sequence of torchvision transforms suitable for preprocessing images for classification
models during evaluation or inference. The transforms include resizing, center cropping, conversion to tensor, and
@ -2668,8 +2594,7 @@ def classify_augmentations(
erasing: float = 0.0,
interpolation: str = "BILINEAR",
):
"""
Create a composition of image augmentation transforms for classification tasks.
"""Create a composition of image augmentation transforms for classification tasks.
This function generates a set of image transformations suitable for training classification models. It includes
options for resizing, flipping, color jittering, auto augmentation, and random erasing.
@ -2757,8 +2682,7 @@ def classify_augmentations(
# NOTE: keep this class for backward compatibility
class ClassifyLetterBox:
"""
A class for resizing and padding images for classification tasks.
"""A class for resizing and padding images for classification tasks.
This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]). It
resizes and pads images to a specified size while maintaining the original aspect ratio.
@ -2781,15 +2705,14 @@ class ClassifyLetterBox:
"""
def __init__(self, size: int | tuple[int, int] = (640, 640), auto: bool = False, stride: int = 32):
"""
Initialize the ClassifyLetterBox object for image preprocessing.
"""Initialize the ClassifyLetterBox object for image preprocessing.
This class is designed to be part of a transformation pipeline for image classification tasks. It resizes and
pads images to a specified size while maintaining the original aspect ratio.
Args:
size (int | tuple[int, int]): Target size for the letterboxed image. If an int, a square image of
(size, size) is created. If a tuple, it should be (height, width).
size (int | tuple[int, int]): Target size for the letterboxed image. If an int, a square image of (size,
size) is created. If a tuple, it should be (height, width).
auto (bool): If True, automatically calculates the short side based on stride.
stride (int): The stride value, used when 'auto' is True.
@ -2812,8 +2735,7 @@ class ClassifyLetterBox:
self.stride = stride # used with auto
def __call__(self, im: np.ndarray) -> np.ndarray:
"""
Resize and pad an image using the letterbox method.
"""Resize and pad an image using the letterbox method.
This method resizes the input image to fit within the specified dimensions while maintaining its aspect ratio,
then pads the resized image to match the target size.
@ -2822,8 +2744,8 @@ class ClassifyLetterBox:
im (np.ndarray): Input image as a numpy array with shape (H, W, C).
Returns:
(np.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are
the target height and width respectively.
(np.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are the
target height and width respectively.
Examples:
>>> letterbox = ClassifyLetterBox(size=(640, 640))
@ -2848,8 +2770,7 @@ class ClassifyLetterBox:
# NOTE: keep this class for backward compatibility
class CenterCrop:
"""
Apply center cropping to images for classification tasks.
"""Apply center cropping to images for classification tasks.
This class performs center cropping on input images, resizing them to a specified size while maintaining the aspect
ratio. It is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).
@ -2870,15 +2791,14 @@ class CenterCrop:
"""
def __init__(self, size: int | tuple[int, int] = (640, 640)):
"""
Initialize the CenterCrop object for image preprocessing.
"""Initialize the CenterCrop object for image preprocessing.
This class is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).
It performs a center crop on input images to a specified size.
Args:
size (int | tuple[int, int]): The desired output size of the crop. If size is an int, a square crop
(size, size) is made. If size is a sequence like (h, w), it is used as the output size.
size (int | tuple[int, int]): The desired output size of the crop. If size is an int, a square crop (size,
size) is made. If size is a sequence like (h, w), it is used as the output size.
Returns:
(None): This method initializes the object and does not return anything.
@ -2894,15 +2814,14 @@ class CenterCrop:
self.h, self.w = (size, size) if isinstance(size, int) else size
def __call__(self, im: Image.Image | np.ndarray) -> np.ndarray:
"""
Apply center cropping to an input image.
"""Apply center cropping to an input image.
This method resizes and crops the center of the image using a letterbox method. It maintains the aspect ratio of
the original image while fitting it into the specified dimensions.
Args:
im (np.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a
PIL Image object.
im (np.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a PIL Image
object.
Returns:
(np.ndarray): The center-cropped and resized image as a numpy array of shape (self.h, self.w, C).
@ -2923,8 +2842,7 @@ class CenterCrop:
# NOTE: keep this class for backward compatibility
class ToTensor:
"""
Convert an image from a numpy array to a PyTorch tensor.
"""Convert an image from a numpy array to a PyTorch tensor.
This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).
@ -2947,8 +2865,7 @@ class ToTensor:
"""
def __init__(self, half: bool = False):
"""
Initialize the ToTensor object for converting images to PyTorch tensors.
"""Initialize the ToTensor object for converting images to PyTorch tensors.
This class is designed to be used as part of a transformation pipeline for image preprocessing in the
Ultralytics YOLO framework. It converts numpy arrays or PIL Images to PyTorch tensors, with an option for
@ -2968,8 +2885,7 @@ class ToTensor:
self.half = half
def __call__(self, im: np.ndarray) -> torch.Tensor:
"""
Transform an image from a numpy array to a PyTorch tensor.
"""Transform an image from a numpy array to a PyTorch tensor.
This method converts the input image from a numpy array to a PyTorch tensor, applying optional half-precision
conversion and normalization. The image is transposed from HWC to CHW format and the color channels are reversed
@ -2979,8 +2895,8 @@ class ToTensor:
im (np.ndarray): Input image as a numpy array with shape (H, W, C) in RGB order.
Returns:
(torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized
to [0, 1] with shape (C, H, W) in RGB order.
(torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized to [0, 1] with
shape (C, H, W) in RGB order.
Examples:
>>> transform = ToTensor(half=True)

View file

@ -21,8 +21,7 @@ from ultralytics.utils.patches import imread
class BaseDataset(Dataset):
"""
Base dataset class for loading and processing image data.
"""Base dataset class for loading and processing image data.
This class provides core functionality for loading images, caching, and preparing data for training and inference in
object detection tasks.
@ -86,8 +85,7 @@ class BaseDataset(Dataset):
fraction: float = 1.0,
channels: int = 3,
):
"""
Initialize BaseDataset with given configuration and options.
"""Initialize BaseDataset with given configuration and options.
Args:
img_path (str | list[str]): Path to the folder containing images or list of image paths.
@ -148,8 +146,7 @@ class BaseDataset(Dataset):
self.transforms = self.build_transforms(hyp=hyp)
def get_img_files(self, img_path: str | list[str]) -> list[str]:
"""
Read image files from the specified path.
"""Read image files from the specified path.
Args:
img_path (str | list[str]): Path or list of paths to image directories or files.
@ -186,8 +183,7 @@ class BaseDataset(Dataset):
return im_files
def update_labels(self, include_class: list[int] | None) -> None:
"""
Update labels to include only specified classes.
"""Update labels to include only specified classes.
Args:
include_class (list[int], optional): List of classes to include. If None, all classes are included.
@ -210,8 +206,7 @@ class BaseDataset(Dataset):
self.labels[i]["cls"][:, 0] = 0
def load_image(self, i: int, rect_mode: bool = True) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
"""
Load an image from dataset index 'i'.
"""Load an image from dataset index 'i'.
Args:
i (int): Index of the image to load.
@ -286,8 +281,7 @@ class BaseDataset(Dataset):
np.save(f.as_posix(), imread(self.im_files[i]), allow_pickle=False)
def check_cache_disk(self, safety_margin: float = 0.5) -> bool:
"""
Check if there's enough disk space for caching images.
"""Check if there's enough disk space for caching images.
Args:
safety_margin (float): Safety margin factor for disk space calculation.
@ -322,8 +316,7 @@ class BaseDataset(Dataset):
return True
def check_cache_ram(self, safety_margin: float = 0.5) -> bool:
"""
Check if there's enough RAM for caching images.
"""Check if there's enough RAM for caching images.
Args:
safety_margin (float): Safety margin factor for RAM calculation.
@ -381,8 +374,7 @@ class BaseDataset(Dataset):
return self.transforms(self.get_image_and_label(index))
def get_image_and_label(self, index: int) -> dict[str, Any]:
"""
Get and return label information from the dataset.
"""Get and return label information from the dataset.
Args:
index (int): Index of the image to retrieve.
@ -410,8 +402,7 @@ class BaseDataset(Dataset):
return label
def build_transforms(self, hyp: dict[str, Any] | None = None):
"""
Users can customize augmentations here.
"""Users can customize augmentations here.
Examples:
>>> if self.augment:
@ -424,8 +415,7 @@ class BaseDataset(Dataset):
raise NotImplementedError
def get_labels(self) -> list[dict[str, Any]]:
"""
Users can customize their own format here.
"""Users can customize their own format here.
Examples:
Ensure output is a dictionary with the following keys:

View file

@ -35,8 +35,7 @@ from ultralytics.utils.torch_utils import TORCH_2_0
class InfiniteDataLoader(dataloader.DataLoader):
"""
Dataloader that reuses workers for infinite iteration.
"""Dataloader that reuses workers for infinite iteration.
This dataloader extends the PyTorch DataLoader to provide infinite recycling of workers, which improves efficiency
for training loops that need to iterate through the dataset multiple times without recreating workers.
@ -94,8 +93,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
class _RepeatSampler:
"""
Sampler that repeats forever for infinite iteration.
"""Sampler that repeats forever for infinite iteration.
This sampler wraps another sampler and yields its contents indefinitely, allowing for infinite iteration over a
dataset without recreating the sampler.
@ -115,8 +113,7 @@ class _RepeatSampler:
class ContiguousDistributedSampler(torch.utils.data.Sampler):
"""
Distributed sampler that assigns contiguous batch-aligned chunks of the dataset to each GPU.
"""Distributed sampler that assigns contiguous batch-aligned chunks of the dataset to each GPU.
Unlike PyTorch's DistributedSampler which distributes samples in a round-robin fashion (GPU 0 gets indices
[0,2,4,...], GPU 1 gets [1,3,5,...]), this sampler gives each GPU contiguous batches of the dataset (GPU 0 gets
@ -132,8 +129,8 @@ class ContiguousDistributedSampler(torch.utils.data.Sampler):
num_replicas (int, optional): Number of distributed processes. Defaults to world size.
batch_size (int, optional): Batch size used by dataloader. Defaults to dataset batch size.
rank (int, optional): Rank of current process. Defaults to current rank.
shuffle (bool, optional): Whether to shuffle indices within each rank's chunk. Defaults to False.
When True, shuffling is deterministic and controlled by set_epoch() for reproducibility.
shuffle (bool, optional): Whether to shuffle indices within each rank's chunk. Defaults to False. When True,
shuffling is deterministic and controlled by set_epoch() for reproducibility.
Examples:
>>> # For validation with size-grouped images
@ -202,8 +199,7 @@ class ContiguousDistributedSampler(torch.utils.data.Sampler):
return end_idx - start_idx
def set_epoch(self, epoch):
"""
Set the epoch for this sampler to ensure different shuffling patterns across epochs.
"""Set the epoch for this sampler to ensure different shuffling patterns across epochs.
Args:
epoch (int): Epoch number to use as the random seed for shuffling.
@ -289,8 +285,7 @@ def build_dataloader(
drop_last: bool = False,
pin_memory: bool = True,
):
"""
Create and return an InfiniteDataLoader or DataLoader for training or validation.
"""Create and return an InfiniteDataLoader or DataLoader for training or validation.
Args:
dataset (Dataset): Dataset to load data from.
@ -337,8 +332,7 @@ def build_dataloader(
def check_source(source):
"""
Check the type of input source and return corresponding flag values.
"""Check the type of input source and return corresponding flag values.
Args:
source (str | int | Path | list | tuple | np.ndarray | PIL.Image | torch.Tensor): The input source to check.
@ -386,8 +380,7 @@ def check_source(source):
def load_inference_source(source=None, batch: int = 1, vid_stride: int = 1, buffer: bool = False, channels: int = 3):
"""
Load an inference source for object detection and apply necessary transformations.
"""Load an inference source for object detection and apply necessary transformations.
Args:
source (str | Path | torch.Tensor | PIL.Image | np.ndarray, optional): The input source for inference.

View file

@ -21,12 +21,11 @@ from ultralytics.utils.files import increment_path
def coco91_to_coco80_class() -> list[int]:
"""
Convert 91-index COCO class IDs to 80-index COCO class IDs.
"""Convert 91-index COCO class IDs to 80-index COCO class IDs.
Returns:
(list[int]): A list of 91 class IDs where the index represents the 80-index class ID and the value
is the corresponding 91-index class ID.
(list[int]): A list of 91 class IDs where the index represents the 80-index class ID and the value is the
corresponding 91-index class ID.
"""
return [
0,
@ -124,8 +123,7 @@ def coco91_to_coco80_class() -> list[int]:
def coco80_to_coco91_class() -> list[int]:
r"""
Convert 80-index (val2014) to 91-index (paper).
r"""Convert 80-index (val2014) to 91-index (paper).
Returns:
(list[int]): A list of 80 class IDs where each value is the corresponding 91-index class ID.
@ -236,8 +234,7 @@ def convert_coco(
cls91to80: bool = True,
lvis: bool = False,
):
"""
Convert COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
"""Convert COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
Args:
labels_dir (str, optional): Path to directory containing COCO dataset annotation files.
@ -348,8 +345,7 @@ def convert_coco(
def convert_segment_masks_to_yolo_seg(masks_dir: str, output_dir: str, classes: int):
"""
Convert a dataset of segmentation mask images to the YOLO segmentation format.
"""Convert a dataset of segmentation mask images to the YOLO segmentation format.
This function takes the directory containing the binary format mask images and converts them into YOLO segmentation
format. The converted masks are saved in the specified output directory.
@ -424,8 +420,7 @@ def convert_segment_masks_to_yolo_seg(masks_dir: str, output_dir: str, classes:
def convert_dota_to_yolo_obb(dota_root_path: str):
"""
Convert DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format.
"""Convert DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format.
The function processes images in the 'train' and 'val' folders of the DOTA dataset. For each image, it reads the
associated label from the original labels directory and writes new labels in YOLO OBB format to a new directory.
@ -517,8 +512,7 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
def min_index(arr1: np.ndarray, arr2: np.ndarray):
"""
Find a pair of indexes with the shortest distance between two arrays of 2D points.
"""Find a pair of indexes with the shortest distance between two arrays of 2D points.
Args:
arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points.
@ -533,14 +527,14 @@ def min_index(arr1: np.ndarray, arr2: np.ndarray):
def merge_multi_segment(segments: list[list]):
"""
Merge multiple segments into one list by connecting the coordinates with the minimum distance between each segment.
"""Merge multiple segments into one list by connecting the coordinates with the minimum distance between each
segment.
This function connects these coordinates with a thin line to merge all segments into one.
Args:
segments (list[list]): Original segmentations in COCO's JSON file.
Each element is a list of coordinates, like [segmentation1, segmentation2,...].
segments (list[list]): Original segmentations in COCO's JSON file. Each element is a list of coordinates, like
[segmentation1, segmentation2,...].
Returns:
s (list[np.ndarray]): A list of connected segments represented as NumPy arrays.
@ -584,14 +578,13 @@ def merge_multi_segment(segments: list[list]):
def yolo_bbox2segment(im_dir: str | Path, save_dir: str | Path | None = None, sam_model: str = "sam_b.pt", device=None):
"""
Convert existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB) in
YOLO format. Generate segmentation data using SAM auto-annotator as needed.
"""Convert existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB)
in YOLO format. Generate segmentation data using SAM auto-annotator as needed.
Args:
im_dir (str | Path): Path to image directory to convert.
save_dir (str | Path, optional): Path to save the generated labels, labels will be saved
into `labels-segment` in the same directory level of `im_dir` if save_dir is None.
save_dir (str | Path, optional): Path to save the generated labels, labels will be saved into `labels-segment`
in the same directory level of `im_dir` if save_dir is None.
sam_model (str): Segmentation model to use for intermediate segmentation data.
device (int | str, optional): The specific device to run SAM models.
@ -648,8 +641,7 @@ def yolo_bbox2segment(im_dir: str | Path, save_dir: str | Path | None = None, sa
def create_synthetic_coco_dataset():
"""
Create a synthetic COCO dataset with random images based on filenames from label lists.
"""Create a synthetic COCO dataset with random images based on filenames from label lists.
This function downloads COCO labels, reads image filenames from label list files, creates synthetic images for
train2017 and val2017 subsets, and organizes them in the COCO dataset structure. It uses multithreading to generate
@ -704,8 +696,7 @@ def create_synthetic_coco_dataset():
def convert_to_multispectral(path: str | Path, n_channels: int = 10, replace: bool = False, zip: bool = False):
"""
Convert RGB images to multispectral images by interpolating across wavelength bands.
"""Convert RGB images to multispectral images by interpolating across wavelength bands.
This function takes RGB images and interpolates them to create multispectral images with a specified number of
channels. It can process either a single image or a directory of images.
@ -756,8 +747,7 @@ def convert_to_multispectral(path: str | Path, n_channels: int = 10, replace: bo
async def convert_ndjson_to_yolo(ndjson_path: str | Path, output_path: str | Path | None = None) -> Path:
"""
Convert NDJSON dataset format to Ultralytics YOLO11 dataset structure.
"""Convert NDJSON dataset format to Ultralytics YOLO11 dataset structure.
This function converts datasets stored in NDJSON (Newline Delimited JSON) format to the standard YOLO format with
separate directories for images and labels. It supports parallel processing for efficient conversion of large
@ -769,8 +759,8 @@ async def convert_ndjson_to_yolo(ndjson_path: str | Path, output_path: str | Pat
Args:
ndjson_path (Union[str, Path]): Path to the input NDJSON file containing dataset information.
output_path (Optional[Union[str, Path]], optional): Directory where the converted YOLO dataset
will be saved. If None, uses the parent directory of the NDJSON file. Defaults to None.
output_path (Optional[Union[str, Path]], optional): Directory where the converted YOLO dataset will be saved. If
None, uses the parent directory of the NDJSON file. Defaults to None.
Returns:
(Path): Path to the generated data.yaml file that can be used for YOLO training.

View file

@ -47,8 +47,7 @@ DATASET_CACHE_VERSION = "1.0.3"
class YOLODataset(BaseDataset):
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
"""Dataset class for loading object detection and/or segmentation labels in YOLO format.
This class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box
(OBB) tasks using the YOLO format.
@ -73,8 +72,7 @@ class YOLODataset(BaseDataset):
"""
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
"""
Initialize the YOLODataset.
"""Initialize the YOLODataset.
Args:
data (dict, optional): Dataset configuration dictionary.
@ -90,8 +88,7 @@ class YOLODataset(BaseDataset):
super().__init__(*args, channels=self.data.get("channels", 3), **kwargs)
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict:
"""
Cache dataset labels, check images and read shapes.
"""Cache dataset labels, check images and read shapes.
Args:
path (Path): Path where to save the cache file.
@ -158,8 +155,7 @@ class YOLODataset(BaseDataset):
return x
def get_labels(self) -> list[dict]:
"""
Return dictionary of labels for YOLO training.
"""Return dictionary of labels for YOLO training.
This method loads labels from disk or cache, verifies their integrity, and prepares them for training.
@ -208,8 +204,7 @@ class YOLODataset(BaseDataset):
return labels
def build_transforms(self, hyp: dict | None = None) -> Compose:
"""
Build and append transforms to the list.
"""Build and append transforms to the list.
Args:
hyp (dict, optional): Hyperparameters for transforms.
@ -240,8 +235,7 @@ class YOLODataset(BaseDataset):
return transforms
def close_mosaic(self, hyp: dict) -> None:
"""
Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
"""Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
Args:
hyp (dict): Hyperparameters for transforms.
@ -253,8 +247,7 @@ class YOLODataset(BaseDataset):
self.transforms = self.build_transforms(hyp)
def update_labels_info(self, label: dict) -> dict:
"""
Update label format for different tasks.
"""Update label format for different tasks.
Args:
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
@ -287,8 +280,7 @@ class YOLODataset(BaseDataset):
@staticmethod
def collate_fn(batch: list[dict]) -> dict:
"""
Collate data samples into batches.
"""Collate data samples into batches.
Args:
batch (list[dict]): List of dictionaries containing sample data.
@ -317,8 +309,7 @@ class YOLODataset(BaseDataset):
class YOLOMultiModalDataset(YOLODataset):
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
"""Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
This class extends YOLODataset to add text information for multi-modal model training, enabling models to process
both image and text data.
@ -334,8 +325,7 @@ class YOLOMultiModalDataset(YOLODataset):
"""
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
"""
Initialize a YOLOMultiModalDataset.
"""Initialize a YOLOMultiModalDataset.
Args:
data (dict, optional): Dataset configuration dictionary.
@ -346,8 +336,7 @@ class YOLOMultiModalDataset(YOLODataset):
super().__init__(*args, data=data, task=task, **kwargs)
def update_labels_info(self, label: dict) -> dict:
"""
Add text information for multi-modal model training.
"""Add text information for multi-modal model training.
Args:
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
@ -363,8 +352,7 @@ class YOLOMultiModalDataset(YOLODataset):
return labels
def build_transforms(self, hyp: dict | None = None) -> Compose:
"""
Enhance data transformations with optional text augmentation for multi-modal training.
"""Enhance data transformations with optional text augmentation for multi-modal training.
Args:
hyp (dict, optional): Hyperparameters for transforms.
@ -388,8 +376,7 @@ class YOLOMultiModalDataset(YOLODataset):
@property
def category_names(self):
"""
Return category names for the dataset.
"""Return category names for the dataset.
Returns:
(set[str]): List of class names.
@ -418,8 +405,7 @@ class YOLOMultiModalDataset(YOLODataset):
class GroundingDataset(YOLODataset):
"""
Dataset class for object detection tasks using annotations from a JSON file in grounding format.
"""Dataset class for object detection tasks using annotations from a JSON file in grounding format.
This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than the standard
YOLO format text files.
@ -438,8 +424,7 @@ class GroundingDataset(YOLODataset):
"""
def __init__(self, *args, task: str = "detect", json_file: str = "", max_samples: int = 80, **kwargs):
"""
Initialize a GroundingDataset for object detection.
"""Initialize a GroundingDataset for object detection.
Args:
json_file (str): Path to the JSON file containing annotations.
@ -454,8 +439,7 @@ class GroundingDataset(YOLODataset):
super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
def get_img_files(self, img_path: str) -> list:
"""
The image files would be read in `get_labels` function, return empty list here.
"""The image files would be read in `get_labels` function, return empty list here.
Args:
img_path (str): Path to the directory containing images.
@ -466,21 +450,19 @@ class GroundingDataset(YOLODataset):
return []
def verify_labels(self, labels: list[dict[str, Any]]) -> None:
"""
Verify the number of instances in the dataset matches expected counts.
"""Verify the number of instances in the dataset matches expected counts.
This method checks if the total number of bounding box instances in the provided labels matches the expected
count for known datasets. It performs validation against a predefined set of datasets with known instance
counts.
Args:
labels (list[dict[str, Any]]): List of label dictionaries, where each dictionary
contains dataset annotations. Each label dict must have a 'bboxes' key with a numpy array or tensor
containing bounding box coordinates.
labels (list[dict[str, Any]]): List of label dictionaries, where each dictionary contains dataset
annotations. Each label dict must have a 'bboxes' key with a numpy array or tensor containing bounding
box coordinates.
Raises:
AssertionError: If the actual instance count doesn't match the expected count
for a recognized dataset.
AssertionError: If the actual instance count doesn't match the expected count for a recognized dataset.
Notes:
For unrecognized datasets (those not in the predefined expected_counts),
@ -501,8 +483,7 @@ class GroundingDataset(YOLODataset):
LOGGER.warning(f"Skipping instance count verification for unrecognized dataset '{self.json_file}'")
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict[str, Any]:
"""
Load annotations from a JSON file, filter, and normalize bounding boxes for each image.
"""Load annotations from a JSON file, filter, and normalize bounding boxes for each image.
Args:
path (Path): Path where to save the cache file.
@ -592,8 +573,7 @@ class GroundingDataset(YOLODataset):
return x
def get_labels(self) -> list[dict]:
"""
Load labels from cache or generate them from JSON file.
"""Load labels from cache or generate them from JSON file.
Returns:
(list[dict]): List of label dictionaries, each containing information about an image and its annotations.
@ -614,8 +594,7 @@ class GroundingDataset(YOLODataset):
return labels
def build_transforms(self, hyp: dict | None = None) -> Compose:
"""
Configure augmentations for training with optional text loading.
"""Configure augmentations for training with optional text loading.
Args:
hyp (dict, optional): Hyperparameters for transforms.
@ -661,8 +640,7 @@ class GroundingDataset(YOLODataset):
class YOLOConcatDataset(ConcatDataset):
"""
Dataset as a concatenation of multiple datasets.
"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same collation
function.
@ -678,8 +656,7 @@ class YOLOConcatDataset(ConcatDataset):
@staticmethod
def collate_fn(batch: list[dict]) -> dict:
"""
Collate data samples into batches.
"""Collate data samples into batches.
Args:
batch (list[dict]): List of dictionaries containing sample data.
@ -690,8 +667,7 @@ class YOLOConcatDataset(ConcatDataset):
return YOLODataset.collate_fn(batch)
def close_mosaic(self, hyp: dict) -> None:
"""
Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
"""Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
Args:
hyp (dict): Hyperparameters for transforms.
@ -712,8 +688,7 @@ class SemanticDataset(BaseDataset):
class ClassificationDataset:
"""
Dataset class for image classification tasks extending torchvision ImageFolder functionality.
"""Dataset class for image classification tasks extending torchvision ImageFolder functionality.
This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently
handle large datasets for training deep learning models, with optional image transformations and caching mechanisms
@ -735,8 +710,7 @@ class ClassificationDataset:
"""
def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
"""
Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings.
"""Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings.
Args:
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
@ -787,8 +761,7 @@ class ClassificationDataset:
)
def __getitem__(self, i: int) -> dict:
"""
Return subset of data and targets corresponding to given indices.
"""Return subset of data and targets corresponding to given indices.
Args:
i (int): Index of the sample to retrieve.
@ -816,8 +789,7 @@ class ClassificationDataset:
return len(self.samples)
def verify_images(self) -> list[tuple]:
"""
Verify all images in dataset.
"""Verify all images in dataset.
Returns:
(list): List of valid samples after verification.

View file

@ -25,8 +25,7 @@ from ultralytics.utils.patches import imread
@dataclass
class SourceTypes:
"""
Class to represent various types of input sources for predictions.
"""Class to represent various types of input sources for predictions.
This class uses dataclass to define boolean flags for different types of input sources that can be used for making
predictions with YOLO models.
@ -52,8 +51,7 @@ class SourceTypes:
class LoadStreams:
"""
Stream Loader for various types of video streams.
"""Stream Loader for various types of video streams.
Supports RTSP, RTMP, HTTP, and TCP streams. This class handles the loading and processing of multiple video streams
simultaneously, making it suitable for real-time video analysis tasks.
@ -94,8 +92,7 @@ class LoadStreams:
"""
def __init__(self, sources: str = "file.streams", vid_stride: int = 1, buffer: bool = False, channels: int = 3):
"""
Initialize stream loader for multiple video sources, supporting various stream types.
"""Initialize stream loader for multiple video sources, supporting various stream types.
Args:
sources (str): Path to streams file or single stream URL.
@ -227,8 +224,7 @@ class LoadStreams:
class LoadScreenshots:
"""
Ultralytics screenshot dataloader for capturing and processing screen images.
"""Ultralytics screenshot dataloader for capturing and processing screen images.
This class manages the loading of screenshot images for processing with YOLO. It is suitable for use with `yolo
predict source=screen`.
@ -259,8 +255,7 @@ class LoadScreenshots:
"""
def __init__(self, source: str, channels: int = 3):
"""
Initialize screenshot capture with specified screen and region parameters.
"""Initialize screenshot capture with specified screen and region parameters.
Args:
source (str): Screen capture source string in format "screen_num left top width height".
@ -307,8 +302,7 @@ class LoadScreenshots:
class LoadImagesAndVideos:
"""
A class for loading and processing images and videos for YOLO object detection.
"""A class for loading and processing images and videos for YOLO object detection.
This class manages the loading and pre-processing of image and video data from various sources, including single
image files, video files, and lists of image and video paths.
@ -347,8 +341,7 @@ class LoadImagesAndVideos:
"""
def __init__(self, path: str | Path | list, batch: int = 1, vid_stride: int = 1, channels: int = 3):
"""
Initialize dataloader for images and videos, supporting various input formats.
"""Initialize dataloader for images and videos, supporting various input formats.
Args:
path (str | Path | list): Path to images/videos, directory, or list of paths.
@ -490,8 +483,7 @@ class LoadImagesAndVideos:
class LoadPilAndNumpy:
"""
Load images from PIL and Numpy arrays for batch processing.
"""Load images from PIL and Numpy arrays for batch processing.
This class manages loading and pre-processing of image data from both PIL and Numpy formats. It performs basic
validation and format conversion to ensure that the images are in the required format for downstream processing.
@ -517,8 +509,7 @@ class LoadPilAndNumpy:
"""
def __init__(self, im0: Image.Image | np.ndarray | list, channels: int = 3):
"""
Initialize a loader for PIL and Numpy images, converting inputs to a standardized format.
"""Initialize a loader for PIL and Numpy images, converting inputs to a standardized format.
Args:
im0 (PIL.Image.Image | np.ndarray | list): Single image or list of images in PIL or numpy format.
@ -564,8 +555,7 @@ class LoadPilAndNumpy:
class LoadTensor:
"""
A class for loading and processing tensor data for object detection tasks.
"""A class for loading and processing tensor data for object detection tasks.
This class handles the loading and pre-processing of image data from PyTorch tensors, preparing them for further
processing in object detection pipelines.
@ -588,8 +578,7 @@ class LoadTensor:
"""
def __init__(self, im0: torch.Tensor) -> None:
"""
Initialize LoadTensor object for processing torch.Tensor image data.
"""Initialize LoadTensor object for processing torch.Tensor image data.
Args:
im0 (torch.Tensor): Input tensor with shape (B, C, H, W).
@ -656,8 +645,7 @@ def autocast_list(source: list[Any]) -> list[Image.Image | np.ndarray]:
def get_best_youtube_url(url: str, method: str = "pytube") -> str | None:
"""
Retrieve the URL of the best quality MP4 video stream from a given YouTube video.
"""Retrieve the URL of the best quality MP4 video stream from a given YouTube video.
Args:
url (str): The URL of the YouTube video.

View file

@ -11,8 +11,7 @@ from ultralytics.utils import DATASETS_DIR, LOGGER, TQDM
def split_classify_dataset(source_dir: str | Path, train_ratio: float = 0.8) -> Path:
"""
Split classification dataset into train and val directories in a new directory.
"""Split classification dataset into train and val directories in a new directory.
Creates a new directory '{source_dir}_split' with train/val subdirectories, preserving the original class structure
with an 80/20 split by default.
@ -101,8 +100,8 @@ def autosplit(
weights: tuple[float, float, float] = (0.9, 0.1, 0.0),
annotated_only: bool = False,
) -> None:
"""
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
"""Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt
files.
Args:
path (Path): Path to images directory.

View file

@ -18,8 +18,7 @@ from ultralytics.utils.checks import check_requirements
def bbox_iof(polygon1: np.ndarray, bbox2: np.ndarray, eps: float = 1e-6) -> np.ndarray:
"""
Calculate Intersection over Foreground (IoF) between polygons and bounding boxes.
"""Calculate Intersection over Foreground (IoF) between polygons and bounding boxes.
Args:
polygon1 (np.ndarray): Polygon coordinates with shape (N, 8).
@ -65,8 +64,7 @@ def bbox_iof(polygon1: np.ndarray, bbox2: np.ndarray, eps: float = 1e-6) -> np.n
def load_yolo_dota(data_root: str, split: str = "train") -> list[dict[str, Any]]:
"""
Load DOTA dataset annotations and image information.
"""Load DOTA dataset annotations and image information.
Args:
data_root (str): Data root directory.
@ -107,8 +105,7 @@ def get_windows(
im_rate_thr: float = 0.6,
eps: float = 0.01,
) -> np.ndarray:
"""
Get the coordinates of sliding windows for image cropping.
"""Get the coordinates of sliding windows for image cropping.
Args:
im_size (tuple[int, int]): Original image size, (H, W).
@ -175,8 +172,7 @@ def crop_and_save(
lb_dir: str,
allow_background_images: bool = True,
) -> None:
"""
Crop images and save new labels for each window.
"""Crop images and save new labels for each window.
Args:
anno (dict[str, Any]): Annotation dict, including 'filepath', 'label', 'ori_size' as its keys.
@ -226,8 +222,7 @@ def split_images_and_labels(
crop_sizes: tuple[int, ...] = (1024,),
gaps: tuple[int, ...] = (200,),
) -> None:
"""
Split both images and labels for a given dataset split.
"""Split both images and labels for a given dataset split.
Args:
data_root (str): Root directory of the dataset.
@ -265,8 +260,7 @@ def split_images_and_labels(
def split_trainval(
data_root: str, save_dir: str, crop_size: int = 1024, gap: int = 200, rates: tuple[float, ...] = (1.0,)
) -> None:
"""
Split train and val sets of DOTA dataset with multiple scaling rates.
"""Split train and val sets of DOTA dataset with multiple scaling rates.
Args:
data_root (str): Root directory of the dataset.
@ -304,8 +298,7 @@ def split_trainval(
def split_test(
data_root: str, save_dir: str, crop_size: int = 1024, gap: int = 200, rates: tuple[float, ...] = (1.0,)
) -> None:
"""
Split test set of DOTA dataset, labels are not included within this set.
"""Split test set of DOTA dataset, labels are not included within this set.
Args:
data_root (str): Root directory of the dataset.

View file

@ -51,8 +51,7 @@ def img2label_paths(img_paths: list[str]) -> list[str]:
def check_file_speeds(
files: list[str], threshold_ms: float = 10, threshold_mb: float = 50, max_files: int = 5, prefix: str = ""
):
"""
Check dataset file access speed and provide performance feedback.
"""Check dataset file access speed and provide performance feedback.
This function tests the access speed of dataset files by measuring ping (stat call) time and read speed. It samples
up to 5 files from the provided list and warns if access times exceed the threshold.
@ -251,8 +250,7 @@ def verify_image_label(args: tuple) -> list:
def visualize_image_annotations(image_path: str, txt_path: str, label_map: dict[int, str]):
"""
Visualize YOLO annotations (bounding boxes and class labels) on an image.
"""Visualize YOLO annotations (bounding boxes and class labels) on an image.
This function reads an image and its corresponding annotation file in YOLO format, then draws bounding boxes around
detected objects and labels them with their respective class names. The bounding box colors are assigned based on
@ -297,13 +295,12 @@ def visualize_image_annotations(image_path: str, txt_path: str, label_map: dict[
def polygon2mask(
imgsz: tuple[int, int], polygons: list[np.ndarray], color: int = 1, downsample_ratio: int = 1
) -> np.ndarray:
"""
Convert a list of polygons to a binary mask of the specified image size.
"""Convert a list of polygons to a binary mask of the specified image size.
Args:
imgsz (tuple[int, int]): The size of the image as (height, width).
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where
N is the number of polygons, and M is the number of points such that M % 2 = 0.
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where N is the
number of polygons, and M is the number of points such that M % 2 = 0.
color (int, optional): The color value to fill in the polygons on the mask.
downsample_ratio (int, optional): Factor by which to downsample the mask.
@ -322,13 +319,12 @@ def polygon2mask(
def polygons2masks(
imgsz: tuple[int, int], polygons: list[np.ndarray], color: int, downsample_ratio: int = 1
) -> np.ndarray:
"""
Convert a list of polygons to a set of binary masks of the specified image size.
"""Convert a list of polygons to a set of binary masks of the specified image size.
Args:
imgsz (tuple[int, int]): The size of the image as (height, width).
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where
N is the number of polygons, and M is the number of points such that M % 2 = 0.
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where N is the
number of polygons, and M is the number of points such that M % 2 = 0.
color (int): The color value to fill in the polygons on the masks.
downsample_ratio (int, optional): Factor by which to downsample each mask.
@ -368,8 +364,7 @@ def polygons2masks_overlap(
def find_dataset_yaml(path: Path) -> Path:
"""
Find and return the YAML file associated with a Detect, Segment or Pose dataset.
"""Find and return the YAML file associated with a Detect, Segment or Pose dataset.
This function searches for a YAML file at the root level of the provided directory first, and if not found, it
performs a recursive search. It prefers YAML files that have the same stem as the provided path.
@ -389,8 +384,7 @@ def find_dataset_yaml(path: Path) -> Path:
def check_det_dataset(dataset: str, autodownload: bool = True) -> dict[str, Any]:
"""
Download, verify, and/or unzip a dataset if not found locally.
"""Download, verify, and/or unzip a dataset if not found locally.
This function checks the availability of a specified dataset, and if not found, it has the option to download and
unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
@ -484,8 +478,7 @@ def check_det_dataset(dataset: str, autodownload: bool = True) -> dict[str, Any]
def check_cls_dataset(dataset: str | Path, split: str = "") -> dict[str, Any]:
"""
Check a classification dataset such as Imagenet.
"""Check a classification dataset such as Imagenet.
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information. If the
dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
@ -581,8 +574,7 @@ def check_cls_dataset(dataset: str | Path, split: str = "") -> dict[str, Any]:
class HUBDatasetStats:
"""
A class for generating HUB dataset JSON and `-hub` dataset directory.
"""A class for generating HUB dataset JSON and `-hub` dataset directory.
Args:
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip).
@ -748,10 +740,9 @@ class HUBDatasetStats:
def compress_one_image(f: str, f_new: str | None = None, max_dim: int = 1920, quality: int = 50):
"""
Compress a single image file to reduced size while preserving its aspect ratio and quality using either the Python
Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
resized.
"""Compress a single image file to reduced size while preserving its aspect ratio and quality using either the
Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it
will not be resized.
Args:
f (str): The path to the input image file.

View file

@ -192,8 +192,7 @@ def best_onnx_opset(onnx, cuda=False) -> int:
def validate_args(format, passed_args, valid_args):
"""
Validate arguments based on the export format.
"""Validate arguments based on the export format.
Args:
format (str): The export format.
@ -238,8 +237,7 @@ def try_export(inner_func):
class Exporter:
"""
A class for exporting YOLO models to various formats.
"""A class for exporting YOLO models to various formats.
This class provides functionality to export YOLO models to different formats including ONNX, TensorRT, CoreML,
TensorFlow, and others. It handles format validation, device selection, model preparation, and the actual export
@ -289,8 +287,7 @@ class Exporter:
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the Exporter class.
"""Initialize the Exporter class.
Args:
cfg (str, optional): Path to a configuration file.
@ -1359,8 +1356,7 @@ class IOSDetectModel(torch.nn.Module):
"""Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
def __init__(self, model, im, mlprogram=True):
"""
Initialize the IOSDetectModel class with a YOLO model and example image.
"""Initialize the IOSDetectModel class with a YOLO model and example image.
Args:
model (torch.nn.Module): The YOLO model to wrap.
@ -1394,8 +1390,7 @@ class NMSModel(torch.nn.Module):
"""Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
def __init__(self, model, args):
"""
Initialize the NMSModel.
"""Initialize the NMSModel.
Args:
model (torch.nn.Module): The model to wrap with NMS postprocessing.
@ -1408,15 +1403,14 @@ class NMSModel(torch.nn.Module):
self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
def forward(self, x):
"""
Perform inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
"""Perform inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
Args:
x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).
Returns:
(torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the
number of detections after NMS.
(torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the number
of detections after NMS.
"""
from functools import partial

View file

@ -27,8 +27,7 @@ from ultralytics.utils import (
class Model(torch.nn.Module):
"""
A base class for implementing YOLO models, unifying APIs across different model types.
"""A base class for implementing YOLO models, unifying APIs across different model types.
This class provides a common interface for various operations related to YOLO models, such as training, validation,
prediction, exporting, and benchmarking. It handles different types of models, including those loaded from local
@ -85,19 +84,17 @@ class Model(torch.nn.Module):
task: str | None = None,
verbose: bool = False,
) -> None:
"""
Initialize a new instance of the YOLO model class.
"""Initialize a new instance of the YOLO model class.
This constructor sets up the model based on the provided model path or name. It handles various types of model
sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several
important attributes of the model and prepares it for operations like training, prediction, or export.
Args:
model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a
model name from Ultralytics HUB, a Triton Server model, or an already initialized Model instance.
model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a model
name from Ultralytics HUB, a Triton Server model, or an already initialized Model instance.
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
verbose (bool): If True, enables verbose output during the model's initialization and subsequent
operations.
verbose (bool): If True, enables verbose output during the model's initialization and subsequent operations.
Raises:
FileNotFoundError: If the specified model file does not exist or is inaccessible.
@ -160,22 +157,21 @@ class Model(torch.nn.Module):
stream: bool = False,
**kwargs: Any,
) -> list:
"""
Alias for the predict method, enabling the model instance to be callable for predictions.
"""Alias for the predict method, enabling the model instance to be callable for predictions.
This method simplifies the process of making predictions by allowing the model instance to be called directly
with the required arguments.
Args:
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of
the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch tensor, or
a list/tuple of these.
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image(s)
to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch tensor, or a list/tuple
of these.
stream (bool): If True, treat the input source as a continuous stream for predictions.
**kwargs (Any): Additional keyword arguments to configure the prediction process.
Returns:
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
Results object.
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a Results
object.
Examples:
>>> model = YOLO("yolo11n.pt")
@ -187,8 +183,7 @@ class Model(torch.nn.Module):
@staticmethod
def is_triton_model(model: str) -> bool:
"""
Check if the given model string is a Triton Server URL.
"""Check if the given model string is a Triton Server URL.
This static method determines whether the provided model string represents a valid Triton Server URL by parsing
its components using urllib.parse.urlsplit().
@ -212,8 +207,7 @@ class Model(torch.nn.Module):
@staticmethod
def is_hub_model(model: str) -> bool:
"""
Check if the provided model is an Ultralytics HUB model.
"""Check if the provided model is an Ultralytics HUB model.
This static method determines whether the given model string represents a valid Ultralytics HUB model
identifier.
@ -235,8 +229,7 @@ class Model(torch.nn.Module):
return model.startswith(f"{HUB_WEB_ROOT}/models/")
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
"""
Initialize a new model and infer the task type from model definitions.
"""Initialize a new model and infer the task type from model definitions.
Creates a new model instance based on the provided configuration file. Loads the model configuration, infers the
task type if not specified, and initializes the model using the appropriate class from the task map.
@ -244,8 +237,8 @@ class Model(torch.nn.Module):
Args:
cfg (str): Path to the model configuration file in YAML format.
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of
creating a new one.
model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of creating
a new one.
verbose (bool): If True, displays model information during loading.
Raises:
@ -269,8 +262,7 @@ class Model(torch.nn.Module):
self.model_name = cfg
def _load(self, weights: str, task=None) -> None:
"""
Load a model from a checkpoint file or initialize it from a weights file.
"""Load a model from a checkpoint file or initialize it from a weights file.
This method handles loading models from either .pt checkpoint files or other weight file formats. It sets up the
model, task, and related attributes based on the loaded weights.
@ -307,8 +299,7 @@ class Model(torch.nn.Module):
self.model_name = weights
def _check_is_pytorch_model(self) -> None:
"""
Check if the model is a PyTorch model and raise TypeError if it's not.
"""Check if the model is a PyTorch model and raise TypeError if it's not.
This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that certain
operations that require a PyTorch model are only performed on compatible model types.
@ -335,8 +326,7 @@ class Model(torch.nn.Module):
)
def reset_weights(self) -> Model:
"""
Reset the model's weights to their initial state.
"""Reset the model's weights to their initial state.
This method iterates through all modules in the model and resets their parameters if they have a
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them
@ -361,8 +351,7 @@ class Model(torch.nn.Module):
return self
def load(self, weights: str | Path = "yolo11n.pt") -> Model:
"""
Load parameters from the specified weights file into the model.
"""Load parameters from the specified weights file into the model.
This method supports loading weights from a file or directly from a weights object. It matches parameters by
name and shape and transfers them to the model.
@ -389,8 +378,7 @@ class Model(torch.nn.Module):
return self
def save(self, filename: str | Path = "saved_model.pt") -> None:
"""
Save the current model state to a file.
"""Save the current model state to a file.
This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as the
date, Ultralytics version, license information, and a link to the documentation.
@ -421,8 +409,7 @@ class Model(torch.nn.Module):
torch.save({**self.ckpt, **updates}, filename)
def info(self, detailed: bool = False, verbose: bool = True):
"""
Display model information.
"""Display model information.
This method provides an overview or detailed information about the model, depending on the arguments
passed. It can control the verbosity of the output and return the information as a list.
@ -432,8 +419,8 @@ class Model(torch.nn.Module):
verbose (bool): If True, prints the information. If False, returns the information as a list.
Returns:
(list[str]): A list of strings containing various types of information about the model, including
model summary, layer details, and parameter counts. Empty if verbose is True.
(list[str]): A list of strings containing various types of information about the model, including model
summary, layer details, and parameter counts. Empty if verbose is True.
Examples:
>>> model = Model("yolo11n.pt")
@ -444,8 +431,7 @@ class Model(torch.nn.Module):
return self.model.info(detailed=detailed, verbose=verbose)
def fuse(self) -> None:
"""
Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.
"""Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.
This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers into a
single layer. This fusion can significantly improve inference speed by reducing the number of operations and
@ -469,15 +455,14 @@ class Model(torch.nn.Module):
stream: bool = False,
**kwargs: Any,
) -> list:
"""
Generate image embeddings based on the provided source.
"""Generate image embeddings based on the provided source.
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
source. It allows customization of the embedding process through various keyword arguments.
Args:
source (str | Path | int | list | tuple | np.ndarray | torch.Tensor): The source of the image for
generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.
source (str | Path | int | list | tuple | np.ndarray | torch.Tensor): The source of the image for generating
embeddings. Can be a file path, URL, PIL image, numpy array, etc.
stream (bool): If True, predictions are streamed.
**kwargs (Any): Additional keyword arguments for configuring the embedding process.
@ -501,25 +486,24 @@ class Model(torch.nn.Module):
predictor=None,
**kwargs: Any,
) -> list[Results]:
"""
Perform predictions on the given image source using the YOLO model.
"""Perform predictions on the given image source using the YOLO model.
This method facilitates the prediction process, allowing various configurations through keyword arguments. It
supports predictions with custom predictors or the default predictor method. The method handles different types
of image sources and can operate in a streaming mode.
Args:
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL images,
numpy arrays, and torch tensors.
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image(s)
to make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
torch tensors.
stream (bool): If True, treats the input source as a continuous stream for predictions.
predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
If None, the method uses a default predictor.
predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions. If
None, the method uses a default predictor.
**kwargs (Any): Additional keyword arguments for configuring the prediction process.
Returns:
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
Results object.
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a Results
object.
Examples:
>>> model = YOLO("yolo11n.pt")
@ -562,8 +546,7 @@ class Model(torch.nn.Module):
persist: bool = False,
**kwargs: Any,
) -> list[Results]:
"""
Conduct object tracking on the specified input source using the registered trackers.
"""Conduct object tracking on the specified input source using the registered trackers.
This method performs object tracking using the model's predictors and optionally registered trackers. It handles
various input sources such as file paths or video streams, and supports customization through keyword arguments.
@ -604,8 +587,7 @@ class Model(torch.nn.Module):
validator=None,
**kwargs: Any,
):
"""
Validate the model using a specified dataset and validation configuration.
"""Validate the model using a specified dataset and validation configuration.
This method facilitates the model validation process, allowing for customization through various settings. It
supports validation with a custom validator or the default validation approach. The method combines default
@ -636,8 +618,7 @@ class Model(torch.nn.Module):
return validator.metrics
def benchmark(self, data=None, format="", verbose=False, **kwargs: Any):
"""
Benchmark the model across various export formats to evaluate performance.
"""Benchmark the model across various export formats to evaluate performance.
This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. It
uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured using
@ -655,8 +636,8 @@ class Model(torch.nn.Module):
- device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
Returns:
(dict): A dictionary containing the results of the benchmarking process, including metrics for
different export formats.
(dict): A dictionary containing the results of the benchmarking process, including metrics for different
export formats.
Raises:
AssertionError: If the model is not a PyTorch model.
@ -690,8 +671,7 @@ class Model(torch.nn.Module):
self,
**kwargs: Any,
) -> str:
"""
Export the model to a different format suitable for deployment.
"""Export the model to a different format suitable for deployment.
This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
@ -738,8 +718,7 @@ class Model(torch.nn.Module):
trainer=None,
**kwargs: Any,
):
"""
Train the model using the specified dataset and training configuration.
"""Train the model using the specified dataset and training configuration.
This method facilitates model training with a range of customizable settings. It supports training with a custom
trainer or the default training approach. The method handles scenarios such as resuming training from a
@ -811,8 +790,7 @@ class Model(torch.nn.Module):
*args: Any,
**kwargs: Any,
):
"""
Conduct hyperparameter tuning for the model, with an option to use Ray Tune.
"""Conduct hyperparameter tuning for the model, with an option to use Ray Tune.
This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. When Ray Tune
is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. Otherwise, it uses
@ -853,16 +831,15 @@ class Model(torch.nn.Module):
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
def _apply(self, fn) -> Model:
"""
Apply a function to model tensors that are not parameters or registered buffers.
"""Apply a function to model tensors that are not parameters or registered buffers.
This method extends the functionality of the parent class's _apply method by additionally resetting the
predictor and updating the device in the model's overrides. It's typically used for operations like moving the
model to a different device or changing its precision.
Args:
fn (Callable): A function to be applied to the model's tensors. This is typically a method like
to(), cpu(), cuda(), half(), or float().
fn (Callable): A function to be applied to the model's tensors. This is typically a method like to(), cpu(),
cuda(), half(), or float().
Returns:
(Model): The model instance with the function applied and updated attributes.
@ -882,8 +859,7 @@ class Model(torch.nn.Module):
@property
def names(self) -> dict[int, str]:
"""
Retrieve the class names associated with the loaded model.
"""Retrieve the class names associated with the loaded model.
This property returns the class names if they are defined in the model. It checks the class names for validity
using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
@ -913,8 +889,7 @@ class Model(torch.nn.Module):
@property
def device(self) -> torch.device:
"""
Get the device on which the model's parameters are allocated.
"""Get the device on which the model's parameters are allocated.
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
applicable only to models that are instances of torch.nn.Module.
@ -937,8 +912,7 @@ class Model(torch.nn.Module):
@property
def transforms(self):
"""
Retrieve the transformations applied to the input data of the loaded model.
"""Retrieve the transformations applied to the input data of the loaded model.
This property returns the transformations if they are defined in the model. The transforms typically include
preprocessing steps like resizing, normalization, and data augmentation that are applied to input data before it
@ -958,18 +932,17 @@ class Model(torch.nn.Module):
return self.model.transforms if hasattr(self.model, "transforms") else None
def add_callback(self, event: str, func) -> None:
"""
Add a callback function for a specified event.
"""Add a callback function for a specified event.
This method allows registering custom callback functions that are triggered on specific events during model
operations such as training or inference. Callbacks provide a way to extend and customize the behavior of the
model at various stages of its lifecycle.
Args:
event (str): The name of the event to attach the callback to. Must be a valid event name recognized
by the Ultralytics framework.
func (Callable): The callback function to be registered. This function will be called when the
specified event occurs.
event (str): The name of the event to attach the callback to. Must be a valid event name recognized by the
Ultralytics framework.
func (Callable): The callback function to be registered. This function will be called when the specified
event occurs.
Raises:
ValueError: If the event name is not recognized or is invalid.
@ -984,8 +957,7 @@ class Model(torch.nn.Module):
self.callbacks[event].append(func)
def clear_callback(self, event: str) -> None:
"""
Clear all callback functions registered for a specified event.
"""Clear all callback functions registered for a specified event.
This method removes all custom and default callback functions associated with the given event. It resets the
callback list for the specified event to an empty list, effectively removing all registered callbacks for that
@ -1012,8 +984,7 @@ class Model(torch.nn.Module):
self.callbacks[event] = []
def reset_callbacks(self) -> None:
"""
Reset all callbacks to their default functions.
"""Reset all callbacks to their default functions.
This method reinstates the default callback functions for all events, removing any custom callbacks that were
previously added. It iterates through all default callback events and replaces the current callbacks with the
@ -1036,8 +1007,7 @@ class Model(torch.nn.Module):
@staticmethod
def _reset_ckpt_args(args: dict[str, Any]) -> dict[str, Any]:
"""
Reset specific arguments when loading a PyTorch model checkpoint.
"""Reset specific arguments when loading a PyTorch model checkpoint.
This method filters the input arguments dictionary to retain only a specific set of keys that are considered
important for model loading. It's used to ensure that only relevant arguments are preserved when loading a model
@ -1064,8 +1034,7 @@ class Model(torch.nn.Module):
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def _smart_load(self, key: str):
"""
Intelligently load the appropriate module based on the model task.
"""Intelligently load the appropriate module based on the model task.
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor) based
on the current task of the model and the provided key. It uses the task_map dictionary to determine the
@ -1094,8 +1063,7 @@ class Model(torch.nn.Module):
@property
def task_map(self) -> dict:
"""
Provide a mapping from model tasks to corresponding classes for different modes.
"""Provide a mapping from model tasks to corresponding classes for different modes.
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify) to a
nested dictionary. The nested dictionary contains mappings for different operational modes (model, trainer,
@ -1119,8 +1087,7 @@ class Model(torch.nn.Module):
raise NotImplementedError("Please provide task map for your model!")
def eval(self):
"""
Sets the model to evaluation mode.
"""Sets the model to evaluation mode.
This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization
that behave differently during training and evaluation. In evaluation mode, these layers use running statistics
@ -1138,8 +1105,7 @@ class Model(torch.nn.Module):
return self
def __getattr__(self, name):
"""
Enable accessing model attributes directly through the Model class.
"""Enable accessing model attributes directly through the Model class.
This method provides a way to access attributes of the underlying model directly through the Model class
instance. It first checks if the requested attribute is 'model', in which case it returns the model from

View file

@ -68,8 +68,7 @@ Example:
class BasePredictor:
"""
A base class for creating predictors.
"""A base class for creating predictors.
This class provides the foundation for prediction functionality, handling model setup, inference, and result
processing across various input sources.
@ -115,8 +114,7 @@ class BasePredictor:
overrides: dict[str, Any] | None = None,
_callbacks: dict[str, list[callable]] | None = None,
):
"""
Initialize the BasePredictor class.
"""Initialize the BasePredictor class.
Args:
cfg (str | dict): Path to a configuration file or a configuration dictionary.
@ -151,8 +149,7 @@ class BasePredictor:
callbacks.add_integration_callbacks(self)
def preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor:
"""
Prepare input image before inference.
"""Prepare input image before inference.
Args:
im (torch.Tensor | list[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.
@ -185,8 +182,7 @@ class BasePredictor:
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
def pre_transform(self, im: list[np.ndarray]) -> list[np.ndarray]:
"""
Pre-transform input image before inference.
"""Pre-transform input image before inference.
Args:
im (list[np.ndarray]): List of images with shape [(H, W, 3) x N].
@ -209,8 +205,7 @@ class BasePredictor:
return preds
def __call__(self, source=None, model=None, stream: bool = False, *args, **kwargs):
"""
Perform inference on an image or stream.
"""Perform inference on an image or stream.
Args:
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
@ -230,8 +225,7 @@ class BasePredictor:
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
def predict_cli(self, source=None, model=None):
"""
Method used for Command Line Interface (CLI) prediction.
"""Method used for Command Line Interface (CLI) prediction.
This function is designed to run predictions using the CLI. It sets up the source and model, then processes the
inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
@ -251,12 +245,11 @@ class BasePredictor:
pass
def setup_source(self, source):
"""
Set up source and inference mode.
"""Set up source and inference mode.
Args:
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor):
Source for inference.
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor): Source for
inference.
"""
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
self.dataset = load_inference_source(
@ -282,8 +275,7 @@ class BasePredictor:
@smart_inference_mode()
def stream_inference(self, source=None, model=None, *args, **kwargs):
"""
Stream real-time inference on camera feed and save results to file.
"""Stream real-time inference on camera feed and save results to file.
Args:
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
@ -388,8 +380,7 @@ class BasePredictor:
self.run_callbacks("on_predict_end")
def setup_model(self, model, verbose: bool = True):
"""
Initialize YOLO model with given parameters and set it to evaluation mode.
"""Initialize YOLO model with given parameters and set it to evaluation mode.
Args:
model (str | Path | torch.nn.Module, optional): Model to load or use.
@ -413,8 +404,7 @@ class BasePredictor:
self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
def write_results(self, i: int, p: Path, im: torch.Tensor, s: list[str]) -> str:
"""
Write inference results to a file or directory.
"""Write inference results to a file or directory.
Args:
i (int): Index of the current image in the batch.
@ -464,8 +454,7 @@ class BasePredictor:
return string
def save_predicted_images(self, save_path: Path, frame: int = 0):
"""
Save video predictions as mp4 or images as jpg at specified path.
"""Save video predictions as mp4 or images as jpg at specified path.
Args:
save_path (Path): Path to save the results.

View file

@ -21,8 +21,7 @@ from ultralytics.utils.plotting import Annotator, colors, save_one_box
class BaseTensor(SimpleClass):
"""
Base tensor class with additional methods for easy manipulation and device handling.
"""Base tensor class with additional methods for easy manipulation and device handling.
This class provides a foundation for tensor-like objects with device management capabilities, supporting both
PyTorch tensors and NumPy arrays. It includes methods for moving data between devices and converting between tensor
@ -49,8 +48,7 @@ class BaseTensor(SimpleClass):
"""
def __init__(self, data: torch.Tensor | np.ndarray, orig_shape: tuple[int, int]) -> None:
"""
Initialize BaseTensor with prediction data and the original shape of the image.
"""Initialize BaseTensor with prediction data and the original shape of the image.
Args:
data (torch.Tensor | np.ndarray): Prediction data such as bounding boxes, masks, or keypoints.
@ -68,8 +66,7 @@ class BaseTensor(SimpleClass):
@property
def shape(self) -> tuple[int, ...]:
"""
Return the shape of the underlying data tensor.
"""Return the shape of the underlying data tensor.
Returns:
(tuple[int, ...]): The shape of the data tensor.
@ -83,8 +80,7 @@ class BaseTensor(SimpleClass):
return self.data.shape
def cpu(self):
"""
Return a copy of the tensor stored in CPU memory.
"""Return a copy of the tensor stored in CPU memory.
Returns:
(BaseTensor): A new BaseTensor object with the data tensor moved to CPU memory.
@ -101,8 +97,7 @@ class BaseTensor(SimpleClass):
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
def numpy(self):
"""
Return a copy of the tensor as a numpy array.
"""Return a copy of the tensor as a numpy array.
Returns:
(np.ndarray): A numpy array containing the same data as the original tensor.
@ -118,12 +113,11 @@ class BaseTensor(SimpleClass):
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
def cuda(self):
"""
Move the tensor to GPU memory.
"""Move the tensor to GPU memory.
Returns:
(BaseTensor): A new BaseTensor instance with the data moved to GPU memory if it's not already a
numpy array, otherwise returns self.
(BaseTensor): A new BaseTensor instance with the data moved to GPU memory if it's not already a numpy array,
otherwise returns self.
Examples:
>>> import torch
@ -137,8 +131,7 @@ class BaseTensor(SimpleClass):
return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)
def to(self, *args, **kwargs):
"""
Return a copy of the tensor with the specified device and dtype.
"""Return a copy of the tensor with the specified device and dtype.
Args:
*args (Any): Variable length argument list to be passed to torch.Tensor.to().
@ -155,8 +148,7 @@ class BaseTensor(SimpleClass):
return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)
def __len__(self) -> int:
"""
Return the length of the underlying data tensor.
"""Return the length of the underlying data tensor.
Returns:
(int): The number of elements in the first dimension of the data tensor.
@ -170,8 +162,7 @@ class BaseTensor(SimpleClass):
return len(self.data)
def __getitem__(self, idx):
"""
Return a new BaseTensor instance containing the specified indexed elements of the data tensor.
"""Return a new BaseTensor instance containing the specified indexed elements of the data tensor.
Args:
idx (int | list[int] | torch.Tensor): Index or indices to select from the data tensor.
@ -190,8 +181,7 @@ class BaseTensor(SimpleClass):
class Results(SimpleClass, DataExportMixin):
"""
A class for storing and manipulating inference results.
"""A class for storing and manipulating inference results.
This class provides comprehensive functionality for handling inference results from various Ultralytics models,
including detection, segmentation, classification, and pose estimation. It supports visualization, data export, and
@ -249,8 +239,7 @@ class Results(SimpleClass, DataExportMixin):
obb: torch.Tensor | None = None,
speed: dict[str, float] | None = None,
) -> None:
"""
Initialize the Results class for storing and manipulating inference results.
"""Initialize the Results class for storing and manipulating inference results.
Args:
orig_img (np.ndarray): The original image as a numpy array.
@ -290,8 +279,7 @@ class Results(SimpleClass, DataExportMixin):
self._keys = "boxes", "masks", "probs", "keypoints", "obb"
def __getitem__(self, idx):
"""
Return a Results object for a specific index of inference results.
"""Return a Results object for a specific index of inference results.
Args:
idx (int | slice): Index or slice to retrieve from the Results object.
@ -307,12 +295,11 @@ class Results(SimpleClass, DataExportMixin):
return self._apply("__getitem__", idx)
def __len__(self) -> int:
"""
Return the number of detections in the Results object.
"""Return the number of detections in the Results object.
Returns:
(int): The number of detections, determined by the length of the first non-empty
attribute in (masks, probs, keypoints, or obb).
(int): The number of detections, determined by the length of the first non-empty attribute in (masks, probs,
keypoints, or obb).
Examples:
>>> results = Results(orig_img, path, names, boxes=torch.rand(5, 4))
@ -332,15 +319,14 @@ class Results(SimpleClass, DataExportMixin):
obb: torch.Tensor | None = None,
keypoints: torch.Tensor | None = None,
):
"""
Update the Results object with new detection data.
"""Update the Results object with new detection data.
This method allows updating the boxes, masks, probabilities, and oriented bounding boxes (OBB) of the Results
object. It ensures that boxes are clipped to the original image shape.
Args:
boxes (torch.Tensor | None): A tensor of shape (N, 6) containing bounding box coordinates and
confidence scores. The format is (x1, y1, x2, y2, conf, class).
boxes (torch.Tensor | None): A tensor of shape (N, 6) containing bounding box coordinates and confidence
scores. The format is (x1, y1, x2, y2, conf, class).
masks (torch.Tensor | None): A tensor of shape (N, H, W) containing segmentation masks.
probs (torch.Tensor | None): A tensor of shape (num_classes,) containing class probabilities.
obb (torch.Tensor | None): A tensor of shape (N, 5) containing oriented bounding box coordinates.
@ -363,8 +349,7 @@ class Results(SimpleClass, DataExportMixin):
self.keypoints = Keypoints(keypoints, self.orig_shape)
def _apply(self, fn: str, *args, **kwargs):
"""
Apply a function to all non-empty attributes and return a new Results object with modified attributes.
"""Apply a function to all non-empty attributes and return a new Results object with modified attributes.
This method is internally called by methods like .to(), .cuda(), .cpu(), etc.
@ -390,8 +375,7 @@ class Results(SimpleClass, DataExportMixin):
return r
def cpu(self):
"""
Return a copy of the Results object with all its tensors moved to CPU memory.
"""Return a copy of the Results object with all its tensors moved to CPU memory.
This method creates a new Results object with all tensor attributes (boxes, masks, probs, keypoints, obb)
transferred to CPU memory. It's useful for moving data from GPU to CPU for further processing or saving.
@ -407,8 +391,7 @@ class Results(SimpleClass, DataExportMixin):
return self._apply("cpu")
def numpy(self):
"""
Convert all tensors in the Results object to numpy arrays.
"""Convert all tensors in the Results object to numpy arrays.
Returns:
(Results): A new Results object with all tensors converted to numpy arrays.
@ -426,8 +409,7 @@ class Results(SimpleClass, DataExportMixin):
return self._apply("numpy")
def cuda(self):
"""
Move all tensors in the Results object to GPU memory.
"""Move all tensors in the Results object to GPU memory.
Returns:
(Results): A new Results object with all tensors moved to CUDA device.
@ -441,8 +423,7 @@ class Results(SimpleClass, DataExportMixin):
return self._apply("cuda")
def to(self, *args, **kwargs):
"""
Move all tensors in the Results object to the specified device and dtype.
"""Move all tensors in the Results object to the specified device and dtype.
Args:
*args (Any): Variable length argument list to be passed to torch.Tensor.to().
@ -460,8 +441,7 @@ class Results(SimpleClass, DataExportMixin):
return self._apply("to", *args, **kwargs)
def new(self):
"""
Create a new Results object with the same image, path, names, and speed attributes.
"""Create a new Results object with the same image, path, names, and speed attributes.
Returns:
(Results): A new Results object with copied attributes from the original instance.
@ -493,8 +473,7 @@ class Results(SimpleClass, DataExportMixin):
color_mode: str = "class",
txt_color: tuple[int, int, int] = (255, 255, 255),
) -> np.ndarray:
"""
Plot detection results on an input RGB image.
"""Plot detection results on an input RGB image.
Args:
conf (bool): Whether to plot detection confidence scores.
@ -613,8 +592,7 @@ class Results(SimpleClass, DataExportMixin):
return annotator.im if pil else annotator.result()
def show(self, *args, **kwargs):
"""
Display the image with annotated inference results.
"""Display the image with annotated inference results.
This method plots the detection results on the original image and displays it. It's a convenient way to
visualize the model's predictions directly.
@ -632,15 +610,14 @@ class Results(SimpleClass, DataExportMixin):
self.plot(show=True, *args, **kwargs)
def save(self, filename: str | None = None, *args, **kwargs) -> str:
"""
Save annotated inference results image to file.
"""Save annotated inference results image to file.
This method plots the detection results on the original image and saves the annotated image to a file. It
utilizes the `plot` method to generate the annotated image and then saves it to the specified filename.
Args:
filename (str | Path | None): The filename to save the annotated image. If None, a default filename
is generated based on the original image path.
filename (str | Path | None): The filename to save the annotated image. If None, a default filename is
generated based on the original image path.
*args (Any): Variable length argument list to be passed to the `plot` method.
**kwargs (Any): Arbitrary keyword arguments to be passed to the `plot` method.
@ -661,15 +638,14 @@ class Results(SimpleClass, DataExportMixin):
return filename
def verbose(self) -> str:
"""
Return a log string for each task in the results, detailing detection and classification outcomes.
"""Return a log string for each task in the results, detailing detection and classification outcomes.
This method generates a human-readable string summarizing the detection and classification results. It includes
the number of detections for each class and the top probabilities for classification tasks.
Returns:
(str): A formatted string containing a summary of the results. For detection tasks, it includes the
number of detections per class. For classification tasks, it includes the top 5 class probabilities.
(str): A formatted string containing a summary of the results. For detection tasks, it includes the number
of detections per class. For classification tasks, it includes the top 5 class probabilities.
Examples:
>>> results = model("path/to/image.jpg")
@ -693,8 +669,7 @@ class Results(SimpleClass, DataExportMixin):
return "".join(f"{n} {self.names[i]}{'s' * (n > 1)}, " for i, n in enumerate(counts) if n > 0)
def save_txt(self, txt_file: str | Path, save_conf: bool = False) -> str:
"""
Save detection results to a text file.
"""Save detection results to a text file.
Args:
txt_file (str | Path): Path to the output text file.
@ -750,8 +725,7 @@ class Results(SimpleClass, DataExportMixin):
return str(txt_file)
def save_crop(self, save_dir: str | Path, file_name: str | Path = Path("im.jpg")):
"""
Save cropped detection images to specified directory.
"""Save cropped detection images to specified directory.
This method saves cropped images of detected objects to a specified directory. Each crop is saved in a
subdirectory named after the object's class, with the filename based on the input file_name.
@ -786,8 +760,7 @@ class Results(SimpleClass, DataExportMixin):
)
def summary(self, normalize: bool = False, decimals: int = 5) -> list[dict[str, Any]]:
"""
Convert inference results to a summarized dictionary with optional normalization for box coordinates.
"""Convert inference results to a summarized dictionary with optional normalization for box coordinates.
This method creates a list of detection dictionaries, each containing information about a single detection or
classification result. For classification tasks, it returns the top class and its
@ -853,8 +826,7 @@ class Results(SimpleClass, DataExportMixin):
class Boxes(BaseTensor):
"""
A class for managing and manipulating detection boxes.
"""A class for managing and manipulating detection boxes.
This class provides comprehensive functionality for handling detection boxes, including their coordinates,
confidence scores, class labels, and optional tracking IDs. It supports various box formats and offers methods for
@ -890,17 +862,15 @@ class Boxes(BaseTensor):
"""
def __init__(self, boxes: torch.Tensor | np.ndarray, orig_shape: tuple[int, int]) -> None:
"""
Initialize the Boxes class with detection box data and the original image shape.
"""Initialize the Boxes class with detection box data and the original image shape.
This class manages detection boxes, providing easy access and manipulation of box coordinates, confidence
scores, class identifiers, and optional tracking IDs. It supports multiple formats for box coordinates,
including both absolute and normalized forms.
Args:
boxes (torch.Tensor | np.ndarray): A tensor or numpy array with detection boxes of shape
(num_boxes, 6) or (num_boxes, 7). Columns should contain [x1, y1, x2, y2, (optional) track_id,
confidence, class].
boxes (torch.Tensor | np.ndarray): A tensor or numpy array with detection boxes of shape (num_boxes, 6) or
(num_boxes, 7). Columns should contain [x1, y1, x2, y2, (optional) track_id, confidence, class].
orig_shape (tuple[int, int]): The original image shape as (height, width). Used for normalization.
Attributes:
@ -926,12 +896,11 @@ class Boxes(BaseTensor):
@property
def xyxy(self) -> torch.Tensor | np.ndarray:
"""
Return bounding boxes in [x1, y1, x2, y2] format.
"""Return bounding boxes in [x1, y1, x2, y2] format.
Returns:
(torch.Tensor | np.ndarray): A tensor or numpy array of shape (n, 4) containing bounding box
coordinates in [x1, y1, x2, y2] format, where n is the number of boxes.
(torch.Tensor | np.ndarray): A tensor or numpy array of shape (n, 4) containing bounding box coordinates in
[x1, y1, x2, y2] format, where n is the number of boxes.
Examples:
>>> results = model("image.jpg")
@ -943,12 +912,11 @@ class Boxes(BaseTensor):
@property
def conf(self) -> torch.Tensor | np.ndarray:
"""
Return the confidence scores for each detection box.
"""Return the confidence scores for each detection box.
Returns:
(torch.Tensor | np.ndarray): A 1D tensor or array containing confidence scores for each detection,
with shape (N,) where N is the number of detections.
(torch.Tensor | np.ndarray): A 1D tensor or array containing confidence scores for each detection, with
shape (N,) where N is the number of detections.
Examples:
>>> boxes = Boxes(torch.tensor([[10, 20, 30, 40, 0.9, 0]]), orig_shape=(100, 100))
@ -960,12 +928,11 @@ class Boxes(BaseTensor):
@property
def cls(self) -> torch.Tensor | np.ndarray:
"""
Return the class ID tensor representing category predictions for each bounding box.
"""Return the class ID tensor representing category predictions for each bounding box.
Returns:
(torch.Tensor | np.ndarray): A tensor or numpy array containing the class IDs for each detection box.
The shape is (N,), where N is the number of boxes.
(torch.Tensor | np.ndarray): A tensor or numpy array containing the class IDs for each detection box. The
shape is (N,), where N is the number of boxes.
Examples:
>>> results = model("image.jpg")
@ -977,12 +944,11 @@ class Boxes(BaseTensor):
@property
def id(self) -> torch.Tensor | np.ndarray | None:
"""
Return the tracking IDs for each detection box if available.
"""Return the tracking IDs for each detection box if available.
Returns:
(torch.Tensor | None): A tensor containing tracking IDs for each box if tracking is enabled,
otherwise None. Shape is (N,) where N is the number of boxes.
(torch.Tensor | None): A tensor containing tracking IDs for each box if tracking is enabled, otherwise None.
Shape is (N,) where N is the number of boxes.
Examples:
>>> results = model.track("path/to/video.mp4")
@ -1003,13 +969,12 @@ class Boxes(BaseTensor):
@property
@lru_cache(maxsize=2)
def xywh(self) -> torch.Tensor | np.ndarray:
"""
Convert bounding boxes from [x1, y1, x2, y2] format to [x, y, width, height] format.
"""Convert bounding boxes from [x1, y1, x2, y2] format to [x, y, width, height] format.
Returns:
(torch.Tensor | np.ndarray): Boxes in [x_center, y_center, width, height] format, where x_center,
y_center are the coordinates of the center point of the bounding box, width, height are the dimensions
of the bounding box and the shape of the returned tensor is (N, 4), where N is the number of boxes.
(torch.Tensor | np.ndarray): Boxes in [x_center, y_center, width, height] format, where x_center, y_center
are the coordinates of the center point of the bounding box, width, height are the dimensions of the
bounding box and the shape of the returned tensor is (N, 4), where N is the number of boxes.
Examples:
>>> boxes = Boxes(torch.tensor([[100, 50, 150, 100], [200, 150, 300, 250]]), orig_shape=(480, 640))
@ -1023,15 +988,14 @@ class Boxes(BaseTensor):
@property
@lru_cache(maxsize=2)
def xyxyn(self) -> torch.Tensor | np.ndarray:
"""
Return normalized bounding box coordinates relative to the original image size.
"""Return normalized bounding box coordinates relative to the original image size.
This property calculates and returns the bounding box coordinates in [x1, y1, x2, y2] format, normalized to the
range [0, 1] based on the original image dimensions.
Returns:
(torch.Tensor | np.ndarray): Normalized bounding box coordinates with shape (N, 4), where N is
the number of boxes. Each row contains [x1, y1, x2, y2] values normalized to [0, 1].
(torch.Tensor | np.ndarray): Normalized bounding box coordinates with shape (N, 4), where N is the number of
boxes. Each row contains [x1, y1, x2, y2] values normalized to [0, 1].
Examples:
>>> boxes = Boxes(torch.tensor([[100, 50, 300, 400, 0.9, 0]]), orig_shape=(480, 640))
@ -1047,16 +1011,15 @@ class Boxes(BaseTensor):
@property
@lru_cache(maxsize=2)
def xywhn(self) -> torch.Tensor | np.ndarray:
"""
Return normalized bounding boxes in [x, y, width, height] format.
"""Return normalized bounding boxes in [x, y, width, height] format.
This property calculates and returns the normalized bounding box coordinates in the format [x_center, y_center,
width, height], where all values are relative to the original image dimensions.
Returns:
(torch.Tensor | np.ndarray): Normalized bounding boxes with shape (N, 4), where N is the
number of boxes. Each row contains [x_center, y_center, width, height] values normalized to [0, 1] based
on the original image dimensions.
(torch.Tensor | np.ndarray): Normalized bounding boxes with shape (N, 4), where N is the number of boxes.
Each row contains [x_center, y_center, width, height] values normalized to [0, 1] based on the original
image dimensions.
Examples:
>>> boxes = Boxes(torch.tensor([[100, 50, 150, 100, 0.9, 0]]), orig_shape=(480, 640))
@ -1071,8 +1034,7 @@ class Boxes(BaseTensor):
class Masks(BaseTensor):
"""
A class for storing and manipulating detection masks.
"""A class for storing and manipulating detection masks.
This class extends BaseTensor and provides functionality for handling segmentation masks, including methods for
converting between pixel and normalized coordinates.
@ -1098,8 +1060,7 @@ class Masks(BaseTensor):
"""
def __init__(self, masks: torch.Tensor | np.ndarray, orig_shape: tuple[int, int]) -> None:
"""
Initialize the Masks class with detection mask data and the original image shape.
"""Initialize the Masks class with detection mask data and the original image shape.
Args:
masks (torch.Tensor | np.ndarray): Detection masks with shape (num_masks, height, width).
@ -1119,15 +1080,14 @@ class Masks(BaseTensor):
@property
@lru_cache(maxsize=1)
def xyn(self) -> list[np.ndarray]:
"""
Return normalized xy-coordinates of the segmentation masks.
"""Return normalized xy-coordinates of the segmentation masks.
This property calculates and caches the normalized xy-coordinates of the segmentation masks. The coordinates are
normalized relative to the original image shape.
Returns:
(list[np.ndarray]): A list of numpy arrays, where each array contains the normalized xy-coordinates
of a single segmentation mask. Each array has shape (N, 2), where N is the number of points in the
(list[np.ndarray]): A list of numpy arrays, where each array contains the normalized xy-coordinates of a
single segmentation mask. Each array has shape (N, 2), where N is the number of points in the
mask contour.
Examples:
@ -1144,16 +1104,14 @@ class Masks(BaseTensor):
@property
@lru_cache(maxsize=1)
def xy(self) -> list[np.ndarray]:
"""
Return the [x, y] pixel coordinates for each segment in the mask tensor.
"""Return the [x, y] pixel coordinates for each segment in the mask tensor.
This property calculates and returns a list of pixel coordinates for each segmentation mask in the Masks object.
The coordinates are scaled to match the original image dimensions.
Returns:
(list[np.ndarray]): A list of numpy arrays, where each array contains the [x, y] pixel
coordinates for a single segmentation mask. Each array has shape (N, 2), where N is the number of points
in the segment.
(list[np.ndarray]): A list of numpy arrays, where each array contains the [x, y] pixel coordinates for a
single segmentation mask. Each array has shape (N, 2), where N is the number of points in the segment.
Examples:
>>> results = model("image.jpg")
@ -1169,8 +1127,7 @@ class Masks(BaseTensor):
class Keypoints(BaseTensor):
"""
A class for storing and manipulating detection keypoints.
"""A class for storing and manipulating detection keypoints.
This class encapsulates functionality for handling keypoint data, including coordinate manipulation, normalization,
and confidence values. It supports keypoint detection results with optional visibility information.
@ -1201,8 +1158,7 @@ class Keypoints(BaseTensor):
"""
def __init__(self, keypoints: torch.Tensor | np.ndarray, orig_shape: tuple[int, int]) -> None:
"""
Initialize the Keypoints object with detection keypoints and original image dimensions.
"""Initialize the Keypoints object with detection keypoints and original image dimensions.
This method processes the input keypoints tensor, handling both 2D and 3D formats. For 3D tensors (x, y,
confidence), it masks out low-confidence keypoints by setting their coordinates to zero.
@ -1226,12 +1182,11 @@ class Keypoints(BaseTensor):
@property
@lru_cache(maxsize=1)
def xy(self) -> torch.Tensor | np.ndarray:
"""
Return x, y coordinates of keypoints.
"""Return x, y coordinates of keypoints.
Returns:
(torch.Tensor): A tensor containing the x, y coordinates of keypoints with shape (N, K, 2), where N is
the number of detections and K is the number of keypoints per detection.
(torch.Tensor): A tensor containing the x, y coordinates of keypoints with shape (N, K, 2), where N is the
number of detections and K is the number of keypoints per detection.
Examples:
>>> results = model("image.jpg")
@ -1250,8 +1205,7 @@ class Keypoints(BaseTensor):
@property
@lru_cache(maxsize=1)
def xyn(self) -> torch.Tensor | np.ndarray:
"""
Return normalized coordinates (x, y) of keypoints relative to the original image size.
"""Return normalized coordinates (x, y) of keypoints relative to the original image size.
Returns:
(torch.Tensor | np.ndarray): A tensor or array of shape (N, K, 2) containing normalized keypoint
@ -1272,13 +1226,11 @@ class Keypoints(BaseTensor):
@property
@lru_cache(maxsize=1)
def conf(self) -> torch.Tensor | np.ndarray | None:
"""
Return confidence values for each keypoint.
"""Return confidence values for each keypoint.
Returns:
(torch.Tensor | None): A tensor containing confidence scores for each keypoint if available,
otherwise None. Shape is (num_detections, num_keypoints) for batched data or (num_keypoints,) for
single detection.
(torch.Tensor | None): A tensor containing confidence scores for each keypoint if available, otherwise None.
Shape is (num_detections, num_keypoints) for batched data or (num_keypoints,) for single detection.
Examples:
>>> keypoints = Keypoints(torch.rand(1, 17, 3), orig_shape=(640, 640)) # 1 detection, 17 keypoints
@ -1289,8 +1241,7 @@ class Keypoints(BaseTensor):
class Probs(BaseTensor):
"""
A class for storing and manipulating classification probabilities.
"""A class for storing and manipulating classification probabilities.
This class extends BaseTensor and provides methods for accessing and manipulating classification probabilities,
including top-1 and top-5 predictions.
@ -1323,16 +1274,15 @@ class Probs(BaseTensor):
"""
def __init__(self, probs: torch.Tensor | np.ndarray, orig_shape: tuple[int, int] | None = None) -> None:
"""
Initialize the Probs class with classification probabilities.
"""Initialize the Probs class with classification probabilities.
This class stores and manages classification probabilities, providing easy access to top predictions and their
confidences.
Args:
probs (torch.Tensor | np.ndarray): A 1D tensor or array of classification probabilities.
orig_shape (tuple | None): The original image shape as (height, width). Not used in this class but kept
for consistency with other result classes.
orig_shape (tuple | None): The original image shape as (height, width). Not used in this class but kept for
consistency with other result classes.
Attributes:
data (torch.Tensor | np.ndarray): The raw tensor or array containing classification probabilities.
@ -1357,8 +1307,7 @@ class Probs(BaseTensor):
@property
@lru_cache(maxsize=1)
def top1(self) -> int:
"""
Return the index of the class with the highest probability.
"""Return the index of the class with the highest probability.
Returns:
(int): Index of the class with the highest probability.
@ -1373,8 +1322,7 @@ class Probs(BaseTensor):
@property
@lru_cache(maxsize=1)
def top5(self) -> list[int]:
"""
Return the indices of the top 5 class probabilities.
"""Return the indices of the top 5 class probabilities.
Returns:
(list[int]): A list containing the indices of the top 5 class probabilities, sorted in descending order.
@ -1389,8 +1337,7 @@ class Probs(BaseTensor):
@property
@lru_cache(maxsize=1)
def top1conf(self) -> torch.Tensor | np.ndarray:
"""
Return the confidence score of the highest probability class.
"""Return the confidence score of the highest probability class.
This property retrieves the confidence score (probability) of the class with the highest predicted probability
from the classification results.
@ -1409,16 +1356,15 @@ class Probs(BaseTensor):
@property
@lru_cache(maxsize=1)
def top5conf(self) -> torch.Tensor | np.ndarray:
"""
Return confidence scores for the top 5 classification predictions.
"""Return confidence scores for the top 5 classification predictions.
This property retrieves the confidence scores corresponding to the top 5 class probabilities predicted by the
model. It provides a quick way to access the most likely class predictions along with their associated
confidence levels.
Returns:
(torch.Tensor | np.ndarray): A tensor or array containing the confidence scores for the
top 5 predicted classes, sorted in descending order of probability.
(torch.Tensor | np.ndarray): A tensor or array containing the confidence scores for the top 5 predicted
classes, sorted in descending order of probability.
Examples:
>>> results = model("image.jpg")
@ -1430,8 +1376,7 @@ class Probs(BaseTensor):
class OBB(BaseTensor):
"""
A class for storing and manipulating Oriented Bounding Boxes (OBB).
"""A class for storing and manipulating Oriented Bounding Boxes (OBB).
This class provides functionality to handle oriented bounding boxes, including conversion between different formats,
normalization, and access to various properties of the boxes. It supports both tracking and non-tracking scenarios.
@ -1463,16 +1408,15 @@ class OBB(BaseTensor):
"""
def __init__(self, boxes: torch.Tensor | np.ndarray, orig_shape: tuple[int, int]) -> None:
"""
Initialize an OBB (Oriented Bounding Box) instance with oriented bounding box data and original image shape.
"""Initialize an OBB (Oriented Bounding Box) instance with oriented bounding box data and original image shape.
This class stores and manipulates Oriented Bounding Boxes (OBB) for object detection tasks. It provides various
properties and methods to access and transform the OBB data.
Args:
boxes (torch.Tensor | np.ndarray): A tensor or numpy array containing the detection boxes,
with shape (num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values.
If present, the third last column contains track IDs, and the fifth column contains rotation.
boxes (torch.Tensor | np.ndarray): A tensor or numpy array containing the detection boxes, with shape
(num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values. If present,
the third last column contains track IDs, and the fifth column contains rotation.
orig_shape (tuple[int, int]): Original image size, in the format (height, width).
Attributes:
@ -1500,8 +1444,7 @@ class OBB(BaseTensor):
@property
def xywhr(self) -> torch.Tensor | np.ndarray:
"""
Return boxes in [x_center, y_center, width, height, rotation] format.
"""Return boxes in [x_center, y_center, width, height, rotation] format.
Returns:
(torch.Tensor | np.ndarray): A tensor or numpy array containing the oriented bounding boxes with format
@ -1518,15 +1461,14 @@ class OBB(BaseTensor):
@property
def conf(self) -> torch.Tensor | np.ndarray:
"""
Return the confidence scores for Oriented Bounding Boxes (OBBs).
"""Return the confidence scores for Oriented Bounding Boxes (OBBs).
This property retrieves the confidence values associated with each OBB detection. The confidence score
represents the model's certainty in the detection.
Returns:
(torch.Tensor | np.ndarray): A tensor or numpy array of shape (N,) containing confidence scores
for N detections, where each score is in the range [0, 1].
(torch.Tensor | np.ndarray): A tensor or numpy array of shape (N,) containing confidence scores for N
detections, where each score is in the range [0, 1].
Examples:
>>> results = model("image.jpg")
@ -1538,12 +1480,11 @@ class OBB(BaseTensor):
@property
def cls(self) -> torch.Tensor | np.ndarray:
"""
Return the class values of the oriented bounding boxes.
"""Return the class values of the oriented bounding boxes.
Returns:
(torch.Tensor | np.ndarray): A tensor or numpy array containing the class values for each oriented
bounding box. The shape is (N,), where N is the number of boxes.
(torch.Tensor | np.ndarray): A tensor or numpy array containing the class values for each oriented bounding
box. The shape is (N,), where N is the number of boxes.
Examples:
>>> results = model("image.jpg")
@ -1556,12 +1497,11 @@ class OBB(BaseTensor):
@property
def id(self) -> torch.Tensor | np.ndarray | None:
"""
Return the tracking IDs of the oriented bounding boxes (if available).
"""Return the tracking IDs of the oriented bounding boxes (if available).
Returns:
(torch.Tensor | np.ndarray | None): A tensor or numpy array containing the tracking IDs for each
oriented bounding box. Returns None if tracking IDs are not available.
(torch.Tensor | np.ndarray | None): A tensor or numpy array containing the tracking IDs for each oriented
bounding box. Returns None if tracking IDs are not available.
Examples:
>>> results = model("image.jpg", tracker=True) # Run inference with tracking
@ -1576,12 +1516,11 @@ class OBB(BaseTensor):
@property
@lru_cache(maxsize=2)
def xyxyxyxy(self) -> torch.Tensor | np.ndarray:
"""
Convert OBB format to 8-point (xyxyxyxy) coordinate format for rotated bounding boxes.
"""Convert OBB format to 8-point (xyxyxyxy) coordinate format for rotated bounding boxes.
Returns:
(torch.Tensor | np.ndarray): Rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2), where N is
the number of boxes. Each box is represented by 4 points (x, y), starting from the top-left corner and
(torch.Tensor | np.ndarray): Rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2), where N is the
number of boxes. Each box is represented by 4 points (x, y), starting from the top-left corner and
moving clockwise.
Examples:
@ -1595,8 +1534,7 @@ class OBB(BaseTensor):
@property
@lru_cache(maxsize=2)
def xyxyxyxyn(self) -> torch.Tensor | np.ndarray:
"""
Convert rotated bounding boxes to normalized xyxyxyxy format.
"""Convert rotated bounding boxes to normalized xyxyxyxy format.
Returns:
(torch.Tensor | np.ndarray): Normalized rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2),
@ -1617,16 +1555,15 @@ class OBB(BaseTensor):
@property
@lru_cache(maxsize=2)
def xyxy(self) -> torch.Tensor | np.ndarray:
"""
Convert oriented bounding boxes (OBB) to axis-aligned bounding boxes in xyxy format.
"""Convert oriented bounding boxes (OBB) to axis-aligned bounding boxes in xyxy format.
This property calculates the minimal enclosing rectangle for each oriented bounding box and returns it in xyxy
format (x1, y1, x2, y2). This is useful for operations that require axis-aligned bounding boxes, such as IoU
calculation with non-rotated boxes.
Returns:
(torch.Tensor | np.ndarray): Axis-aligned bounding boxes in xyxy format with shape (N, 4), where N
is the number of boxes. Each row contains [x1, y1, x2, y2] coordinates.
(torch.Tensor | np.ndarray): Axis-aligned bounding boxes in xyxy format with shape (N, 4), where N is the
number of boxes. Each row contains [x1, y1, x2, y2] coordinates.
Examples:
>>> import torch

View file

@ -63,8 +63,7 @@ from ultralytics.utils.torch_utils import (
class BaseTrainer:
"""
A base class for creating trainers.
"""A base class for creating trainers.
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
and various training utilities. It supports both single-GPU and multi-GPU distributed training.
@ -114,8 +113,7 @@ class BaseTrainer:
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the BaseTrainer class.
"""Initialize the BaseTrainer class.
Args:
cfg (str, optional): Path to a configuration file.
@ -620,8 +618,7 @@ class BaseTrainer:
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
def get_dataset(self):
"""
Get train and validation datasets from data dictionary.
"""Get train and validation datasets from data dictionary.
Returns:
(dict): A dictionary containing the training/validation/test dataset and category names.
@ -656,8 +653,7 @@ class BaseTrainer:
return data
def setup_model(self):
"""
Load, create, or download model for any task.
"""Load, create, or download model for any task.
Returns:
(dict): Optional checkpoint to resume training from.
@ -690,8 +686,7 @@ class BaseTrainer:
return batch
def validate(self):
"""
Run validation on val set using self.validator.
"""Run validation on val set using self.validator.
Returns:
metrics (dict): Dictionary of validation metrics.
@ -726,8 +721,7 @@ class BaseTrainer:
raise NotImplementedError("build_dataset function not implemented in trainer")
def label_loss_items(self, loss_items=None, prefix="train"):
"""
Return a loss dict with labeled training loss items tensor.
"""Return a loss dict with labeled training loss items tensor.
Notes:
This is not needed for classification but necessary for segmentation & detection
@ -895,18 +889,16 @@ class BaseTrainer:
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
"""
Construct an optimizer for the given model.
"""Construct an optimizer for the given model.
Args:
model (torch.nn.Module): The model for which to build an optimizer.
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
based on the number of iterations.
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
number of iterations.
lr (float, optional): The learning rate for the optimizer.
momentum (float, optional): The momentum factor for the optimizer.
decay (float, optional): The weight decay for the optimizer.
iterations (float, optional): The number of iterations, which determines the optimizer if
name is 'auto'.
iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
Returns:
(torch.optim.Optimizer): The constructed optimizer.

View file

@ -34,8 +34,7 @@ from ultralytics.utils.plotting import plot_tune_results
class Tuner:
"""
A class for hyperparameter tuning of YOLO models.
"""A class for hyperparameter tuning of YOLO models.
The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
search space and retraining the model to evaluate their performance. Supports both local CSV storage and distributed
@ -83,8 +82,7 @@ class Tuner:
"""
def __init__(self, args=DEFAULT_CFG, _callbacks: list | None = None):
"""
Initialize the Tuner with configurations.
"""Initialize the Tuner with configurations.
Args:
args (dict): Configuration for hyperparameter evolution.
@ -142,8 +140,7 @@ class Tuner:
)
def _connect(self, uri: str = "mongodb+srv://username:password@cluster.mongodb.net/", max_retries: int = 3):
"""
Create MongoDB client with exponential backoff retry on connection failures.
"""Create MongoDB client with exponential backoff retry on connection failures.
Args:
uri (str): MongoDB connection string with credentials and cluster information.
@ -183,8 +180,7 @@ class Tuner:
time.sleep(wait_time)
def _init_mongodb(self, mongodb_uri="", mongodb_db="", mongodb_collection=""):
"""
Initialize MongoDB connection for distributed tuning.
"""Initialize MongoDB connection for distributed tuning.
Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines. Each worker
saves results to a shared collection and reads the latest best hyperparameters from all workers for evolution.
@ -205,8 +201,7 @@ class Tuner:
LOGGER.info(f"{self.prefix}Using MongoDB Atlas for distributed tuning")
def _get_mongodb_results(self, n: int = 5) -> list:
"""
Get top N results from MongoDB sorted by fitness.
"""Get top N results from MongoDB sorted by fitness.
Args:
n (int): Number of top results to retrieve.
@ -220,8 +215,7 @@ class Tuner:
return []
def _save_to_mongodb(self, fitness: float, hyperparameters: dict[str, float], metrics: dict, iteration: int):
"""
Save results to MongoDB with proper type conversion.
"""Save results to MongoDB with proper type conversion.
Args:
fitness (float): Fitness score achieved with these hyperparameters.
@ -243,8 +237,7 @@ class Tuner:
LOGGER.warning(f"{self.prefix}MongoDB save failed: {e}")
def _sync_mongodb_to_csv(self):
"""
Sync MongoDB results to CSV for plotting compatibility.
"""Sync MongoDB results to CSV for plotting compatibility.
Downloads all results from MongoDB and writes them to the local CSV file in chronological order. This enables
the existing plotting functions to work seamlessly with distributed MongoDB data.
@ -287,8 +280,7 @@ class Tuner:
mutation: float = 0.5,
sigma: float = 0.2,
) -> dict[str, float]:
"""
Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
"""Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
Args:
parent (str): Parent selection method (kept for API compatibility, unused in BLX mode).
@ -348,8 +340,7 @@ class Tuner:
return hyp
def __call__(self, model=None, iterations: int = 10, cleanup: bool = True):
"""
Execute the hyperparameter evolution process when the Tuner instance is called.
"""Execute the hyperparameter evolution process when the Tuner instance is called.
This method iterates through the specified number of iterations, performing the following steps:
1. Sync MongoDB results to CSV (if using distributed mode)

View file

@ -41,8 +41,7 @@ from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_
class BaseValidator:
"""
A base class for creating validators.
"""A base class for creating validators.
This class provides the foundation for validation processes, including model evaluation, metric computation, and
result visualization.
@ -62,8 +61,8 @@ class BaseValidator:
nc (int): Number of classes.
iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
jdict (list): List to store JSON validation results.
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
batch processing times in milliseconds.
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch
processing times in milliseconds.
save_dir (Path): Directory to save results.
plots (dict): Dictionary to store plots for visualization.
callbacks (dict): Dictionary to store various callback functions.
@ -93,8 +92,7 @@ class BaseValidator:
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
"""
Initialize a BaseValidator instance.
"""Initialize a BaseValidator instance.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
@ -131,8 +129,7 @@ class BaseValidator:
@smart_inference_mode()
def __call__(self, trainer=None, model=None):
"""
Execute validation process, running inference on dataloader and computing performance metrics.
"""Execute validation process, running inference on dataloader and computing performance metrics.
Args:
trainer (object, optional): Trainer object that contains the model to validate.
@ -269,8 +266,7 @@ class BaseValidator:
def match_predictions(
self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
) -> torch.Tensor:
"""
Match predictions to ground truth objects using IoU.
"""Match predictions to ground truth objects using IoU.
Args:
pred_classes (torch.Tensor): Predicted class indices of shape (N,).

View file

@ -23,15 +23,14 @@ __all__ = (
def login(api_key: str | None = None, save: bool = True) -> bool:
"""
Log in to the Ultralytics HUB API using the provided API key.
"""Log in to the Ultralytics HUB API using the provided API key.
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
environment variable if successfully authenticated.
Args:
api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from
SETTINGS or HUB_API_KEY environment variable.
api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from SETTINGS
or HUB_API_KEY environment variable.
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
Returns:
@ -91,8 +90,7 @@ def export_fmts_hub():
def export_model(model_id: str = "", format: str = "torchscript"):
"""
Export a model to a specified format for deployment via the Ultralytics HUB API.
"""Export a model to a specified format for deployment via the Ultralytics HUB API.
Args:
model_id (str): The ID of the model to export. An empty string will use the default model.
@ -117,13 +115,11 @@ def export_model(model_id: str = "", format: str = "torchscript"):
def get_export(model_id: str = "", format: str = "torchscript"):
"""
Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
"""Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
Args:
model_id (str): The ID of the model to retrieve from Ultralytics HUB.
format (str): The export format to retrieve. Must be one of the supported formats returned by
export_fmts_hub().
format (str): The export format to retrieve. Must be one of the supported formats returned by export_fmts_hub().
Returns:
(dict): JSON response containing the exported model information.
@ -148,8 +144,7 @@ def get_export(model_id: str = "", format: str = "torchscript"):
def check_dataset(path: str, task: str) -> None:
"""
Check HUB dataset Zip file for errors before upload.
"""Check HUB dataset Zip file for errors before upload.
Args:
path (str): Path to data.zip (with data.yaml inside data.zip).

View file

@ -7,8 +7,7 @@ API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
class Auth:
"""
Manages authentication processes including API key handling, cookie-based authentication, and header generation.
"""Manages authentication processes including API key handling, cookie-based authentication, and header generation.
The class supports different methods of authentication:
1. Directly using an API key.
@ -37,8 +36,7 @@ class Auth:
id_token = api_key = model_key = False
def __init__(self, api_key: str = "", verbose: bool = False):
"""
Initialize Auth class and authenticate user.
"""Initialize Auth class and authenticate user.
Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful
authentication.
@ -82,8 +80,7 @@ class Auth:
LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'")
def request_api_key(self, max_attempts: int = 3) -> bool:
"""
Prompt the user to input their API key.
"""Prompt the user to input their API key.
Args:
max_attempts (int): Maximum number of authentication attempts.
@ -102,8 +99,7 @@ class Auth:
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
def authenticate(self) -> bool:
"""
Attempt to authenticate with the server using either id_token or API key.
"""Attempt to authenticate with the server using either id_token or API key.
Returns:
(bool): True if authentication is successful, False otherwise.
@ -123,8 +119,7 @@ class Auth:
return False
def auth_with_cookies(self) -> bool:
"""
Attempt to fetch authentication via cookies and set id_token.
"""Attempt to fetch authentication via cookies and set id_token.
User must be logged in to HUB and running in a supported browser.
@ -145,8 +140,7 @@ class Auth:
return False
def get_auth_header(self):
"""
Get the authentication header for making API requests.
"""Get the authentication header for making API requests.
Returns:
(dict | None): The authentication header if id_token or API key is set, None otherwise.

View file

@ -8,8 +8,7 @@ import time
class GCPRegions:
"""
A class for managing and analyzing Google Cloud Platform (GCP) regions.
"""A class for managing and analyzing Google Cloud Platform (GCP) regions.
This class provides functionality to initialize, categorize, and analyze GCP regions based on their geographical
location, tier classification, and network latency.
@ -82,8 +81,7 @@ class GCPRegions:
@staticmethod
def _ping_region(region: str, attempts: int = 1) -> tuple[str, float, float, float, float]:
"""
Ping a specified GCP region and measure network latency statistics.
"""Ping a specified GCP region and measure network latency statistics.
Args:
region (str): The GCP region identifier to ping (e.g., 'us-central1').
@ -126,8 +124,7 @@ class GCPRegions:
tier: int | None = None,
attempts: int = 1,
) -> list[tuple[str, float, float, float, float]]:
"""
Determine the GCP regions with the lowest latency based on ping tests.
"""Determine the GCP regions with the lowest latency based on ping tests.
Args:
top (int, optional): Number of top regions to return.
@ -136,8 +133,8 @@ class GCPRegions:
attempts (int, optional): Number of ping attempts per region.
Returns:
(list[tuple[str, float, float, float, float]]): List of tuples containing region information and
latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
(list[tuple[str, float, float, float, float]]): List of tuples containing region information and latency
statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
Examples:
>>> regions = GCPRegions()

View file

@ -19,8 +19,7 @@ AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version_
class HUBTrainingSession:
"""
HUB training session for Ultralytics HUB YOLO models.
"""HUB training session for Ultralytics HUB YOLO models.
This class encapsulates the functionality for interacting with Ultralytics HUB during model training, including
model creation, metrics tracking, and checkpoint uploading.
@ -45,12 +44,11 @@ class HUBTrainingSession:
"""
def __init__(self, identifier: str):
"""
Initialize the HUBTrainingSession with the provided model identifier.
"""Initialize the HUBTrainingSession with the provided model identifier.
Args:
identifier (str): Model identifier used to initialize the HUB training session. It can be a URL string
or a model key with specific format.
identifier (str): Model identifier used to initialize the HUB training session. It can be a URL string or a
model key with specific format.
Raises:
ValueError: If the provided model identifier is invalid.
@ -93,8 +91,7 @@ class HUBTrainingSession:
@classmethod
def create_session(cls, identifier: str, args: dict[str, Any] | None = None):
"""
Create an authenticated HUBTrainingSession or return None.
"""Create an authenticated HUBTrainingSession or return None.
Args:
identifier (str): Model identifier used to initialize the HUB training session.
@ -114,8 +111,7 @@ class HUBTrainingSession:
return None
def load_model(self, model_id: str):
"""
Load an existing model from Ultralytics HUB using the provided model identifier.
"""Load an existing model from Ultralytics HUB using the provided model identifier.
Args:
model_id (str): The identifier of the model to load.
@ -140,8 +136,7 @@ class HUBTrainingSession:
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
def create_model(self, model_args: dict[str, Any]):
"""
Initialize a HUB training session with the specified model arguments.
"""Initialize a HUB training session with the specified model arguments.
Args:
model_args (dict[str, Any]): Arguments for creating the model, including batch size, epochs, image size,
@ -186,8 +181,7 @@ class HUBTrainingSession:
@staticmethod
def _parse_identifier(identifier: str):
"""
Parse the given identifier to determine the type and extract relevant components.
"""Parse the given identifier to determine the type and extract relevant components.
The method supports different identifier formats:
- A HUB model URL https://hub.ultralytics.com/models/MODEL
@ -218,8 +212,7 @@ class HUBTrainingSession:
return api_key, model_id, filename
def _set_train_args(self):
"""
Initialize training arguments and create a model entry on the Ultralytics HUB.
"""Initialize training arguments and create a model entry on the Ultralytics HUB.
This method sets up training arguments based on the model's state and updates them with any additional arguments
provided. It handles different states of the model, such as whether it's resumable, pretrained, or requires
@ -261,8 +254,7 @@ class HUBTrainingSession:
*args,
**kwargs,
):
"""
Execute request_func with retries, timeout handling, optional threading, and progress tracking.
"""Execute request_func with retries, timeout handling, optional threading, and progress tracking.
Args:
request_func (callable): The function to execute.
@ -342,8 +334,7 @@ class HUBTrainingSession:
return status_code in retry_codes
def _get_failure_message(self, response, retry: int, timeout: int) -> str:
"""
Generate a retry message based on the response status code.
"""Generate a retry message based on the response status code.
Args:
response (requests.Response): The HTTP response object.
@ -379,8 +370,7 @@ class HUBTrainingSession:
map: float = 0.0,
final: bool = False,
) -> None:
"""
Upload a model checkpoint to Ultralytics HUB.
"""Upload a model checkpoint to Ultralytics HUB.
Args:
epoch (int): The current training epoch.

View file

@ -21,8 +21,7 @@ HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/h
def request_with_credentials(url: str) -> Any:
"""
Make an AJAX request with cookies attached in a Google Colab environment.
"""Make an AJAX request with cookies attached in a Google Colab environment.
Args:
url (str): The URL to make the request to.
@ -62,8 +61,7 @@ def request_with_credentials(url: str) -> Any:
def requests_with_progress(method: str, url: str, **kwargs):
"""
Make an HTTP request using the specified method and URL, with an optional progress bar.
"""Make an HTTP request using the specified method and URL, with an optional progress bar.
Args:
method (str): The HTTP method to use (e.g. 'GET', 'POST').
@ -106,8 +104,7 @@ def smart_request(
progress: bool = False,
**kwargs,
):
"""
Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
"""Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
Args:
method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.

View file

@ -12,8 +12,7 @@ from .val import FastSAMValidator
class FastSAM(Model):
"""
FastSAM model interface for segment anything tasks.
"""FastSAM model interface for segment anything tasks.
This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything
Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.
@ -53,15 +52,14 @@ class FastSAM(Model):
texts: list | None = None,
**kwargs: Any,
):
"""
Perform segmentation prediction on image or video source.
"""Perform segmentation prediction on image or video source.
Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these prompts
and passes them to the parent class predict method for processing.
Args:
source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image,
or numpy array.
source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image, or
numpy array.
stream (bool): Whether to enable real-time streaming mode for video inputs.
bboxes (list, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].
points (list, optional): Point coordinates for prompted segmentation in format [[x, y]].

View file

@ -13,8 +13,7 @@ from .utils import adjust_bboxes_to_image_border
class FastSAMPredictor(SegmentationPredictor):
"""
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
"""FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
@ -33,8 +32,7 @@ class FastSAMPredictor(SegmentationPredictor):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the FastSAMPredictor with configuration and callbacks.
"""Initialize the FastSAMPredictor with configuration and callbacks.
This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
@ -49,8 +47,7 @@ class FastSAMPredictor(SegmentationPredictor):
self.prompts = {}
def postprocess(self, preds, img, orig_imgs):
"""
Apply postprocessing to FastSAM predictions and handle prompts.
"""Apply postprocessing to FastSAM predictions and handle prompts.
Args:
preds (list[torch.Tensor]): Raw predictions from the model.
@ -77,8 +74,7 @@ class FastSAMPredictor(SegmentationPredictor):
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
"""
Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
"""Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
Args:
results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
@ -151,8 +147,7 @@ class FastSAMPredictor(SegmentationPredictor):
return prompt_results
def _clip_inference(self, images, texts):
"""
Perform CLIP inference to calculate similarity between images and text prompts.
"""Perform CLIP inference to calculate similarity between images and text prompts.
Args:
images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.

View file

@ -2,8 +2,7 @@
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
"""
Adjust bounding boxes to stick to image border if they are within a certain threshold.
"""Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.

View file

@ -4,8 +4,7 @@ from ultralytics.models.yolo.segment import SegmentationValidator
class FastSAMValidator(SegmentationValidator):
"""
Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
"""Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class
sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
@ -23,8 +22,7 @@ class FastSAMValidator(SegmentationValidator):
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
"""
Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
"""Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.

View file

@ -18,8 +18,7 @@ from .val import NASValidator
class NAS(Model):
"""
YOLO-NAS model for object detection.
"""YOLO-NAS model for object detection.
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. It
is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
@ -48,8 +47,7 @@ class NAS(Model):
super().__init__(model, task="detect")
def _load(self, weights: str, task=None) -> None:
"""
Load an existing NAS model weights or create a new NAS model with pretrained weights.
"""Load an existing NAS model weights or create a new NAS model with pretrained weights.
Args:
weights (str): Path to the model weights file or model name.
@ -83,8 +81,7 @@ class NAS(Model):
self.model.eval()
def info(self, detailed: bool = False, verbose: bool = True) -> dict[str, Any]:
"""
Log model information.
"""Log model information.
Args:
detailed (bool): Show detailed information about model.

View file

@ -7,8 +7,7 @@ from ultralytics.utils import ops
class NASPredictor(DetectionPredictor):
"""
Ultralytics YOLO NAS Predictor for object detection.
"""Ultralytics YOLO NAS Predictor for object detection.
This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the raw
predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and scaling the
@ -33,8 +32,7 @@ class NASPredictor(DetectionPredictor):
"""
def postprocess(self, preds_in, img, orig_imgs):
"""
Postprocess NAS model predictions to generate final detection results.
"""Postprocess NAS model predictions to generate final detection results.
This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies
post-processing operations to generate the final detection results compatible with Ultralytics result

View file

@ -9,8 +9,7 @@ __all__ = ["NASValidator"]
class NASValidator(DetectionValidator):
"""
Ultralytics YOLO NAS Validator for object detection.
"""Ultralytics YOLO NAS Validator for object detection.
Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions
generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,

View file

@ -19,8 +19,7 @@ from .val import RTDETRValidator
class RTDETR(Model):
"""
Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
"""Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
selection, and adaptable inference speed.
@ -39,8 +38,7 @@ class RTDETR(Model):
"""
def __init__(self, model: str = "rtdetr-l.pt") -> None:
"""
Initialize the RT-DETR model with the given pre-trained model file.
"""Initialize the RT-DETR model with the given pre-trained model file.
Args:
model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
@ -50,8 +48,7 @@ class RTDETR(Model):
@property
def task_map(self) -> dict:
"""
Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
"""Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
Returns:
(dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.

View file

@ -9,8 +9,7 @@ from ultralytics.utils import ops
class RTDETRPredictor(BasePredictor):
"""
RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
"""RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy. It
supports key features like efficient hybrid encoding and IoU-aware query selection.
@ -34,21 +33,20 @@ class RTDETRPredictor(BasePredictor):
"""
def postprocess(self, preds, img, orig_imgs):
"""
Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
"""Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
The method filters detections based on confidence and class if specified in `self.args`. It converts model
predictions to Results objects containing properly scaled bounding boxes.
Args:
preds (list | tuple): List of [predictions, extra] from the model, where predictions contain
bounding boxes and scores.
preds (list | tuple): List of [predictions, extra] from the model, where predictions contain bounding boxes
and scores.
img (torch.Tensor): Processed input images with shape (N, 3, H, W).
orig_imgs (list | torch.Tensor): Original, unprocessed images.
Returns:
results (list[Results]): A list of Results objects containing the post-processed bounding boxes,
confidence scores, and class labels.
results (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence
scores, and class labels.
"""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]
@ -75,15 +73,14 @@ class RTDETRPredictor(BasePredictor):
return results
def pre_transform(self, im):
"""
Pre-transform input images before feeding them into the model for inference.
"""Pre-transform input images before feeding them into the model for inference.
The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square (640)
and scale_filled.
Args:
im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,
[(H, W, 3) x N] for list.
im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for
list.
Returns:
(list): List of pre-transformed images ready for model inference.

View file

@ -12,8 +12,7 @@ from .val import RTDETRDataset, RTDETRValidator
class RTDETRTrainer(DetectionTrainer):
"""
Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
"""Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of
RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable
@ -43,8 +42,7 @@ class RTDETRTrainer(DetectionTrainer):
"""
def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
"""
Initialize and return an RT-DETR model for object detection tasks.
"""Initialize and return an RT-DETR model for object detection tasks.
Args:
cfg (dict, optional): Model configuration.
@ -60,8 +58,7 @@ class RTDETRTrainer(DetectionTrainer):
return model
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
"""
Build and return an RT-DETR dataset for training or validation.
"""Build and return an RT-DETR dataset for training or validation.
Args:
img_path (str): Path to the folder containing images.

View file

@ -16,8 +16,7 @@ __all__ = ("RTDETRValidator",) # tuple or list
class RTDETRDataset(YOLODataset):
"""
Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
"""Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
real-time detection and tracking tasks.
@ -40,8 +39,7 @@ class RTDETRDataset(YOLODataset):
"""
def __init__(self, *args, data=None, **kwargs):
"""
Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
"""Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
model, building upon the base YOLODataset functionality.
@ -54,8 +52,7 @@ class RTDETRDataset(YOLODataset):
super().__init__(*args, data=data, **kwargs)
def load_image(self, i, rect_mode=False):
"""
Load one image from dataset index 'i'.
"""Load one image from dataset index 'i'.
Args:
i (int): Index of the image to load.
@ -73,8 +70,7 @@ class RTDETRDataset(YOLODataset):
return super().load_image(i=i, rect_mode=rect_mode)
def build_transforms(self, hyp=None):
"""
Build transformation pipeline for the dataset.
"""Build transformation pipeline for the dataset.
Args:
hyp (dict, optional): Hyperparameters for transformations.
@ -105,8 +101,7 @@ class RTDETRDataset(YOLODataset):
class RTDETRValidator(DetectionValidator):
"""
RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
"""RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
the RT-DETR (Real-Time DETR) object detection model.
The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
@ -132,8 +127,7 @@ class RTDETRValidator(DetectionValidator):
"""
def build_dataset(self, img_path, mode="val", batch=None):
"""
Build an RTDETR Dataset.
"""Build an RTDETR Dataset.
Args:
img_path (str): Path to the folder containing images.
@ -159,8 +153,7 @@ class RTDETRValidator(DetectionValidator):
def postprocess(
self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
) -> list[dict[str, torch.Tensor]]:
"""
Apply Non-maximum suppression to prediction outputs.
"""Apply Non-maximum suppression to prediction outputs.
Args:
preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
@ -191,12 +184,11 @@ class RTDETRValidator(DetectionValidator):
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Serialize YOLO predictions to COCO json format.
"""Serialize YOLO predictions to COCO json format.
Args:
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
with bounding box coordinates, confidence scores, and class predictions.
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
bounding box coordinates, confidence scores, and class predictions.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
"""
path = Path(pbatch["im_file"])

View file

@ -14,8 +14,7 @@ import torch
def is_box_near_crop_edge(
boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
) -> torch.Tensor:
"""
Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
"""Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
Args:
boxes (torch.Tensor): Bounding boxes in XYXY format.
@ -42,8 +41,7 @@ def is_box_near_crop_edge(
def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
"""
Yield batches of data from input arguments with specified batch size for efficient processing.
"""Yield batches of data from input arguments with specified batch size for efficient processing.
This function takes a batch size and any number of iterables, then yields batches of elements from those
iterables. All input iterables must have the same length.
@ -71,8 +69,7 @@ def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
"""
Compute the stability score for a batch of masks.
"""Compute the stability score for a batch of masks.
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at high and
low values.
@ -117,8 +114,7 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
def generate_crop_boxes(
im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
) -> tuple[list[list[int]], list[int]]:
"""
Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
"""Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
Args:
im_size (tuple[int, ...]): Height and width of the input image.
@ -198,8 +194,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w:
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
"""
Remove small disconnected regions or holes in a mask based on area threshold and mode.
"""Remove small disconnected regions or holes in a mask based on area threshold and mode.
Args:
mask (np.ndarray): Binary mask to process.
@ -236,8 +231,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tup
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculate bounding boxes in XYXY format around binary masks.
"""Calculate bounding boxes in XYXY format around binary masks.
Args:
masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).

View file

@ -127,8 +127,7 @@ def _build_sam(
checkpoint=None,
mobile_sam=False,
):
"""
Build a Segment Anything Model (SAM) with specified encoder parameters.
"""Build a Segment Anything Model (SAM) with specified encoder parameters.
Args:
encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
@ -224,8 +223,7 @@ def _build_sam2(
encoder_window_spec=[8, 4, 16, 8],
checkpoint=None,
):
"""
Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
"""Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
Args:
encoder_embed_dim (int, optional): Embedding dimension for the encoder.
@ -326,8 +324,7 @@ sam_model_map = {
def build_sam(ckpt="sam_b.pt"):
"""
Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
"""Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
Args:
ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.

View file

@ -25,8 +25,7 @@ from .predict import Predictor, SAM2Predictor
class SAM(Model):
"""
SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
"""SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for promptable
segmentation with versatility in image analysis. It supports various prompts such as bounding boxes, points, or
@ -49,8 +48,7 @@ class SAM(Model):
"""
def __init__(self, model: str = "sam_b.pt") -> None:
"""
Initialize the SAM (Segment Anything Model) instance.
"""Initialize the SAM (Segment Anything Model) instance.
Args:
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
@ -68,8 +66,7 @@ class SAM(Model):
super().__init__(model=model, task="segment")
def _load(self, weights: str, task=None):
"""
Load the specified weights into the SAM model.
"""Load the specified weights into the SAM model.
Args:
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
@ -84,12 +81,11 @@ class SAM(Model):
self.model = build_sam(weights)
def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
"""
Perform segmentation prediction on the given image or video source.
"""Perform segmentation prediction on the given image or video source.
Args:
source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or
a np.ndarray object.
source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or a
np.ndarray object.
stream (bool): If True, enables real-time streaming.
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
points (list[list[float]] | None): List of points for prompted segmentation.
@ -111,15 +107,14 @@ class SAM(Model):
return super().predict(source, stream, prompts=prompts, **kwargs)
def __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
"""
Perform segmentation prediction on the given image or video source.
"""Perform segmentation prediction on the given image or video source.
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model for
segmentation tasks.
Args:
source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image
object, or a np.ndarray object.
source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image object, or a
np.ndarray object.
stream (bool): If True, enables real-time streaming.
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
points (list[list[float]] | None): List of points for prompted segmentation.
@ -137,8 +132,7 @@ class SAM(Model):
return self.predict(source, stream, bboxes, points, labels, **kwargs)
def info(self, detailed: bool = False, verbose: bool = True):
"""
Log information about the SAM model.
"""Log information about the SAM model.
Args:
detailed (bool): If True, displays detailed information about the model layers and operations.
@ -156,8 +150,7 @@ class SAM(Model):
@property
def task_map(self) -> dict[str, dict[str, type[Predictor]]]:
"""
Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
"""Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
Returns:
(dict[str, dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding

View file

@ -17,8 +17,7 @@ from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis,
class DropPath(nn.Module):
"""
Implements stochastic depth regularization for neural networks during training.
"""Implements stochastic depth regularization for neural networks during training.
Attributes:
drop_prob (float): Probability of dropping a path during training.
@ -52,15 +51,14 @@ class DropPath(nn.Module):
class MaskDownSampler(nn.Module):
"""
A mask downsampling and embedding module for efficient processing of input masks.
"""A mask downsampling and embedding module for efficient processing of input masks.
This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks while
expanding their channel dimensions using convolutional layers, layer normalization, and activation functions.
Attributes:
encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and
activation functions for downsampling and embedding masks.
encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and activation
functions for downsampling and embedding masks.
Methods:
forward: Downsamples and encodes input mask to embed_dim channels.
@ -111,8 +109,7 @@ class MaskDownSampler(nn.Module):
class CXBlock(nn.Module):
"""
ConvNeXt Block for efficient feature extraction in convolutional neural networks.
"""ConvNeXt Block for efficient feature extraction in convolutional neural networks.
This block implements a modified version of the ConvNeXt architecture, offering improved performance and flexibility
in feature extraction.
@ -147,8 +144,7 @@ class CXBlock(nn.Module):
layer_scale_init_value: float = 1e-6,
use_dwconv: bool = True,
):
"""
Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
"""Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
This block implements a modified version of the ConvNeXt architecture, offering improved performance and
flexibility in feature extraction.
@ -205,8 +201,7 @@ class CXBlock(nn.Module):
class Fuser(nn.Module):
"""
A module for fusing features through multiple layers of a neural network.
"""A module for fusing features through multiple layers of a neural network.
This class applies a series of identical layers to an input tensor, optionally projecting the input first.
@ -227,8 +222,7 @@ class Fuser(nn.Module):
"""
def __init__(self, layer: nn.Module, num_layers: int, dim: int | None = None, input_projection: bool = False):
"""
Initialize the Fuser module for feature fusion through multiple layers.
"""Initialize the Fuser module for feature fusion through multiple layers.
This module creates a sequence of identical layers and optionally applies an input projection.
@ -261,8 +255,7 @@ class Fuser(nn.Module):
class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
"""
A two-way attention block for performing self-attention and cross-attention in both directions.
"""A two-way attention block for performing self-attention and cross-attention in both directions.
This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse inputs,
cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention from dense to sparse
@ -298,8 +291,7 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
Initialize a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
"""Initialize a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse
inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention from
@ -324,8 +316,7 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
class SAM2TwoWayTransformer(TwoWayTransformer):
"""
A Two-Way Transformer module for simultaneous attention to image and query points.
"""A Two-Way Transformer module for simultaneous attention to image and query points.
This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an input
image using queries with supplied positional embeddings. It is particularly useful for tasks like object detection,
@ -361,8 +352,7 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
activation: type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
Initialize a SAM2TwoWayTransformer instance.
"""Initialize a SAM2TwoWayTransformer instance.
This transformer decoder attends to an input image using queries with supplied positional embeddings. It is
designed for tasks like object detection, image segmentation, and point cloud processing.
@ -402,8 +392,7 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
class RoPEAttention(Attention):
"""
Implements rotary position encoding for attention mechanisms in transformer architectures.
"""Implements rotary position encoding for attention mechanisms in transformer architectures.
This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance the
positional awareness of the attention mechanism.
@ -500,8 +489,7 @@ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.T
class MultiScaleAttention(nn.Module):
"""
Implements multiscale self-attention with optional query pooling for efficient feature extraction.
"""Implements multiscale self-attention with optional query pooling for efficient feature extraction.
This class provides a flexible implementation of multiscale attention, allowing for optional downsampling of query
features through pooling. It's designed to enhance the model's ability to capture multiscale information in visual
@ -580,8 +568,7 @@ class MultiScaleAttention(nn.Module):
class MultiScaleBlock(nn.Module):
"""
A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
"""A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
This class implements a multiscale attention mechanism with optional window partitioning and downsampling, designed
for use in vision transformer architectures.
@ -695,8 +682,7 @@ class MultiScaleBlock(nn.Module):
class PositionEmbeddingSine(nn.Module):
"""
A module for generating sinusoidal positional embeddings for 2D inputs like images.
"""A module for generating sinusoidal positional embeddings for 2D inputs like images.
This class implements sinusoidal position encoding for 2D spatial positions, which can be used in transformer-based
models for computer vision tasks.
@ -810,8 +796,7 @@ class PositionEmbeddingSine(nn.Module):
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""Positional encoding using random spatial frequencies.
This class generates positional embeddings for input coordinates using random spatial frequencies. It is
particularly useful for transformer-based models that require position information.
@ -877,8 +862,7 @@ class PositionEmbeddingRandom(nn.Module):
class Block(nn.Module):
"""
Transformer block with support for window attention and residual propagation.
"""Transformer block with support for window attention and residual propagation.
This class implements a transformer block that can use either global or windowed self-attention, followed by a
feed-forward network. It supports relative positional embeddings and is designed for use in vision transformer
@ -916,8 +900,7 @@ class Block(nn.Module):
window_size: int = 0,
input_size: tuple[int, int] | None = None,
) -> None:
"""
Initialize a transformer block with optional window attention and relative positional embeddings.
"""Initialize a transformer block with optional window attention and relative positional embeddings.
This constructor sets up a transformer block that can use either global or windowed self-attention, followed by
a feed-forward network. It supports relative positional embeddings and is designed for use in vision transformer
@ -977,8 +960,7 @@ class Block(nn.Module):
class REAttention(nn.Module):
"""
Relative Position Attention module for efficient self-attention in transformer architectures.
"""Relative Position Attention module for efficient self-attention in transformer architectures.
This class implements a multi-head attention mechanism with relative positional embeddings, designed for use in
vision transformer models. It supports optional query pooling and window partitioning for efficient processing of
@ -1013,8 +995,7 @@ class REAttention(nn.Module):
rel_pos_zero_init: bool = True,
input_size: tuple[int, int] | None = None,
) -> None:
"""
Initialize a Relative Position Attention module for transformer-based architectures.
"""Initialize a Relative Position Attention module for transformer-based architectures.
This module implements multi-head attention with optional relative positional encodings, designed specifically
for vision tasks in transformer models.
@ -1069,8 +1050,7 @@ class REAttention(nn.Module):
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding module for vision transformer architectures.
"""Image to Patch Embedding module for vision transformer architectures.
This module converts an input image into a sequence of patch embeddings using a convolutional layer. It is commonly
used as the first layer in vision transformer architectures to transform image data into a suitable format for
@ -1098,8 +1078,7 @@ class PatchEmbed(nn.Module):
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Initialize the PatchEmbed module for converting image patches to embeddings.
"""Initialize the PatchEmbed module for converting image patches to embeddings.
This module is typically used as the first layer in vision transformer architectures to transform image data
into a suitable format for subsequent transformer blocks.

View file

@ -9,8 +9,7 @@ from ultralytics.nn.modules import MLP, LayerNorm2d
class MaskDecoder(nn.Module):
"""
Decoder module for generating masks and their associated quality scores using a transformer architecture.
"""Decoder module for generating masks and their associated quality scores using a transformer architecture.
This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
generate mask predictions along with their quality scores.
@ -47,8 +46,7 @@ class MaskDecoder(nn.Module):
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
) -> None:
"""
Initialize the MaskDecoder module for generating masks and their associated quality scores.
"""Initialize the MaskDecoder module for generating masks and their associated quality scores.
Args:
transformer_dim (int): Channel dimension for the transformer module.
@ -94,8 +92,7 @@ class MaskDecoder(nn.Module):
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
"""Predict masks given image and prompt embeddings.
Args:
image_embeddings (torch.Tensor): Embeddings from the image encoder.
@ -172,8 +169,7 @@ class MaskDecoder(nn.Module):
class SAM2MaskDecoder(nn.Module):
"""
Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
"""Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
This class extends the functionality of the MaskDecoder, incorporating additional features such as high-resolution
feature processing, dynamic multimask output, and object score prediction.
@ -233,8 +229,7 @@ class SAM2MaskDecoder(nn.Module):
pred_obj_scores_mlp: bool = False,
use_multimask_token_for_obj_ptr: bool = False,
) -> None:
"""
Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
"""Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
This decoder extends the functionality of MaskDecoder, incorporating additional features such as high-resolution
feature processing, dynamic multimask output, and object score prediction.
@ -319,8 +314,7 @@ class SAM2MaskDecoder(nn.Module):
repeat_image: bool,
high_res_features: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
"""Predict masks given image and prompt embeddings.
Args:
image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
@ -458,8 +452,7 @@ class SAM2MaskDecoder(nn.Module):
return torch.where(area_u > 0, area_i / area_u, 1.0)
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
"""
Dynamically select the most stable mask output based on stability scores and IoU predictions.
"""Dynamically select the most stable mask output based on stability scores and IoU predictions.
This method is used when outputting a single mask. If the stability score from the current single-mask output
(based on output token 0) falls below a threshold, it instead selects from multi-mask outputs (based on output
@ -467,8 +460,8 @@ class SAM2MaskDecoder(nn.Module):
tracking scenarios.
Args:
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
batch size, N is number of masks (typically 4), and H, W are mask dimensions.
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is batch size, N
is number of masks (typically 4), and H, W are mask dimensions.
all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
Returns:

View file

@ -21,8 +21,7 @@ from .blocks import (
class ImageEncoderViT(nn.Module):
"""
An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
"""An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
This class processes images by splitting them into patches, applying transformer blocks, and generating a final
encoded representation through a neck module.
@ -64,8 +63,7 @@ class ImageEncoderViT(nn.Module):
window_size: int = 0,
global_attn_indexes: tuple[int, ...] = (),
) -> None:
"""
Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
"""Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
Args:
img_size (int): Input image size, assumed to be square.
@ -156,8 +154,7 @@ class ImageEncoderViT(nn.Module):
class PromptEncoder(nn.Module):
"""
Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
"""Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
Attributes:
embed_dim (int): Dimension of the embeddings.
@ -193,8 +190,7 @@ class PromptEncoder(nn.Module):
mask_in_chans: int,
activation: type[nn.Module] = nn.GELU,
) -> None:
"""
Initialize the PromptEncoder module for encoding various types of prompts.
"""Initialize the PromptEncoder module for encoding various types of prompts.
Args:
embed_dim (int): The dimension of the embeddings.
@ -236,15 +232,14 @@ class PromptEncoder(nn.Module):
self.no_mask_embed = nn.Embedding(1, embed_dim)
def get_dense_pe(self) -> torch.Tensor:
"""
Return the dense positional encoding used for encoding point prompts.
"""Return the dense positional encoding used for encoding point prompts.
Generate a positional encoding for a dense set of points matching the shape of the image
encoding. The encoding is used to provide spatial information to the model when processing point prompts.
Returns:
(torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
height and width of the image embedding size, respectively.
(torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the height and
width of the image embedding size, respectively.
Examples:
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
@ -306,12 +301,11 @@ class PromptEncoder(nn.Module):
boxes: torch.Tensor | None,
masks: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Embed different types of prompts, returning both sparse and dense embeddings.
"""Embed different types of prompts, returning both sparse and dense embeddings.
Args:
points (tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
tensor contains coordinates of shape (B, N, 2), and the second tensor contains labels of shape (B, N).
points (tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first tensor
contains coordinates of shape (B, N, 2), and the second tensor contains labels of shape (B, N).
boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
@ -353,8 +347,7 @@ class PromptEncoder(nn.Module):
class MemoryEncoder(nn.Module):
"""
Encode pixel features and masks into a memory representation for efficient image segmentation.
"""Encode pixel features and masks into a memory representation for efficient image segmentation.
This class processes pixel-level features and masks, fusing them to generate encoded memory representations suitable
for downstream tasks in image segmentation models like SAM (Segment Anything Model).
@ -384,8 +377,7 @@ class MemoryEncoder(nn.Module):
out_dim,
in_dim=256, # in_dim of pix_feats
):
"""
Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
"""Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
This encoder processes pixel-level features and masks, fusing them to generate encoded memory representations
suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
@ -438,8 +430,7 @@ class MemoryEncoder(nn.Module):
class ImageEncoder(nn.Module):
"""
Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
"""Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
This class combines a trunk network for feature extraction with a neck network for feature refinement and positional
encoding generation. It can optionally discard the lowest resolution features.
@ -468,8 +459,7 @@ class ImageEncoder(nn.Module):
neck: nn.Module,
scalp: int = 0,
):
"""
Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.
"""Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.
This encoder combines a trunk network for feature extraction with a neck network for feature refinement and
positional encoding generation. It can optionally discard the lowest resolution features.
@ -512,8 +502,7 @@ class ImageEncoder(nn.Module):
class FpnNeck(nn.Module):
"""
A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
"""A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, similar to ViT
positional embedding interpolation.
@ -549,8 +538,7 @@ class FpnNeck(nn.Module):
fuse_type: str = "sum",
fpn_top_down_levels: list[int] | None = None,
):
"""
Initialize a modified Feature Pyramid Network (FPN) neck.
"""Initialize a modified Feature Pyramid Network (FPN) neck.
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, similar to
ViT positional embedding interpolation.
@ -602,8 +590,7 @@ class FpnNeck(nn.Module):
self.fpn_top_down_levels = list(fpn_top_down_levels)
def forward(self, xs: list[torch.Tensor]):
"""
Perform forward pass through the Feature Pyramid Network (FPN) neck.
"""Perform forward pass through the Feature Pyramid Network (FPN) neck.
This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
@ -612,8 +599,8 @@ class FpnNeck(nn.Module):
xs (list[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
Returns:
out (list[torch.Tensor]): List of output feature maps after FPN processing, each with shape
(B, d_model, H, W).
out (list[torch.Tensor]): List of output feature maps after FPN processing, each with shape (B, d_model, H,
W).
pos (list[torch.Tensor]): List of positional encodings corresponding to each output feature map.
Examples:
@ -655,8 +642,7 @@ class FpnNeck(nn.Module):
class Hiera(nn.Module):
"""
Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
"""Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for efficient
multiscale feature extraction. It uses a series of transformer blocks organized into stages, with optional pooling
@ -714,8 +700,7 @@ class Hiera(nn.Module):
),
return_interm_layers=True, # return feats from every stage
):
"""
Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.
"""Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.
Hiera is a hierarchical vision transformer architecture designed for efficient multiscale feature extraction in
image processing tasks. It uses a series of transformer blocks organized into stages, with optional pooling and
@ -816,8 +801,7 @@ class Hiera(nn.Module):
return pos_embed
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""
Perform forward pass through Hiera model, extracting multiscale features from input images.
"""Perform forward pass through Hiera model, extracting multiscale features from input images.
Args:
x (torch.Tensor): Input tensor with shape (B, C, H, W) representing a batch of images.

View file

@ -11,8 +11,7 @@ from .blocks import RoPEAttention
class MemoryAttentionLayer(nn.Module):
"""
Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
"""Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
This class combines self-attention, cross-attention, and feedforward components to process input tensors and
generate memory-based attention outputs.
@ -61,8 +60,7 @@ class MemoryAttentionLayer(nn.Module):
pos_enc_at_cross_attn_keys: bool = True,
pos_enc_at_cross_attn_queries: bool = False,
):
"""
Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
"""Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
Args:
d_model (int): Dimensionality of the model.
@ -145,8 +143,7 @@ class MemoryAttentionLayer(nn.Module):
query_pos: torch.Tensor | None = None,
num_k_exclude_rope: int = 0,
) -> torch.Tensor:
"""
Process input tensors through self-attention, cross-attention, and feedforward network layers.
"""Process input tensors through self-attention, cross-attention, and feedforward network layers.
Args:
tgt (torch.Tensor): Target tensor for self-attention with shape (N, L, D).
@ -168,8 +165,7 @@ class MemoryAttentionLayer(nn.Module):
class MemoryAttention(nn.Module):
"""
Memory attention module for processing sequential data with self and cross-attention mechanisms.
"""Memory attention module for processing sequential data with self and cross-attention mechanisms.
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
processing sequential data, particularly useful in transformer-like architectures.
@ -206,8 +202,7 @@ class MemoryAttention(nn.Module):
num_layers: int,
batch_first: bool = True, # Do layers expect batch first input?
):
"""
Initialize MemoryAttention with specified layers and normalization for sequential data processing.
"""Initialize MemoryAttention with specified layers and normalization for sequential data processing.
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
processing sequential data, particularly useful in transformer-like architectures.
@ -247,8 +242,7 @@ class MemoryAttention(nn.Module):
memory_pos: torch.Tensor | None = None, # pos_enc for cross-attention inputs
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
) -> torch.Tensor:
"""
Process inputs through attention layers, applying self and cross-attention with positional encoding.
"""Process inputs through attention layers, applying self and cross-attention with positional encoding.
Args:
curr (torch.Tensor): Self-attention input tensor, representing the current state.

View file

@ -23,8 +23,7 @@ NO_OBJ_SCORE = -1024.0
class SAMModel(nn.Module):
"""
Segment Anything Model (SAM) for object segmentation tasks.
"""Segment Anything Model (SAM) for object segmentation tasks.
This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images and input
prompts.
@ -61,8 +60,7 @@ class SAMModel(nn.Module):
pixel_mean: list[float] = (123.675, 116.28, 103.53),
pixel_std: list[float] = (58.395, 57.12, 57.375),
) -> None:
"""
Initialize the SAMModel class to predict object masks from an image and input prompts.
"""Initialize the SAMModel class to predict object masks from an image and input prompts.
Args:
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
@ -98,8 +96,7 @@ class SAMModel(nn.Module):
class SAM2Model(torch.nn.Module):
"""
SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
"""SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms for temporal
consistency and efficient tracking of objects across frames.
@ -136,24 +133,24 @@ class SAM2Model(torch.nn.Module):
use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.
max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
first frame.
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial
conditioning frames.
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
frame.
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
frames.
multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
memory encoder during evaluation.
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
encoder during evaluation.
sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
with clicks during evaluation.
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
prompt encoder and mask decoder on frames with mask input.
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
clicks during evaluation.
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM prompt
encoder and mask decoder on frames with mask input.
Methods:
forward_image: Process image batch through encoder to extract multi-level features.
@ -208,8 +205,7 @@ class SAM2Model(torch.nn.Module):
sam_mask_decoder_extra_args=None,
compile_image_encoder: bool = False,
):
"""
Initialize the SAM2Model for video object segmentation with memory-based tracking.
"""Initialize the SAM2Model for video object segmentation with memory-based tracking.
Args:
image_encoder (nn.Module): Visual encoder for extracting image features.
@ -220,35 +216,35 @@ class SAM2Model(torch.nn.Module):
backbone_stride (int): Stride of the image backbone output.
sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
with clicks during evaluation.
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
clicks during evaluation.
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
prompt encoder and mask decoder on frames with mask input.
max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
first frame.
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
frame.
use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial
conditioning frames.
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
frames.
multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
memory encoder during evaluation.
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
encoder during evaluation.
use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
cross-attention.
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
the encoder.
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in the
encoder.
proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
encoding in object pointers.
use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in the temporal positional encoding
in the object pointers.
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
during evaluation.
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during
evaluation.
pred_obj_scores (bool): Whether to predict if there is an object in the frame.
pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
@ -428,25 +424,23 @@ class SAM2Model(torch.nn.Module):
high_res_features=None,
multimask_output=False,
):
"""
Forward pass through SAM prompt encoders and mask heads.
"""Forward pass through SAM prompt encoders and mask heads.
This method processes image features and optional point/mask inputs to generate object masks and scores.
Args:
backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
point_inputs (dict[str, torch.Tensor] | None): Dictionary containing point prompts.
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
pixel-unit coordinates in (x, y) format for P input points.
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
0 means negative clicks, and -1 means padding.
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
same spatial size as the image.
high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes
(B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps for
SAM decoder.
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
output only 1 mask and its IoU estimate.
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute pixel-unit coordinates in
(x, y) format for P input points.
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, 0 means negative
clicks, and -1 means padding.
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the same spatial
size as the image.
high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes (B, C, 4*H, 4*W) and (B,
C, 2*H, 2*W) respectively, used as high-resolution feature maps for SAM decoder.
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, output only 1
mask and its IoU estimate.
Returns:
low_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.

View file

@ -22,8 +22,7 @@ from ultralytics.utils.instance import to_2tuple
class Conv2d_BN(torch.nn.Sequential):
"""
A sequential container that performs 2D convolution followed by batch normalization.
"""A sequential container that performs 2D convolution followed by batch normalization.
This module combines a 2D convolution layer with batch normalization, providing a common building block for
convolutional neural networks. The batch normalization weights and biases are initialized to specific values for
@ -52,8 +51,7 @@ class Conv2d_BN(torch.nn.Sequential):
groups: int = 1,
bn_weight_init: float = 1,
):
"""
Initialize a sequential container with 2D convolution followed by batch normalization.
"""Initialize a sequential container with 2D convolution followed by batch normalization.
Args:
a (int): Number of input channels.
@ -74,8 +72,7 @@ class Conv2d_BN(torch.nn.Sequential):
class PatchEmbed(nn.Module):
"""
Embed images into patches and project them into a specified embedding dimension.
"""Embed images into patches and project them into a specified embedding dimension.
This module converts input images into patch embeddings using a sequence of convolutional layers, effectively
downsampling the spatial dimensions while increasing the channel dimension.
@ -97,8 +94,7 @@ class PatchEmbed(nn.Module):
"""
def __init__(self, in_chans: int, embed_dim: int, resolution: int, activation):
"""
Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.
"""Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.
Args:
in_chans (int): Number of input channels.
@ -125,8 +121,7 @@ class PatchEmbed(nn.Module):
class MBConv(nn.Module):
"""
Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
"""Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
This module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution, and
projection phases, along with residual connections for improved gradient flow.
@ -153,8 +148,7 @@ class MBConv(nn.Module):
"""
def __init__(self, in_chans: int, out_chans: int, expand_ratio: float, activation, drop_path: float):
"""
Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.
"""Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.
Args:
in_chans (int): Number of input channels.
@ -195,8 +189,7 @@ class MBConv(nn.Module):
class PatchMerging(nn.Module):
"""
Merge neighboring patches in the feature map and project to a new dimension.
"""Merge neighboring patches in the feature map and project to a new dimension.
This class implements a patch merging operation that combines spatial information and adjusts the feature dimension
using a series of convolutional layers with batch normalization. It effectively reduces spatial resolution while
@ -221,8 +214,7 @@ class PatchMerging(nn.Module):
"""
def __init__(self, input_resolution: tuple[int, int], dim: int, out_dim: int, activation):
"""
Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.
"""Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.
Args:
input_resolution (tuple[int, int]): The input resolution (height, width) of the feature map.
@ -259,8 +251,7 @@ class PatchMerging(nn.Module):
class ConvLayer(nn.Module):
"""
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
"""Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
This layer optionally applies downsample operations to the output and supports gradient checkpointing for memory
efficiency during training.
@ -293,8 +284,7 @@ class ConvLayer(nn.Module):
out_dim: int | None = None,
conv_expand_ratio: float = 4.0,
):
"""
Initialize the ConvLayer with the given dimensions and settings.
"""Initialize the ConvLayer with the given dimensions and settings.
This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and optionally
applies downsampling to the output.
@ -345,8 +335,7 @@ class ConvLayer(nn.Module):
class MLP(nn.Module):
"""
Multi-layer Perceptron (MLP) module for transformer architectures.
"""Multi-layer Perceptron (MLP) module for transformer architectures.
This module applies layer normalization, two fully-connected layers with an activation function in between, and
dropout. It is commonly used in transformer-based architectures for processing token embeddings.
@ -376,8 +365,7 @@ class MLP(nn.Module):
activation=nn.GELU,
drop: float = 0.0,
):
"""
Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.
"""Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.
Args:
in_features (int): Number of input features.
@ -406,8 +394,7 @@ class MLP(nn.Module):
class Attention(torch.nn.Module):
"""
Multi-head attention module with spatial awareness and trainable attention biases.
"""Multi-head attention module with spatial awareness and trainable attention biases.
This module implements a multi-head attention mechanism with support for spatial awareness, applying attention
biases based on spatial resolution. It includes trainable attention biases for each unique offset between spatial
@ -444,8 +431,7 @@ class Attention(torch.nn.Module):
attn_ratio: float = 4,
resolution: tuple[int, int] = (14, 14),
):
"""
Initialize the Attention module for multi-head attention with spatial awareness.
"""Initialize the Attention module for multi-head attention with spatial awareness.
This module implements a multi-head attention mechanism with support for spatial awareness, applying attention
biases based on spatial resolution. It includes trainable attention biases for each unique offset between
@ -521,8 +507,7 @@ class Attention(torch.nn.Module):
class TinyViTBlock(nn.Module):
"""
TinyViT Block that applies self-attention and a local convolution to the input.
"""TinyViT Block that applies self-attention and a local convolution to the input.
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with local
convolutions to process input features efficiently. It supports windowed attention for computational efficiency and
@ -559,8 +544,7 @@ class TinyViTBlock(nn.Module):
local_conv_size: int = 3,
activation=nn.GELU,
):
"""
Initialize a TinyViT block with self-attention and local convolution.
"""Initialize a TinyViT block with self-attention and local convolution.
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with local
convolutions to process input features efficiently.
@ -644,8 +628,7 @@ class TinyViTBlock(nn.Module):
return x + self.drop_path(self.mlp(x))
def extra_repr(self) -> str:
"""
Return a string representation of the TinyViTBlock's parameters.
"""Return a string representation of the TinyViTBlock's parameters.
This method provides a formatted string containing key information about the TinyViTBlock, including its
dimension, input resolution, number of attention heads, window size, and MLP ratio.
@ -665,8 +648,7 @@ class TinyViTBlock(nn.Module):
class BasicLayer(nn.Module):
"""
A basic TinyViT layer for one stage in a TinyViT architecture.
"""A basic TinyViT layer for one stage in a TinyViT architecture.
This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks and an optional
downsampling operation. It processes features at a specific resolution and dimensionality within the overall
@ -704,8 +686,7 @@ class BasicLayer(nn.Module):
activation=nn.GELU,
out_dim: int | None = None,
):
"""
Initialize a BasicLayer in the TinyViT architecture.
"""Initialize a BasicLayer in the TinyViT architecture.
This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to process
feature maps at a specific resolution and dimensionality within the TinyViT model.
@ -770,8 +751,7 @@ class BasicLayer(nn.Module):
class TinyViT(nn.Module):
"""
TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
"""TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
This class implements the TinyViT model, which combines elements of vision transformers and convolutional neural
networks for improved efficiency and performance on vision tasks. It features hierarchical processing with patch
@ -815,8 +795,7 @@ class TinyViT(nn.Module):
local_conv_size: int = 3,
layer_lr_decay: float = 1.0,
):
"""
Initialize the TinyViT model.
"""Initialize the TinyViT model.
This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of attention and
convolution blocks, and a classification head.

View file

@ -11,8 +11,7 @@ from ultralytics.nn.modules import MLPBlock
class TwoWayTransformer(nn.Module):
"""
A Two-Way Transformer module for simultaneous attention to image and query points.
"""A Two-Way Transformer module for simultaneous attention to image and query points.
This class implements a specialized transformer decoder that attends to an input image using queries with supplied
positional embeddings. It's useful for tasks like object detection, image segmentation, and point cloud processing.
@ -47,8 +46,7 @@ class TwoWayTransformer(nn.Module):
activation: type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
Initialize a Two-Way Transformer for simultaneous attention to image and query points.
"""Initialize a Two-Way Transformer for simultaneous attention to image and query points.
Args:
depth (int): Number of layers in the transformer.
@ -86,8 +84,7 @@ class TwoWayTransformer(nn.Module):
image_pe: torch.Tensor,
point_embedding: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Process image and point embeddings through the Two-Way Transformer.
"""Process image and point embeddings through the Two-Way Transformer.
Args:
image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
@ -126,8 +123,7 @@ class TwoWayTransformer(nn.Module):
class TwoWayAttentionBlock(nn.Module):
"""
A two-way attention block for simultaneous attention to image and query points.
"""A two-way attention block for simultaneous attention to image and query points.
This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense inputs to
@ -166,8 +162,7 @@ class TwoWayAttentionBlock(nn.Module):
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
"""Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
This block implements a specialized transformer layer with four main components: self-attention on sparse
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of
@ -199,8 +194,7 @@ class TwoWayAttentionBlock(nn.Module):
def forward(
self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply two-way attention to process query and key embeddings in a transformer block.
"""Apply two-way attention to process query and key embeddings in a transformer block.
Args:
queries (torch.Tensor): Query embeddings with shape (B, N_queries, embedding_dim).
@ -244,8 +238,7 @@ class TwoWayAttentionBlock(nn.Module):
class Attention(nn.Module):
"""
An attention layer with downscaling capability for embedding size after projection.
"""An attention layer with downscaling capability for embedding size after projection.
This class implements a multi-head attention mechanism with the option to downsample the internal dimension of
queries, keys, and values.
@ -281,8 +274,7 @@ class Attention(nn.Module):
downsample_rate: int = 1,
kv_in_dim: int | None = None,
) -> None:
"""
Initialize the Attention module with specified dimensions and settings.
"""Initialize the Attention module with specified dimensions and settings.
Args:
embedding_dim (int): Dimensionality of input embeddings.
@ -320,8 +312,7 @@ class Attention(nn.Module):
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Apply multi-head attention to query, key, and value tensors with optional downsampling.
"""Apply multi-head attention to query, key, and value tensors with optional downsampling.
Args:
q (torch.Tensor): Query tensor with shape (B, N_q, embedding_dim).

View file

@ -9,8 +9,7 @@ import torch.nn.functional as F
def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any], max_cond_frame_num: int):
"""
Select the closest conditioning frames to a given frame index.
"""Select the closest conditioning frames to a given frame index.
Args:
frame_idx (int): Current frame index.
@ -62,8 +61,7 @@ def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any
def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):
"""
Generate 1D sinusoidal positional embeddings for given positions and dimensions.
"""Generate 1D sinusoidal positional embeddings for given positions and dimensions.
Args:
pos_inds (torch.Tensor): Position indices for which to generate embeddings.
@ -89,8 +87,7 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
def init_t_xy(end_x: int, end_y: int):
"""
Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
"""Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
tensor and corresponding x and y coordinate tensors.
@ -117,8 +114,7 @@ def init_t_xy(end_x: int, end_y: int):
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
"""
Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
"""Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
frequency components for the x and y dimensions.
@ -150,8 +146,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
"""
Reshape frequency tensor for broadcasting with input tensor.
"""Reshape frequency tensor for broadcasting with input tensor.
Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor. This function
is typically used in positional encoding operations.
@ -179,8 +174,7 @@ def apply_rotary_enc(
freqs_cis: torch.Tensor,
repeat_freqs_k: bool = False,
):
"""
Apply rotary positional encoding to query and key tensors.
"""Apply rotary positional encoding to query and key tensors.
This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
components. RoPE is a technique that injects relative position information into self-attention mechanisms.
@ -188,10 +182,10 @@ def apply_rotary_enc(
Args:
xq (torch.Tensor): Query tensor to encode with positional information.
xk (torch.Tensor): Key tensor to encode with positional information.
freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
last two dimensions of xq.
repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
to match key sequence length.
freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the last
two dimensions of xq.
repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension to match
key sequence length.
Returns:
xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
@ -220,8 +214,7 @@ def apply_rotary_enc(
def window_partition(x: torch.Tensor, window_size: int):
"""
Partition input tensor into non-overlapping windows with padding if needed.
"""Partition input tensor into non-overlapping windows with padding if needed.
Args:
x (torch.Tensor): Input tensor with shape (B, H, W, C).
@ -251,8 +244,7 @@ def window_partition(x: torch.Tensor, window_size: int):
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]):
"""
Unpartition windowed sequences into original sequences and remove padding.
"""Unpartition windowed sequences into original sequences and remove padding.
This function reverses the windowing process, reconstructing the original input from windowed segments and removing
any padding that was added during the windowing process.
@ -266,8 +258,8 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[in
hw (tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
Returns:
(torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
are the original height and width, and C is the number of channels.
(torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W are the
original height and width, and C is the number of channels.
Examples:
>>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
@ -289,18 +281,16 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[in
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Extract relative positional embeddings based on query and key sizes.
"""Extract relative positional embeddings based on query and key sizes.
Args:
q_size (int): Size of the query.
k_size (int): Size of the key.
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
distance and C is the embedding dimension.
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative distance
and C is the embedding dimension.
Returns:
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
k_size, C).
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size, k_size, C).
Examples:
>>> q_size, k_size = 8, 16
@ -338,8 +328,7 @@ def add_decomposed_rel_pos(
q_size: tuple[int, int],
k_size: tuple[int, int],
) -> torch.Tensor:
"""
Add decomposed Relative Positional Embeddings to the attention map.
"""Add decomposed Relative Positional Embeddings to the attention map.
This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
@ -354,8 +343,8 @@ def add_decomposed_rel_pos(
k_size (tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
Returns:
(torch.Tensor): Updated attention map with added relative positional embeddings, shape
(B, q_h * q_w, k_h * k_w).
(torch.Tensor): Updated attention map with added relative positional embeddings, shape (B, q_h * q_w, k_h *
k_w).
Examples:
>>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8

View file

@ -38,8 +38,7 @@ from .amg import (
class Predictor(BasePredictor):
"""
Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
"""Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image segmentation
tasks. It supports various input prompts like points, bounding boxes, and masks for fine-grained control over
@ -81,8 +80,7 @@ class Predictor(BasePredictor):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the Predictor with configuration, overrides, and callbacks.
"""Initialize the Predictor with configuration, overrides, and callbacks.
Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True for
@ -109,8 +107,7 @@ class Predictor(BasePredictor):
self.segment_all = False
def preprocess(self, im):
"""
Preprocess the input image for model inference.
"""Preprocess the input image for model inference.
This method prepares the input image by applying transformations and normalization. It supports both
torch.Tensor and list of np.ndarray as input formats.
@ -142,8 +139,7 @@ class Predictor(BasePredictor):
return im
def pre_transform(self, im):
"""
Perform initial transformations on the input image for preprocessing.
"""Perform initial transformations on the input image for preprocessing.
This method applies transformations such as resizing to prepare the image for further preprocessing. Currently,
batched inference is not supported; hence the list length should be 1.
@ -169,8 +165,7 @@ class Predictor(BasePredictor):
return [letterbox(image=x) for x in im]
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
"""
Perform image segmentation inference based on the given input cues, using the currently loaded image.
"""Perform image segmentation inference based on the given input cues, using the currently loaded image.
This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder,
and mask decoder for real-time and promptable segmentation tasks.
@ -208,8 +203,7 @@ class Predictor(BasePredictor):
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
"""
Perform image segmentation inference based on input cues using SAM's specialized architecture.
"""Perform image segmentation inference based on input cues using SAM's specialized architecture.
This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation. It
processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
@ -248,8 +242,7 @@ class Predictor(BasePredictor):
masks=None,
multimask_output=False,
):
"""
Perform inference on image features using the SAM model.
"""Perform inference on image features using the SAM model.
Args:
features (torch.Tensor): Extracted image features with shape (B, C, H, W) from the SAM model image encoder.
@ -281,8 +274,7 @@ class Predictor(BasePredictor):
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def _prepare_prompts(self, dst_shape, src_shape, bboxes=None, points=None, labels=None, masks=None):
"""
Prepare and transform the input prompts for processing based on the destination shape.
"""Prepare and transform the input prompts for processing based on the destination shape.
Args:
dst_shape (tuple[int, int]): The target shape (height, width) for the prompts.
@ -346,8 +338,7 @@ class Predictor(BasePredictor):
stability_score_offset=0.95,
crop_nms_thresh=0.7,
):
"""
Perform image segmentation using the Segment Anything Model (SAM).
"""Perform image segmentation using the Segment Anything Model (SAM).
This method segments an entire image into constituent parts by leveraging SAM's advanced architecture and
real-time performance capabilities. It can optionally work on image crops for finer segmentation.
@ -445,8 +436,7 @@ class Predictor(BasePredictor):
return pred_masks, pred_scores, pred_bboxes
def setup_model(self, model=None, verbose=True):
"""
Initialize the Segment Anything Model (SAM) for inference.
"""Initialize the Segment Anything Model (SAM) for inference.
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
parameters for image normalization and other Ultralytics compatibility settings.
@ -484,8 +474,7 @@ class Predictor(BasePredictor):
return build_sam(self.args.model)
def postprocess(self, preds, img, orig_imgs):
"""
Post-process SAM's inference outputs to generate object detection masks and bounding boxes.
"""Post-process SAM's inference outputs to generate object detection masks and bounding boxes.
This method scales masks and boxes to the original image size and applies a threshold to the mask
predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
@ -499,8 +488,8 @@ class Predictor(BasePredictor):
orig_imgs (list[np.ndarray] | torch.Tensor): The original, unprocessed images.
Returns:
(list[Results]): List of Results objects containing detection masks, bounding boxes, and other
metadata for each processed image.
(list[Results]): List of Results objects containing detection masks, bounding boxes, and other metadata for
each processed image.
Examples:
>>> predictor = Predictor()
@ -537,15 +526,14 @@ class Predictor(BasePredictor):
return results
def setup_source(self, source):
"""
Set up the data source for inference.
"""Set up the data source for inference.
This method configures the data source from which images will be fetched for inference. It supports various
input types such as image files, directories, video files, and other compatible data sources.
Args:
source (str | Path | None): The path or identifier for the image data source. Can be a file path,
directory path, URL, or other supported source types.
source (str | Path | None): The path or identifier for the image data source. Can be a file path, directory
path, URL, or other supported source types.
Examples:
>>> predictor = Predictor()
@ -562,16 +550,15 @@ class Predictor(BasePredictor):
super().setup_source(source)
def set_image(self, image):
"""
Preprocess and set a single image for inference.
"""Preprocess and set a single image for inference.
This method prepares the model for inference on a single image by setting up the model if not already
initialized, configuring the data source, and preprocessing the image for feature extraction. It ensures that
only one image is set at a time and extracts image features for subsequent use.
Args:
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
an image read by cv2.
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing an image read by
cv2.
Raises:
AssertionError: If more than one image is attempted to be set.
@ -613,8 +600,7 @@ class Predictor(BasePredictor):
@staticmethod
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
"""
Remove small disconnected regions and holes from segmentation masks.
"""Remove small disconnected regions and holes from segmentation masks.
This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM). It
removes small disconnected regions and holes from the input masks, and then performs Non-Maximum Suppression
@ -675,8 +661,7 @@ class Predictor(BasePredictor):
masks=None,
multimask_output=False,
):
"""
Perform prompts preprocessing and inference on provided image features using the SAM model.
"""Perform prompts preprocessing and inference on provided image features using the SAM model.
Args:
features (torch.Tensor | dict[str, Any]): Extracted image features from the SAM/SAM2 model image encoder.
@ -714,8 +699,7 @@ class Predictor(BasePredictor):
class SAM2Predictor(Predictor):
"""
SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
"""SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
This class extends the base Predictor class to implement SAM2-specific functionality for image segmentation tasks.
It provides methods for model initialization, feature extraction, and prompt-based inference.
@ -755,8 +739,7 @@ class SAM2Predictor(Predictor):
return build_sam(self.args.model)
def _prepare_prompts(self, dst_shape, src_shape, bboxes=None, points=None, labels=None, masks=None):
"""
Prepare and transform the input prompts for processing based on the destination shape.
"""Prepare and transform the input prompts for processing based on the destination shape.
Args:
dst_shape (tuple[int, int]): The target shape (height, width) for the prompts.
@ -790,8 +773,7 @@ class SAM2Predictor(Predictor):
return points, labels, masks
def set_image(self, image):
"""
Preprocess and set a single image for inference using the SAM2 model.
"""Preprocess and set a single image for inference using the SAM2 model.
This method initializes the model if not already done, configures the data source to the specified image, and
preprocesses the image for feature extraction. It supports setting only one image at a time.
@ -847,8 +829,7 @@ class SAM2Predictor(Predictor):
multimask_output=False,
img_idx=-1,
):
"""
Perform inference on image features using the SAM2 model.
"""Perform inference on image features using the SAM2 model.
Args:
features (torch.Tensor | dict[str, Any]): Extracted image features with shape (B, C, H, W) from the SAM2
@ -892,8 +873,7 @@ class SAM2Predictor(Predictor):
class SAM2VideoPredictor(SAM2Predictor):
"""
SAM2VideoPredictor to handle user interactions with videos and manage inference states.
"""SAM2VideoPredictor to handle user interactions with videos and manage inference states.
This class extends the functionality of SAM2Predictor to support video processing and maintains the state of
inference operations. It includes configurations for managing non-overlapping masks, clearing memory for
@ -929,8 +909,7 @@ class SAM2VideoPredictor(SAM2Predictor):
# fill_hole_area = 8 # not used
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the predictor with configuration and optional overrides.
"""Initialize the predictor with configuration and optional overrides.
This constructor initializes the SAM2VideoPredictor with a given configuration, applies any specified overrides,
and sets up the inference state along with certain flags that control the behavior of the predictor.
@ -953,8 +932,7 @@ class SAM2VideoPredictor(SAM2Predictor):
self.callbacks["on_predict_start"].append(self.init_state)
def get_model(self):
"""
Retrieve and configure the model with binarization enabled.
"""Retrieve and configure the model with binarization enabled.
Notes:
This method overrides the base class implementation to set the binarize flag to True.
@ -964,10 +942,9 @@ class SAM2VideoPredictor(SAM2Predictor):
return model
def inference(self, im, bboxes=None, points=None, labels=None, masks=None):
"""
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
mask decoder for real-time and promptable segmentation tasks.
"""Perform image segmentation inference based on the given input cues, using the currently loaded image. This
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
encoder, and mask decoder for real-time and promptable segmentation tasks.
Args:
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
@ -1037,8 +1014,7 @@ class SAM2VideoPredictor(SAM2Predictor):
return pred_masks, torch.ones(pred_masks.shape[0], dtype=pred_masks.dtype, device=pred_masks.device)
def postprocess(self, preds, img, orig_imgs):
"""
Post-process the predictions to apply non-overlapping constraints if required.
"""Post-process the predictions to apply non-overlapping constraints if required.
This method extends the post-processing functionality by applying non-overlapping constraints to the predicted
masks if the `non_overlap_masks` flag is set to True. This ensures that the masks do not overlap, which can be
@ -1072,8 +1048,7 @@ class SAM2VideoPredictor(SAM2Predictor):
masks=None,
frame_idx=0,
):
"""
Add new points or masks to a specific frame for a given object ID.
"""Add new points or masks to a specific frame for a given object ID.
This method updates the inference state with new prompts (points or masks) for a specified object and frame
index. It ensures that the prompts are either points or masks, but not both, and updates the internal state
@ -1169,8 +1144,7 @@ class SAM2VideoPredictor(SAM2Predictor):
@smart_inference_mode()
def propagate_in_video_preflight(self):
"""
Prepare inference_state and consolidate temporary outputs before tracking.
"""Prepare inference_state and consolidate temporary outputs before tracking.
This method marks the start of tracking, disallowing the addition of new objects until the session is reset. It
consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. Additionally,
@ -1240,8 +1214,7 @@ class SAM2VideoPredictor(SAM2Predictor):
@staticmethod
def init_state(predictor):
"""
Initialize an inference state for the predictor.
"""Initialize an inference state for the predictor.
This function sets up the initial state required for performing inference on video data. It includes
initializing various dictionaries and ordered dictionaries that will store inputs, outputs, and other metadata
@ -1287,8 +1260,7 @@ class SAM2VideoPredictor(SAM2Predictor):
predictor.inference_state = inference_state
def get_im_features(self, im, batch=1):
"""
Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.
"""Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.
Args:
im (torch.Tensor): The input image tensor.
@ -1315,8 +1287,7 @@ class SAM2VideoPredictor(SAM2Predictor):
return vis_feats, vis_pos_embed, feat_sizes
def _obj_id_to_idx(self, obj_id):
"""
Map client-side object id to model-side object index.
"""Map client-side object id to model-side object index.
Args:
obj_id (int): The unique identifier of the object provided by the client side.
@ -1378,8 +1349,7 @@ class SAM2VideoPredictor(SAM2Predictor):
run_mem_encoder,
prev_sam_mask_logits=None,
):
"""
Run tracking on a single frame based on current inputs and previous memory.
"""Run tracking on a single frame based on current inputs and previous memory.
Args:
output_dict (dict): The dictionary containing the output states of the tracking process.
@ -1442,8 +1412,7 @@ class SAM2VideoPredictor(SAM2Predictor):
return current_out
def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
"""
Cache and manage the positional encoding for mask memory across frames and objects.
"""Cache and manage the positional encoding for mask memory across frames and objects.
This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for mask memory, which is
constant across frames and objects, thus reducing the amount of redundant information stored during an inference
@ -1452,8 +1421,8 @@ class SAM2VideoPredictor(SAM2Predictor):
batch size.
Args:
out_maskmem_pos_enc (list[torch.Tensor] | None): The positional encoding for mask memory.
Should be a list of tensors or None.
out_maskmem_pos_enc (list[torch.Tensor] | None): The positional encoding for mask memory. Should be a list
of tensors or None.
Returns:
(list[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
@ -1486,8 +1455,7 @@ class SAM2VideoPredictor(SAM2Predictor):
is_cond=False,
run_mem_encoder=False,
):
"""
Consolidate per-object temporary outputs into a single output for all objects.
"""Consolidate per-object temporary outputs into a single output for all objects.
This method combines the temporary outputs for each object on a given frame into a unified
output. It fills in any missing objects either from the main output dictionary or leaves
@ -1497,8 +1465,8 @@ class SAM2VideoPredictor(SAM2Predictor):
Args:
frame_idx (int): The index of the frame for which to consolidate outputs.
is_cond (bool, optional): Indicates if the frame is considered a conditioning frame.
run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after
consolidating the outputs.
run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after consolidating the
outputs.
Returns:
(dict): A consolidated output dictionary containing the combined results for all objects.
@ -1587,8 +1555,7 @@ class SAM2VideoPredictor(SAM2Predictor):
return consolidated_out
def _get_empty_mask_ptr(self, frame_idx):
"""
Get a dummy object pointer based on an empty mask on the current frame.
"""Get a dummy object pointer based on an empty mask on the current frame.
Args:
frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
@ -1618,8 +1585,7 @@ class SAM2VideoPredictor(SAM2Predictor):
return current_out["obj_ptr"]
def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
"""
Run the memory encoder on masks.
"""Run the memory encoder on masks.
This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
memory also needs to be computed again with the memory encoder.
@ -1651,8 +1617,7 @@ class SAM2VideoPredictor(SAM2Predictor):
), maskmem_pos_enc
def _add_output_per_object(self, frame_idx, current_out, storage_key):
"""
Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
"""Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
The resulting slices share the same tensor storage.
@ -1682,8 +1647,7 @@ class SAM2VideoPredictor(SAM2Predictor):
obj_output_dict[storage_key][frame_idx] = obj_out
def _clear_non_cond_mem_around_input(self, frame_idx):
"""
Remove the non-conditioning memory around the input frame.
"""Remove the non-conditioning memory around the input frame.
When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain
outdated object appearance information and could confuse the model. This method clears those non-conditioning
@ -1703,8 +1667,7 @@ class SAM2VideoPredictor(SAM2Predictor):
class SAM2DynamicInteractivePredictor(SAM2Predictor):
"""
SAM2DynamicInteractivePredictor extends SAM2Predictor to support dynamic interactions with video frames or a
"""SAM2DynamicInteractivePredictor extends SAM2Predictor to support dynamic interactions with video frames or a
sequence of images.
Attributes:
@ -1736,8 +1699,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
max_obj_num: int = 3,
_callbacks: dict[str, Any] | None = None,
) -> None:
"""
Initialize the predictor with configuration and optional overrides.
"""Initialize the predictor with configuration and optional overrides.
This constructor initializes the SAM2DynamicInteractivePredictor with a given configuration, applies any
specified overrides
@ -1783,12 +1745,11 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
obj_ids: list[int] | None = None,
update_memory: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform inference on a single image with optional bounding boxes, masks, points and object IDs. It has two
modes: one is to run inference on a single image without updating the memory, and the other is to update the
memory with the provided prompts and object IDs. When update_memory is True, it will update the memory with the
provided prompts and obj_ids. When update_memory is False, it will only run inference on the provided image
without updating the memory.
"""Perform inference on a single image with optional bounding boxes, masks, points and object IDs. It has two
modes: one is to run inference on a single image without updating the memory, and the other is to update
the memory with the provided prompts and object IDs. When update_memory is True, it will update the
memory with the provided prompts and obj_ids. When update_memory is False, it will only run inference on
the provided image without updating the memory.
Args:
im (torch.Tensor | np.ndarray): The input image tensor or numpy array.
@ -1842,8 +1803,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def get_im_features(self, img: torch.Tensor | np.ndarray) -> None:
"""
Initialize the image state by processing the input image and extracting features.
"""Initialize the image state by processing the input image and extracting features.
Args:
img (torch.Tensor | np.ndarray): The input image tensor or numpy array.
@ -1866,8 +1826,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
labels: torch.Tensor | None = None,
masks: torch.Tensor | None = None,
) -> None:
"""
Append the imgState to the memory_bank and update the memory for the model.
"""Append the imgState to the memory_bank and update the memory for the model.
Args:
obj_ids (list[int]): List of object IDs corresponding to the prompts.
@ -1941,12 +1900,11 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
self.memory_bank.append(consolidated_out)
def _prepare_memory_conditioned_features(self, obj_idx: int | None) -> torch.Tensor:
"""
Prepare the memory-conditioned features for the current image state. If obj_idx is provided, it supposes to
prepare features for a specific prompted object in the image. If obj_idx is None, it prepares features for all
objects in the image. If there is no memory, it will directly add a no-memory embedding to the current vision
features. If there is memory, it will use the memory features from previous frames to condition the current
vision features using a transformer attention mechanism.
"""Prepare the memory-conditioned features for the current image state. If obj_idx is provided, it supposes to
prepare features for a specific prompted object in the image. If obj_idx is None, it prepares features
for all objects in the image. If there is no memory, it will directly add a no-memory embedding to the
current vision features. If there is memory, it will use the memory features from previous frames to
condition the current vision features using a transformer attention mechanism.
Args:
obj_idx (int | None): The index of the object for which to prepare the features.
@ -1989,8 +1947,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
return memory, memory_pos_embed
def _obj_id_to_idx(self, obj_id: int) -> int | None:
"""
Map client-side object id to model-side object index.
"""Map client-side object id to model-side object index.
Args:
obj_id (int): The client-side object ID.
@ -2007,8 +1964,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
label: torch.Tensor | None = None,
mask: torch.Tensor | None = None,
) -> dict[str, Any]:
"""
Tracking step for the current image state to predict masks.
"""Tracking step for the current image state to predict masks.
This method processes the image features and runs the SAM heads to predict masks. If obj_idx is provided, it
processes the features for a specific prompted object in the image. If obj_idx is None, it processes the

View file

@ -15,8 +15,7 @@ from .ops import HungarianMatcher
class DETRLoss(nn.Module):
"""
DETR (DEtection TRansformer) Loss class for calculating various loss components.
"""DETR (DEtection TRansformer) Loss class for calculating various loss components.
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the DETR
object detection model.
@ -47,8 +46,7 @@ class DETRLoss(nn.Module):
gamma: float = 1.5,
alpha: float = 0.25,
):
"""
Initialize DETR loss function with customizable components and gains.
"""Initialize DETR loss function with customizable components and gains.
Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
losses and various loss types.
@ -82,8 +80,7 @@ class DETRLoss(nn.Module):
def _get_loss_class(
self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
) -> dict[str, torch.Tensor]:
"""
Compute classification loss based on predictions, target values, and ground truth scores.
"""Compute classification loss based on predictions, target values, and ground truth scores.
Args:
pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
@ -124,8 +121,7 @@ class DETRLoss(nn.Module):
def _get_loss_bbox(
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
) -> dict[str, torch.Tensor]:
"""
Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
"""Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
@ -199,8 +195,7 @@ class DETRLoss(nn.Module):
masks: torch.Tensor | None = None,
gt_mask: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""
Get auxiliary losses for intermediate decoder layers.
"""Get auxiliary losses for intermediate decoder layers.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
@ -261,8 +256,7 @@ class DETRLoss(nn.Module):
@staticmethod
def _get_index(match_indices: list[tuple]) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Extract batch indices, source indices, and destination indices from match indices.
"""Extract batch indices, source indices, and destination indices from match indices.
Args:
match_indices (list[tuple]): List of tuples containing matched indices.
@ -279,8 +273,7 @@ class DETRLoss(nn.Module):
def _get_assigned_bboxes(
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: list[tuple]
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
"""Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes.
@ -317,8 +310,7 @@ class DETRLoss(nn.Module):
postfix: str = "",
match_indices: list[tuple] | None = None,
) -> dict[str, torch.Tensor]:
"""
Calculate losses for a single prediction layer.
"""Calculate losses for a single prediction layer.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes.
@ -364,8 +356,7 @@ class DETRLoss(nn.Module):
postfix: str = "",
**kwargs: Any,
) -> dict[str, torch.Tensor]:
"""
Calculate loss for predicted bounding boxes and scores.
"""Calculate loss for predicted bounding boxes and scores.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
@ -400,8 +391,7 @@ class DETRLoss(nn.Module):
class RTDETRDetectionLoss(DETRLoss):
"""
Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
"""Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
an additional denoising training loss when provided with denoising metadata.
@ -415,8 +405,7 @@ class RTDETRDetectionLoss(DETRLoss):
dn_scores: torch.Tensor | None = None,
dn_meta: dict[str, Any] | None = None,
) -> dict[str, torch.Tensor]:
"""
Forward pass to compute detection loss with optional denoising loss.
"""Forward pass to compute detection loss with optional denoising loss.
Args:
preds (tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
@ -452,8 +441,7 @@ class RTDETRDetectionLoss(DETRLoss):
def get_dn_match_indices(
dn_pos_idx: list[torch.Tensor], dn_num_group: int, gt_groups: list[int]
) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""
Get match indices for denoising.
"""Get match indices for denoising.
Args:
dn_pos_idx (list[torch.Tensor]): List of tensors containing positive indices for denoising.

View file

@ -14,8 +14,7 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
class HungarianMatcher(nn.Module):
"""
A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
"""A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
@ -56,8 +55,7 @@ class HungarianMatcher(nn.Module):
alpha: float = 0.25,
gamma: float = 2.0,
):
"""
Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
"""Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
Args:
cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
@ -88,8 +86,7 @@ class HungarianMatcher(nn.Module):
masks: torch.Tensor | None = None,
gt_mask: list[torch.Tensor] | None = None,
) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""
Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
"""Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
@ -105,9 +102,9 @@ class HungarianMatcher(nn.Module):
gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
Returns:
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple
(index_i, index_j), where index_i is the tensor of indices of the selected predictions (in order) and
index_j is the tensor of indices of the corresponding selected ground truth targets (in order).
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i,
index_j), where index_i is the tensor of indices of the selected predictions (in order) and index_j is
the tensor of indices of the corresponding selected ground truth targets (in order).
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
"""
bs, nq, nc = pred_scores.shape
@ -198,16 +195,15 @@ def get_cdn_group(
box_noise_scale: float = 1.0,
training: bool = False,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
"""
Generate contrastive denoising training group with positive and negative samples from ground truths.
"""Generate contrastive denoising training group with positive and negative samples from ground truths.
This function creates denoising queries for contrastive denoising training by adding noise to ground truth bounding
boxes and class labels. It generates both positive and negative samples to improve model robustness.
Args:
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),
'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of ground
truths per image.
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)), 'gt_bboxes'
(torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of ground truths
per image.
num_classes (int): Total number of object classes.
num_queries (int): Number of object queries.
class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.

View file

@ -11,8 +11,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
class ClassificationPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on a classification model.
"""A class extending the BasePredictor class for prediction based on a classification model.
This predictor handles the specific requirements of classification models, including preprocessing images and
postprocessing predictions to generate classification results.
@ -36,8 +35,7 @@ class ClassificationPredictor(BasePredictor):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
"""Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
tasks. It ensures the task is set to 'classify' regardless of input configuration.
@ -72,8 +70,7 @@ class ClassificationPredictor(BasePredictor):
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
def postprocess(self, preds, img, orig_imgs):
"""
Process predictions to return Results objects with classification probabilities.
"""Process predictions to return Results objects with classification probabilities.
Args:
preds (torch.Tensor): Raw predictions from the model.

View file

@ -17,8 +17,7 @@ from ultralytics.utils.torch_utils import is_parallel, torch_distributed_zero_fi
class ClassificationTrainer(BaseTrainer):
"""
A trainer class extending BaseTrainer for training image classification models.
"""A trainer class extending BaseTrainer for training image classification models.
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
and torchvision models with comprehensive dataset handling and validation.
@ -51,8 +50,7 @@ class ClassificationTrainer(BaseTrainer):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a ClassificationTrainer object.
"""Initialize a ClassificationTrainer object.
Args:
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
@ -71,8 +69,7 @@ class ClassificationTrainer(BaseTrainer):
self.model.names = self.data["names"]
def get_model(self, cfg=None, weights=None, verbose: bool = True):
"""
Return a modified PyTorch model configured for training YOLO classification.
"""Return a modified PyTorch model configured for training YOLO classification.
Args:
cfg (Any, optional): Model configuration.
@ -96,8 +93,7 @@ class ClassificationTrainer(BaseTrainer):
return model
def setup_model(self):
"""
Load, create or download model for classification tasks.
"""Load, create or download model for classification tasks.
Returns:
(Any): Model checkpoint if applicable, otherwise None.
@ -115,8 +111,7 @@ class ClassificationTrainer(BaseTrainer):
return ckpt
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
"""
Create a ClassificationDataset instance given an image path and mode.
"""Create a ClassificationDataset instance given an image path and mode.
Args:
img_path (str): Path to the dataset images.
@ -129,8 +124,7 @@ class ClassificationTrainer(BaseTrainer):
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
"""
Return PyTorch DataLoader with transforms to preprocess images.
"""Return PyTorch DataLoader with transforms to preprocess images.
Args:
dataset_path (str): Path to the dataset.
@ -177,8 +171,7 @@ class ClassificationTrainer(BaseTrainer):
)
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
"""
Return a loss dict with labeled training loss items tensor.
"""Return a loss dict with labeled training loss items tensor.
Args:
loss_items (torch.Tensor, optional): Loss tensor items.
@ -195,8 +188,7 @@ class ClassificationTrainer(BaseTrainer):
return dict(zip(keys, loss_items))
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
"""
Plot training samples with their annotations.
"""Plot training samples with their annotations.
Args:
batch (dict[str, torch.Tensor]): Batch containing images and class labels.

View file

@ -16,8 +16,7 @@ from ultralytics.utils.plotting import plot_images
class ClassificationValidator(BaseValidator):
"""
A class extending the BaseValidator class for validation based on a classification model.
"""A class extending the BaseValidator class for validation based on a classification model.
This validator handles the validation process for classification models, including metrics calculation, confusion
matrix generation, and visualization of results.
@ -55,8 +54,7 @@ class ClassificationValidator(BaseValidator):
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
"""Initialize ClassificationValidator with dataloader, save directory, and other parameters.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
@ -96,8 +94,7 @@ class ClassificationValidator(BaseValidator):
return batch
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
"""
Update running metrics with model predictions and batch targets.
"""Update running metrics with model predictions and batch targets.
Args:
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
@ -112,8 +109,7 @@ class ClassificationValidator(BaseValidator):
self.targets.append(batch["cls"].type(torch.int32).cpu())
def finalize_metrics(self) -> None:
"""
Finalize metrics including confusion matrix and processing speed.
"""Finalize metrics including confusion matrix and processing speed.
Examples:
>>> validator = ClassificationValidator()
@ -161,8 +157,7 @@ class ClassificationValidator(BaseValidator):
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
"""
Build and return a data loader for classification validation.
"""Build and return a data loader for classification validation.
Args:
dataset_path (str | Path): Path to the dataset directory.
@ -180,8 +175,7 @@ class ClassificationValidator(BaseValidator):
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
"""
Plot validation image samples with their ground truth labels.
"""Plot validation image samples with their ground truth labels.
Args:
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
@ -201,8 +195,7 @@ class ClassificationValidator(BaseValidator):
)
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
"""
Plot images with their predicted class labels and save the visualization.
"""Plot images with their predicted class labels and save the visualization.
Args:
batch (dict[str, Any]): Batch data containing images and other information.

View file

@ -6,8 +6,7 @@ from ultralytics.utils import nms, ops
class DetectionPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on a detection model.
"""A class extending the BasePredictor class for prediction based on a detection model.
This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
with bounding boxes and class predictions.
@ -32,8 +31,7 @@ class DetectionPredictor(BasePredictor):
"""
def postprocess(self, preds, img, orig_imgs, **kwargs):
"""
Post-process predictions and return a list of Results objects.
"""Post-process predictions and return a list of Results objects.
This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
further analysis.
@ -92,8 +90,7 @@ class DetectionPredictor(BasePredictor):
return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
def construct_results(self, preds, img, orig_imgs):
"""
Construct a list of Results objects from model predictions.
"""Construct a list of Results objects from model predictions.
Args:
preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
@ -109,8 +106,7 @@ class DetectionPredictor(BasePredictor):
]
def construct_result(self, pred, img, orig_img, img_path):
"""
Construct a single Results object from one image prediction.
"""Construct a single Results object from one image prediction.
Args:
pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.

View file

@ -22,8 +22,7 @@ from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_m
class DetectionTrainer(BaseTrainer):
"""
A class extending the BaseTrainer class for training based on a detection model.
"""A class extending the BaseTrainer class for training based on a detection model.
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for
object detection including dataset building, data loading, preprocessing, and model configuration.
@ -54,8 +53,7 @@ class DetectionTrainer(BaseTrainer):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a DetectionTrainer object for training YOLO object detection model training.
"""Initialize a DetectionTrainer object for training YOLO object detection model training.
Args:
cfg (dict, optional): Default configuration dictionary containing training parameters.
@ -65,8 +63,7 @@ class DetectionTrainer(BaseTrainer):
super().__init__(cfg, overrides, _callbacks)
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset for training or validation.
"""Build YOLO Dataset for training or validation.
Args:
img_path (str): Path to the folder containing images.
@ -80,8 +77,7 @@ class DetectionTrainer(BaseTrainer):
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
"""
Construct and return dataloader for the specified mode.
"""Construct and return dataloader for the specified mode.
Args:
dataset_path (str): Path to the dataset.
@ -109,8 +105,7 @@ class DetectionTrainer(BaseTrainer):
)
def preprocess_batch(self, batch: dict) -> dict:
"""
Preprocess a batch of images by scaling and converting to float.
"""Preprocess a batch of images by scaling and converting to float.
Args:
batch (dict): Dictionary containing batch data with 'img' tensor.
@ -150,8 +145,7 @@ class DetectionTrainer(BaseTrainer):
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
"""
Return a YOLO detection model.
"""Return a YOLO detection model.
Args:
cfg (str, optional): Path to model configuration file.
@ -174,8 +168,7 @@ class DetectionTrainer(BaseTrainer):
)
def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
"""
Return a loss dict with labeled training loss items tensor.
"""Return a loss dict with labeled training loss items tensor.
Args:
loss_items (list[float], optional): List of loss values.
@ -202,8 +195,7 @@ class DetectionTrainer(BaseTrainer):
)
def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
"""
Plot training samples with their annotations.
"""Plot training samples with their annotations.
Args:
batch (dict[str, Any]): Dictionary containing batch data.
@ -223,8 +215,7 @@ class DetectionTrainer(BaseTrainer):
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
def auto_batch(self):
"""
Get optimal batch size by calculating memory occupation of model.
"""Get optimal batch size by calculating memory occupation of model.
Returns:
(int): Optimal batch size.

View file

@ -19,8 +19,7 @@ from ultralytics.utils.plotting import plot_images
class DetectionValidator(BaseValidator):
"""
A class extending the BaseValidator class for validation based on a detection model.
"""A class extending the BaseValidator class for validation based on a detection model.
This class implements validation functionality specific to object detection tasks, including metrics calculation,
prediction processing, and visualization of results.
@ -44,8 +43,7 @@ class DetectionValidator(BaseValidator):
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize detection validator with necessary variables and settings.
"""Initialize detection validator with necessary variables and settings.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
@ -63,8 +61,7 @@ class DetectionValidator(BaseValidator):
self.metrics = DetMetrics()
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""
Preprocess batch of images for YOLO validation.
"""Preprocess batch of images for YOLO validation.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
@ -79,8 +76,7 @@ class DetectionValidator(BaseValidator):
return batch
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize evaluation metrics for YOLO detection validation.
"""Initialize evaluation metrics for YOLO detection validation.
Args:
model (torch.nn.Module): Model to validate.
@ -107,15 +103,14 @@ class DetectionValidator(BaseValidator):
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
"""
Apply Non-maximum suppression to prediction outputs.
"""Apply Non-maximum suppression to prediction outputs.
Args:
preds (torch.Tensor): Raw predictions from the model.
Returns:
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
'bboxes', 'conf', 'cls', and 'extra' tensors.
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains 'bboxes', 'conf',
'cls', and 'extra' tensors.
"""
outputs = nms.non_max_suppression(
preds,
@ -131,8 +126,7 @@ class DetectionValidator(BaseValidator):
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare a batch of images and annotations for validation.
"""Prepare a batch of images and annotations for validation.
Args:
si (int): Batch index.
@ -159,8 +153,7 @@ class DetectionValidator(BaseValidator):
}
def _prepare_pred(self, pred: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Prepare predictions for evaluation against ground truth.
"""Prepare predictions for evaluation against ground truth.
Args:
pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
@ -173,8 +166,7 @@ class DetectionValidator(BaseValidator):
return pred
def update_metrics(self, preds: list[dict[str, torch.Tensor]], batch: dict[str, Any]) -> None:
"""
Update metrics with new predictions and ground truth.
"""Update metrics with new predictions and ground truth.
Args:
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
@ -250,8 +242,7 @@ class DetectionValidator(BaseValidator):
self.metrics.clear_stats()
def get_stats(self) -> dict[str, Any]:
"""
Calculate and return metrics statistics.
"""Calculate and return metrics statistics.
Returns:
(dict[str, Any]): Dictionary containing metrics results.
@ -281,8 +272,7 @@ class DetectionValidator(BaseValidator):
)
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
"""
Return correct prediction matrix.
"""Return correct prediction matrix.
Args:
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
@ -298,8 +288,7 @@ class DetectionValidator(BaseValidator):
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None) -> torch.utils.data.Dataset:
"""
Build YOLO Dataset.
"""Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
@ -312,8 +301,7 @@ class DetectionValidator(BaseValidator):
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
"""
Construct and return dataloader.
"""Construct and return dataloader.
Args:
dataset_path (str): Path to the dataset.
@ -334,8 +322,7 @@ class DetectionValidator(BaseValidator):
)
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
"""
Plot validation image samples.
"""Plot validation image samples.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
@ -352,8 +339,7 @@ class DetectionValidator(BaseValidator):
def plot_predictions(
self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: int | None = None
) -> None:
"""
Plot predicted bounding boxes on input images and save the result.
"""Plot predicted bounding boxes on input images and save the result.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
@ -379,8 +365,7 @@ class DetectionValidator(BaseValidator):
) # pred
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO detections to a txt file in normalized coordinates in a specific format.
"""Save YOLO detections to a txt file in normalized coordinates in a specific format.
Args:
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
@ -398,12 +383,11 @@ class DetectionValidator(BaseValidator):
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Serialize YOLO predictions to COCO json format.
"""Serialize YOLO predictions to COCO json format.
Args:
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
with bounding box coordinates, confidence scores, and class predictions.
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
bounding box coordinates, confidence scores, and class predictions.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
Examples:
@ -444,8 +428,7 @@ class DetectionValidator(BaseValidator):
}
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
"""
Evaluate YOLO output in JSON format and return performance statistics.
"""Evaluate YOLO output in JSON format and return performance statistics.
Args:
stats (dict[str, Any]): Current statistics dictionary.
@ -469,8 +452,7 @@ class DetectionValidator(BaseValidator):
iou_types: str | list[str] = "bbox",
suffix: str | list[str] = "Box",
) -> dict[str, Any]:
"""
Evaluate COCO/LVIS metrics using faster-coco-eval library.
"""Evaluate COCO/LVIS metrics using faster-coco-eval library.
Performs evaluation using the faster-coco-eval library to compute mAP metrics for object detection. Updates the
provided stats dictionary with computed metrics including mAP50, mAP50-95, and LVIS-specific metrics if
@ -480,10 +462,10 @@ class DetectionValidator(BaseValidator):
stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
pred_json (str | Path]): Path to JSON file containing predictions in COCO format.
anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.
iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
Common values include "bbox", "segm", "keypoints". Defaults to "bbox".
suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond
to iou_types if multiple types provided. Defaults to "Box".
iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings. Common
values include "bbox", "segm", "keypoints". Defaults to "bbox".
suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond to
iou_types if multiple types provided. Defaults to "Box".
Returns:
(dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.

View file

@ -24,8 +24,7 @@ from ultralytics.utils import ROOT, YAML
class YOLO(Model):
"""
YOLO (You Only Look Once) object detection model.
"""YOLO (You Only Look Once) object detection model.
This class provides a unified interface for YOLO models, automatically switching to specialized model types
(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
@ -52,16 +51,15 @@ class YOLO(Model):
"""
def __init__(self, model: str | Path = "yolo11n.pt", task: str | None = None, verbose: bool = False):
"""
Initialize a YOLO model.
"""Initialize a YOLO model.
This constructor initializes a YOLO model, automatically switching to specialized model types (YOLOWorld or
YOLOE) based on the model filename.
Args:
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
Defaults to auto-detection based on model.
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'. Defaults
to auto-detection based on model.
verbose (bool): Display model info on load.
Examples:
@ -126,8 +124,7 @@ class YOLO(Model):
class YOLOWorld(Model):
"""
YOLO-World object detection model.
"""YOLO-World object detection model.
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions without
requiring training on specific classes. It extends the YOLO architecture to support real-time open-vocabulary
@ -152,8 +149,7 @@ class YOLOWorld(Model):
"""
def __init__(self, model: str | Path = "yolov8s-world.pt", verbose: bool = False) -> None:
"""
Initialize YOLOv8-World model with a pre-trained model file.
"""Initialize YOLOv8-World model with a pre-trained model file.
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default COCO
class names.
@ -181,8 +177,7 @@ class YOLOWorld(Model):
}
def set_classes(self, classes: list[str]) -> None:
"""
Set the model's class names for detection.
"""Set the model's class names for detection.
Args:
classes (list[str]): A list of categories i.e. ["person"].
@ -200,8 +195,7 @@ class YOLOWorld(Model):
class YOLOE(Model):
"""
YOLOE object detection and segmentation model.
"""YOLOE object detection and segmentation model.
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with improved
performance and additional features like visual and text positional embeddings.
@ -235,8 +229,7 @@ class YOLOE(Model):
"""
def __init__(self, model: str | Path = "yoloe-11s-seg.pt", task: str | None = None, verbose: bool = False) -> None:
"""
Initialize YOLOE model with a pre-trained model file.
"""Initialize YOLOE model with a pre-trained model file.
Args:
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
@ -269,8 +262,7 @@ class YOLOE(Model):
return self.model.get_text_pe(texts)
def get_visual_pe(self, img, visual):
"""
Get visual positional embeddings for the given image and visual features.
"""Get visual positional embeddings for the given image and visual features.
This method extracts positional embeddings from visual features based on the input image. It requires that the
model is an instance of YOLOEModel.
@ -292,8 +284,7 @@ class YOLOE(Model):
return self.model.get_visual_pe(img, visual)
def set_vocab(self, vocab: list[str], names: list[str]) -> None:
"""
Set vocabulary and class names for the YOLOE model.
"""Set vocabulary and class names for the YOLOE model.
This method configures the vocabulary and class names used by the model for text processing and classification
tasks. The model must be an instance of YOLOEModel.
@ -318,8 +309,7 @@ class YOLOE(Model):
return self.model.get_vocab(names)
def set_classes(self, classes: list[str], embeddings: torch.Tensor | None = None) -> None:
"""
Set the model's class names and embeddings for detection.
"""Set the model's class names and embeddings for detection.
Args:
classes (list[str]): A list of categories i.e. ["person"].
@ -344,8 +334,7 @@ class YOLOE(Model):
refer_data: str | None = None,
**kwargs,
):
"""
Validate the model using text or visual prompts.
"""Validate the model using text or visual prompts.
Args:
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
@ -373,19 +362,18 @@ class YOLOE(Model):
predictor=yolo.yoloe.YOLOEVPDetectPredictor,
**kwargs,
):
"""
Run prediction on images, videos, directories, streams, etc.
"""Run prediction on images, videos, directories, streams, etc.
Args:
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
generator as they are computed.
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include
'bboxes' and 'cls' keys when non-empty.
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths, directory
paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
stream (bool): Whether to stream the prediction results. If True, results are yielded as a generator as they
are computed.
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include 'bboxes'
and 'cls' keys when non-empty.
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
loaded based on the task.
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically loaded
based on the task.
**kwargs (Any): Additional keyword arguments passed to the predictor.
Returns:

View file

@ -8,8 +8,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
class OBBPredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
"""A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
bounding boxes.
@ -27,8 +26,7 @@ class OBBPredictor(DetectionPredictor):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize OBBPredictor with optional model and data configuration overrides.
"""Initialize OBBPredictor with optional model and data configuration overrides.
Args:
cfg (dict, optional): Default configuration for the predictor.
@ -45,12 +43,11 @@ class OBBPredictor(DetectionPredictor):
self.args.task = "obb"
def construct_result(self, pred, img, orig_img, img_path):
"""
Construct the result object from the prediction.
"""Construct the result object from the prediction.
Args:
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
the last dimension contains [x, y, w, h, confidence, class_id, angle].
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where the
last dimension contains [x, y, w, h, confidence, class_id, angle].
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
orig_img (np.ndarray): The original image before preprocessing.
img_path (str): The path to the original image.

View file

@ -12,15 +12,14 @@ from ultralytics.utils import DEFAULT_CFG, RANK
class OBBTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
"""A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for detecting
objects at arbitrary angles rather than just axis-aligned rectangles.
Attributes:
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
and dfl_loss.
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss, and
dfl_loss.
Methods:
get_model: Return OBBModel initialized with specified config and weights.
@ -34,14 +33,13 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
"""
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
"""Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
Args:
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
model configuration.
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
will take precedence over those in cfg.
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and model
configuration.
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here will
take precedence over those in cfg.
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
"""
if overrides is None:
@ -52,8 +50,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
def get_model(
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
) -> OBBModel:
"""
Return OBBModel initialized with specified config and weights.
"""Return OBBModel initialized with specified config and weights.
Args:
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary

View file

@ -15,8 +15,7 @@ from ultralytics.utils.nms import TorchNMS
class OBBValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
"""A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
satellite imagery where objects can appear at various orientations.
@ -44,8 +43,7 @@ class OBBValidator(DetectionValidator):
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models. It
extends the DetectionValidator class and configures it specifically for the OBB task.
@ -61,8 +59,7 @@ class OBBValidator(DetectionValidator):
self.metrics = OBBMetrics()
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize evaluation metrics for YOLO obb validation.
"""Initialize evaluation metrics for YOLO obb validation.
Args:
model (torch.nn.Module): Model to validate.
@ -73,18 +70,17 @@ class OBBValidator(DetectionValidator):
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
"""
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
"""Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
Args:
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
class labels and bounding boxes.
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
class labels and bounding boxes.
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth class
labels and bounding boxes.
Returns:
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy of
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy array
with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy of
predictions compared to the ground truth.
Examples:
@ -99,7 +95,8 @@ class OBBValidator(DetectionValidator):
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
"""
"""Postprocess OBB predictions.
Args:
preds (torch.Tensor): Raw predictions from the model.
@ -112,8 +109,7 @@ class OBBValidator(DetectionValidator):
return preds
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare batch data for OBB validation with proper scaling and formatting.
"""Prepare batch data for OBB validation with proper scaling and formatting.
Args:
si (int): Batch index to process.
@ -146,8 +142,7 @@ class OBBValidator(DetectionValidator):
}
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
"""
Plot predicted bounding boxes on input images and save the result.
"""Plot predicted bounding boxes on input images and save the result.
Args:
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
@ -166,12 +161,11 @@ class OBBValidator(DetectionValidator):
super().plot_predictions(batch, preds, ni) # plot bboxes
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
"""Convert YOLO predictions to COCO JSON format with rotated bounding box information.
Args:
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
with bounding box coordinates, confidence scores, and class predictions.
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys with
bounding box coordinates, confidence scores, and class predictions.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
Notes:
@ -197,8 +191,7 @@ class OBBValidator(DetectionValidator):
)
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO OBB detections to a text file in normalized coordinates.
"""Save YOLO OBB detections to a text file in normalized coordinates.
Args:
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
@ -233,8 +226,7 @@ class OBBValidator(DetectionValidator):
}
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
"""
Evaluate YOLO output in JSON format and save predictions in DOTA format.
"""Evaluate YOLO output in JSON format and save predictions in DOTA format.
Args:
stats (dict[str, Any]): Performance statistics dictionary.

View file

@ -5,8 +5,7 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
class PosePredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a pose model.
"""A class extending the DetectionPredictor class for prediction based on a pose model.
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
capabilities inherited from DetectionPredictor.
@ -27,8 +26,7 @@ class PosePredictor(DetectionPredictor):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize PosePredictor for pose estimation tasks.
"""Initialize PosePredictor for pose estimation tasks.
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific warnings
for Apple MPS.
@ -54,8 +52,7 @@ class PosePredictor(DetectionPredictor):
)
def construct_result(self, pred, img, orig_img, img_path):
"""
Construct the result object from the prediction, including keypoints.
"""Construct the result object from the prediction, including keypoints.
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
result object.

View file

@ -12,8 +12,7 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER
class PoseTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training YOLO pose estimation models.
"""A class extending the DetectionTrainer class for training YOLO pose estimation models.
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
of pose keypoints alongside bounding boxes.
@ -39,8 +38,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a PoseTrainer object for training YOLO pose estimation models.
"""Initialize a PoseTrainer object for training YOLO pose estimation models.
Args:
cfg (dict, optional): Default configuration dictionary containing training parameters.
@ -68,8 +66,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
weights: str | Path | None = None,
verbose: bool = True,
) -> PoseModel:
"""
Get pose estimation model with specified configuration and weights.
"""Get pose estimation model with specified configuration and weights.
Args:
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
@ -105,8 +102,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
)
def get_dataset(self) -> dict[str, Any]:
"""
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
"""Retrieve the dataset and ensure it contains the required `kpt_shape` key.
Returns:
(dict): A dictionary containing the training/validation/test dataset and category names.

View file

@ -14,8 +14,7 @@ from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
class PoseValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a pose model.
"""A class extending the DetectionValidator class for validation based on a pose model.
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
metrics for pose evaluation.
@ -33,8 +32,8 @@ class PoseValidator(DetectionValidator):
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
dimensions.
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
detections and ground truth.
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
and ground truth.
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
@ -49,8 +48,7 @@ class PoseValidator(DetectionValidator):
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize a PoseValidator object for pose estimation validation.
"""Initialize a PoseValidator object for pose estimation validation.
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
specialized metrics for pose evaluation.
@ -106,8 +104,7 @@ class PoseValidator(DetectionValidator):
)
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize evaluation metrics for YOLO pose validation.
"""Initialize evaluation metrics for YOLO pose validation.
Args:
model (torch.nn.Module): Model to validate.
@ -119,16 +116,15 @@ class PoseValidator(DetectionValidator):
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
"""Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
Args:
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
bounding boxes, confidence scores, class predictions, and keypoint data.
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
scores, class predictions, and keypoint data.
Returns:
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
@ -148,8 +144,7 @@ class PoseValidator(DetectionValidator):
return preds
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
"""Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
Args:
si (int): Batch index.
@ -172,18 +167,18 @@ class PoseValidator(DetectionValidator):
return pbatch
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
"""
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
"""Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
truth.
Args:
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
and 'keypoints' for keypoint predictions.
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
for bounding boxes, and 'keypoints' for keypoint annotations.
Returns:
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
true positives across 10 IoU levels.
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
positives across 10 IoU levels.
Notes:
`0.53` scale factor used in area computation is referenced from
@ -202,8 +197,7 @@ class PoseValidator(DetectionValidator):
return tp
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO pose detections to a text file in normalized coordinates.
"""Save YOLO pose detections to a text file in normalized coordinates.
Args:
predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
@ -226,15 +220,14 @@ class PoseValidator(DetectionValidator):
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Convert YOLO predictions to COCO JSON format.
"""Convert YOLO predictions to COCO JSON format.
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
format, and appends the results to the internal JSON dictionary (self.jdict).
Args:
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
and 'keypoints' tensors.
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
tensors.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
Notes:

View file

@ -6,8 +6,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
class SegmentationPredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a segmentation model.
"""A class extending the DetectionPredictor class for prediction based on a segmentation model.
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
prediction results.
@ -31,8 +30,7 @@ class SegmentationPredictor(DetectionPredictor):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
"""Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
prediction results.
@ -46,8 +44,7 @@ class SegmentationPredictor(DetectionPredictor):
self.args.task = "segment"
def postprocess(self, preds, img, orig_imgs):
"""
Apply non-max suppression and process segmentation detections for each image in the input batch.
"""Apply non-max suppression and process segmentation detections for each image in the input batch.
Args:
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
@ -55,8 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
Returns:
(list): List of Results objects containing the segmentation predictions for each image in the batch.
Each Results object includes both bounding boxes and segmentation masks.
(list): List of Results objects containing the segmentation predictions for each image in the batch. Each
Results object includes both bounding boxes and segmentation masks.
Examples:
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
@ -67,8 +64,7 @@ class SegmentationPredictor(DetectionPredictor):
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
def construct_results(self, preds, img, orig_imgs, protos):
"""
Construct a list of result objects from the predictions.
"""Construct a list of result objects from the predictions.
Args:
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
@ -77,8 +73,8 @@ class SegmentationPredictor(DetectionPredictor):
protos (list[torch.Tensor]): List of prototype masks.
Returns:
(list[Results]): List of result objects containing the original images, image paths, class names,
bounding boxes, and masks.
(list[Results]): List of result objects containing the original images, image paths, class names, bounding
boxes, and masks.
"""
return [
self.construct_result(pred, img, orig_img, img_path, proto)
@ -86,8 +82,7 @@ class SegmentationPredictor(DetectionPredictor):
]
def construct_result(self, pred, img, orig_img, img_path, proto):
"""
Construct a single result object from the prediction.
"""Construct a single result object from the prediction.
Args:
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.

View file

@ -11,8 +11,7 @@ from ultralytics.utils import DEFAULT_CFG, RANK
class SegmentationTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on a segmentation model.
"""A class extending the DetectionTrainer class for training based on a segmentation model.
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
functionality including model initialization, validation, and visualization.
@ -28,8 +27,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
"""
Initialize a SegmentationTrainer object.
"""Initialize a SegmentationTrainer object.
Args:
cfg (dict): Configuration dictionary with default training settings.
@ -42,8 +40,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
"""
Initialize and return a SegmentationModel with specified configuration and weights.
"""Initialize and return a SegmentationModel with specified configuration and weights.
Args:
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.

View file

@ -17,8 +17,7 @@ from ultralytics.utils.metrics import SegmentMetrics, mask_iou
class SegmentationValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a segmentation model.
"""A class extending the DetectionValidator class for validation based on a segmentation model.
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions to
compute metrics such as mAP for both detection and segmentation tasks.
@ -38,8 +37,7 @@ class SegmentationValidator(DetectionValidator):
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
@ -53,8 +51,7 @@ class SegmentationValidator(DetectionValidator):
self.metrics = SegmentMetrics()
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""
Preprocess batch of images for YOLO segmentation validation.
"""Preprocess batch of images for YOLO segmentation validation.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
@ -67,8 +64,7 @@ class SegmentationValidator(DetectionValidator):
return batch
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize metrics and select mask processing function based on save_json flag.
"""Initialize metrics and select mask processing function based on save_json flag.
Args:
model (torch.nn.Module): Model to validate.
@ -96,8 +92,7 @@ class SegmentationValidator(DetectionValidator):
)
def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
"""
Post-process YOLO predictions and return output detections with proto.
"""Post-process YOLO predictions and return output detections with proto.
Args:
preds (list[torch.Tensor]): Raw predictions from the model.
@ -122,8 +117,7 @@ class SegmentationValidator(DetectionValidator):
return preds
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare a batch for training or inference by processing images and targets.
"""Prepare a batch for training or inference by processing images and targets.
Args:
si (int): Batch index.
@ -149,8 +143,7 @@ class SegmentationValidator(DetectionValidator):
return prepared_batch
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
"""
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
"""Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
Args:
preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
@ -179,8 +172,7 @@ class SegmentationValidator(DetectionValidator):
return tp
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
"""
Plot batch predictions with masks and bounding boxes.
"""Plot batch predictions with masks and bounding boxes.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
@ -195,8 +187,7 @@ class SegmentationValidator(DetectionValidator):
super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO detections to a txt file in normalized coordinates in a specific format.
"""Save YOLO detections to a txt file in normalized coordinates in a specific format.
Args:
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
@ -215,8 +206,7 @@ class SegmentationValidator(DetectionValidator):
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Save one JSON result for COCO evaluation.
"""Save one JSON result for COCO evaluation.
Args:
predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.

View file

@ -24,8 +24,7 @@ def on_pretrain_routine_end(trainer) -> None:
class WorldTrainer(DetectionTrainer):
"""
A trainer class for fine-tuning YOLO World models on close-set datasets.
"""A trainer class for fine-tuning YOLO World models on close-set datasets.
This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
features for improved object detection and understanding. It handles text embedding generation and caching to
@ -54,8 +53,7 @@ class WorldTrainer(DetectionTrainer):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a WorldTrainer object with given arguments.
"""Initialize a WorldTrainer object with given arguments.
Args:
cfg (dict[str, Any]): Configuration for the trainer.
@ -69,8 +67,7 @@ class WorldTrainer(DetectionTrainer):
self.text_embeddings = None
def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
"""
Return WorldModel initialized with specified config and weights.
"""Return WorldModel initialized with specified config and weights.
Args:
cfg (dict[str, Any] | str, optional): Model configuration.
@ -95,8 +92,7 @@ class WorldTrainer(DetectionTrainer):
return model
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset for training or validation.
"""Build YOLO Dataset for training or validation.
Args:
img_path (str): Path to the folder containing images.
@ -115,8 +111,7 @@ class WorldTrainer(DetectionTrainer):
return dataset
def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
"""
Set text embeddings for datasets to accelerate training by caching category names.
"""Set text embeddings for datasets to accelerate training by caching category names.
This method collects unique category names from all datasets, then generates and caches text embeddings for
these categories to improve training efficiency.
@ -141,8 +136,7 @@ class WorldTrainer(DetectionTrainer):
self.text_embeddings = text_embeddings
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
"""
Generate text embeddings for a list of text samples.
"""Generate text embeddings for a list of text samples.
Args:
texts (list[str]): List of text samples to encode.

View file

@ -10,8 +10,7 @@ from ultralytics.utils.torch_utils import unwrap_model
class WorldTrainerFromScratch(WorldTrainer):
"""
A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
"""A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
supporting training YOLO-World models with combined vision-language capabilities.
@ -53,8 +52,7 @@ class WorldTrainerFromScratch(WorldTrainer):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize a WorldTrainerFromScratch object.
"""Initialize a WorldTrainerFromScratch object.
This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both object
detection and grounding datasets for vision-language capabilities.
@ -87,8 +85,7 @@ class WorldTrainerFromScratch(WorldTrainer):
super().__init__(cfg, overrides, _callbacks)
def build_dataset(self, img_path, mode="train", batch=None):
"""
Build YOLO Dataset for training or validation.
"""Build YOLO Dataset for training or validation.
This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
datasets and grounding datasets with different formats.
@ -122,8 +119,7 @@ class WorldTrainerFromScratch(WorldTrainer):
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
def get_dataset(self):
"""
Get train and validation paths from data dictionary.
"""Get train and validation paths from data dictionary.
Processes the data configuration to extract paths for training and validation datasets, handling both YOLO
detection datasets and grounding datasets.
@ -187,8 +183,7 @@ class WorldTrainerFromScratch(WorldTrainer):
pass
def final_eval(self):
"""
Perform final evaluation and validation for the YOLO-World model.
"""Perform final evaluation and validation for the YOLO-World model.
Configures the validator with appropriate dataset and split information before running evaluation.

View file

@ -9,8 +9,7 @@ from ultralytics.models.yolo.segment import SegmentationPredictor
class YOLOEVPDetectPredictor(DetectionPredictor):
"""
A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
"""A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
This mixin provides common functionality for YOLO models that use visual prompting, including model setup, prompt
handling, and preprocessing transformations.
@ -29,8 +28,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
"""
def setup_model(self, model, verbose: bool = True):
"""
Set up the model for prediction.
"""Set up the model for prediction.
Args:
model (torch.nn.Module): Model to load or use.
@ -40,18 +38,16 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
self.done_warmup = True
def set_prompts(self, prompts):
"""
Set the visual prompts for the model.
"""Set the visual prompts for the model.
Args:
prompts (dict): Dictionary containing class indices and bounding boxes or masks.
Must include a 'cls' key with class indices.
prompts (dict): Dictionary containing class indices and bounding boxes or masks. Must include a 'cls' key
with class indices.
"""
self.prompts = prompts
def pre_transform(self, im):
"""
Preprocess images and prompts before inference.
"""Preprocess images and prompts before inference.
This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks)
accordingly.
@ -94,8 +90,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
return img
def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
"""
Process a single image by resizing bounding boxes or masks and generating visuals.
"""Process a single image by resizing bounding boxes or masks and generating visuals.
Args:
dst_shape (tuple): The target shape (height, width) of the image.
@ -131,8 +126,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
def inference(self, im, *args, **kwargs):
"""
Run inference with visual prompts.
"""Run inference with visual prompts.
Args:
im (torch.Tensor): Input image tensor.
@ -145,13 +139,12 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
return super().inference(im, vpe=self.prompts, *args, **kwargs)
def get_vpe(self, source):
"""
Process the source to get the visual prompt embeddings (VPE).
"""Process the source to get the visual prompt embeddings (VPE).
Args:
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
of the image to make predictions on. Accepts various types including file paths, URLs, PIL images, numpy
arrays, and torch tensors.
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image to
make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
torch tensors.
Returns:
(torch.Tensor): The visual prompt embeddings (VPE) from the model.

View file

@ -19,8 +19,7 @@ from .val import YOLOEDetectValidator
class YOLOETrainer(DetectionTrainer):
"""
A trainer class for YOLOE object detection models.
"""A trainer class for YOLOE object detection models.
This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom
model initialization, validation, and dataset building with multi-modal support.
@ -35,8 +34,7 @@ class YOLOETrainer(DetectionTrainer):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
"""
Initialize the YOLOE Trainer with specified configurations.
"""Initialize the YOLOE Trainer with specified configurations.
Args:
cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
@ -50,12 +48,11 @@ class YOLOETrainer(DetectionTrainer):
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose: bool = True):
"""
Return a YOLOEModel initialized with the specified configuration and weights.
"""Return a YOLOEModel initialized with the specified configuration and weights.
Args:
cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,
a direct path to a YAML file, or None to use default configuration.
cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key, a direct
path to a YAML file, or None to use default configuration.
weights (str | Path, optional): Path to pretrained weights file to load into the model.
verbose (bool): Whether to display model information during initialization.
@ -88,8 +85,7 @@ class YOLOETrainer(DetectionTrainer):
)
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset.
"""Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
@ -106,8 +102,7 @@ class YOLOETrainer(DetectionTrainer):
class YOLOEPETrainer(DetectionTrainer):
"""
Fine-tune YOLOE model using linear probing approach.
"""Fine-tune YOLOE model using linear probing approach.
This trainer freezes most model layers and only trains specific projection layers for efficient fine-tuning on new
datasets while preserving pretrained features.
@ -117,8 +112,7 @@ class YOLOEPETrainer(DetectionTrainer):
"""
def get_model(self, cfg=None, weights=None, verbose: bool = True):
"""
Return YOLOEModel initialized with specified config and weights.
"""Return YOLOEModel initialized with specified config and weights.
Args:
cfg (dict | str, optional): Model configuration.
@ -160,8 +154,7 @@ class YOLOEPETrainer(DetectionTrainer):
class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
"""
Train YOLOE models from scratch with text embedding support.
"""Train YOLOE models from scratch with text embedding support.
This trainer combines YOLOE training capabilities with world training features, enabling training from scratch with
text embeddings and grounding datasets.
@ -172,8 +165,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
"""
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset for training or validation.
"""Build YOLO Dataset for training or validation.
This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
datasets and grounding datasets with different formats.
@ -189,8 +181,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
"""
Generate text embeddings for a list of text samples.
"""Generate text embeddings for a list of text samples.
Args:
texts (list[str]): List of text samples to encode.
@ -216,8 +207,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
"""
Train prompt-free YOLOE model.
"""Train prompt-free YOLOE model.
This trainer combines linear probing capabilities with from-scratch training for prompt-free YOLOE models that don't
require text prompts during inference.
@ -240,8 +230,7 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
return DetectionTrainer.preprocess_batch(self, batch)
def set_text_embeddings(self, datasets, batch: int):
"""
Set text embeddings for datasets to accelerate training by caching category names.
"""Set text embeddings for datasets to accelerate training by caching category names.
This method collects unique category names from all datasets, generates text embeddings for them, and caches
these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of
@ -260,8 +249,7 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
class YOLOEVPTrainer(YOLOETrainerFromScratch):
"""
Train YOLOE model with visual prompts.
"""Train YOLOE model with visual prompts.
This trainer extends YOLOETrainerFromScratch to support visual prompt-based training, where visual cues are provided
alongside images to guide the detection process.
@ -271,8 +259,7 @@ class YOLOEVPTrainer(YOLOETrainerFromScratch):
"""
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset for training or validation with visual prompts.
"""Build YOLO Dataset for training or validation with visual prompts.
Args:
img_path (list[str] | str): Path to the folder containing images or list of paths.

View file

@ -11,8 +11,7 @@ from .val import YOLOESegValidator
class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
"""
Trainer class for YOLOE segmentation models.
"""Trainer class for YOLOE segmentation models.
This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
segmentation models, enabling both object detection and instance segmentation capabilities.
@ -24,8 +23,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
"""
def get_model(self, cfg=None, weights=None, verbose=True):
"""
Return YOLOESegModel initialized with specified config and weights.
"""Return YOLOESegModel initialized with specified config and weights.
Args:
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
@ -49,8 +47,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
return model
def get_validator(self):
"""
Create and return a validator for YOLOE segmentation model evaluation.
"""Create and return a validator for YOLOE segmentation model evaluation.
Returns:
(YOLOESegValidator): Validator for YOLOE segmentation models.
@ -62,8 +59,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
class YOLOEPESegTrainer(SegmentationTrainer):
"""
Fine-tune YOLOESeg model in linear probing way.
"""Fine-tune YOLOESeg model in linear probing way.
This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
most of the model and only training specific layers for efficient adaptation to new tasks.
@ -73,8 +69,7 @@ class YOLOEPESegTrainer(SegmentationTrainer):
"""
def get_model(self, cfg=None, weights=None, verbose=True):
"""
Return YOLOESegModel initialized with specified config and weights for linear probing.
"""Return YOLOESegModel initialized with specified config and weights for linear probing.
Args:
cfg (dict | str, optional): Model configuration dictionary or YAML file path.

View file

@ -21,8 +21,7 @@ from ultralytics.utils.torch_utils import select_device, smart_inference_mode
class YOLOEDetectValidator(DetectionValidator):
"""
A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
"""A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
This class extends DetectionValidator to provide specialized validation functionality for YOLOE models. It supports
validation using either text prompts or visual prompt embeddings extracted from training samples, enabling flexible
@ -50,8 +49,7 @@ class YOLOEDetectValidator(DetectionValidator):
@smart_inference_mode()
def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
"""
Extract visual prompt embeddings from training samples.
"""Extract visual prompt embeddings from training samples.
This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It
normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to
@ -99,8 +97,7 @@ class YOLOEDetectValidator(DetectionValidator):
return visual_pe.unsqueeze(0)
def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
"""
Create a dataloader for LVIS training visual prompt samples.
"""Create a dataloader for LVIS training visual prompt samples.
This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies
necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.
@ -140,8 +137,7 @@ class YOLOEDetectValidator(DetectionValidator):
refer_data: str | None = None,
load_vp: bool = False,
) -> dict[str, Any]:
"""
Run validation on the model using either text or visual prompt embeddings.
"""Run validation on the model using either text or visual prompt embeddings.
This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It
supports validation during training (using a trainer object) or standalone validation with a provided model. For

View file

@ -23,8 +23,7 @@ from ultralytics.utils.nms import non_max_suppression
def check_class_names(names: list | dict) -> dict[int, str]:
"""
Check class names and convert to dict format if needed.
"""Check class names and convert to dict format if needed.
Args:
names (list | dict): Class names as list or dict format.
@ -53,8 +52,7 @@ def check_class_names(names: list | dict) -> dict[int, str]:
def default_class_names(data: str | Path | None = None) -> dict[int, str]:
"""
Apply default class names to an input YAML file or return numerical class names.
"""Apply default class names to an input YAML file or return numerical class names.
Args:
data (str | Path, optional): Path to YAML file containing class names.
@ -71,8 +69,7 @@ def default_class_names(data: str | Path | None = None) -> dict[int, str]:
class AutoBackend(nn.Module):
"""
Handle dynamic backend selection for running inference using Ultralytics YOLO models.
"""Handle dynamic backend selection for running inference using Ultralytics YOLO models.
The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide
range of formats, each with specific naming conventions as outlined below:
@ -148,8 +145,7 @@ class AutoBackend(nn.Module):
fuse: bool = True,
verbose: bool = True,
):
"""
Initialize the AutoBackend for inference.
"""Initialize the AutoBackend for inference.
Args:
model (str | torch.nn.Module): Path to the model weights file or a module instance.
@ -639,8 +635,7 @@ class AutoBackend(nn.Module):
embed: list | None = None,
**kwargs: Any,
) -> torch.Tensor | list[torch.Tensor]:
"""
Run inference on an AutoBackend model.
"""Run inference on an AutoBackend model.
Args:
im (torch.Tensor): The image tensor to perform inference on.
@ -860,8 +855,7 @@ class AutoBackend(nn.Module):
return self.from_numpy(y)
def from_numpy(self, x: np.ndarray) -> torch.Tensor:
"""
Convert a numpy array to a tensor.
"""Convert a numpy array to a tensor.
Args:
x (np.ndarray): The array to be converted.
@ -872,8 +866,7 @@ class AutoBackend(nn.Module):
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
def warmup(self, imgsz: tuple[int, int, int, int] = (1, 3, 640, 640)) -> None:
"""
Warm up the model by running one forward pass with a dummy input.
"""Warm up the model by running one forward pass with a dummy input.
Args:
imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
@ -889,8 +882,7 @@ class AutoBackend(nn.Module):
@staticmethod
def _model_type(p: str = "path/to/model.pt") -> list[bool]:
"""
Take a path to a model file and return the model type.
"""Take a path to a model file and return the model type.
Args:
p (str): Path to the model file.

View file

@ -6,8 +6,7 @@ import torch.nn as nn
class AGLU(nn.Module):
"""
Unified activation function module from AGLU.
"""Unified activation function module from AGLU.
This class implements a parameterized activation function with learnable parameters lambda and kappa, based on the
AGLU (Adaptive Gated Linear Unit) approach.
@ -40,8 +39,7 @@ class AGLU(nn.Module):
self.kappa = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # kappa parameter
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply the Adaptive Gated Linear Unit (AGLU) activation function.
"""Apply the Adaptive Gated Linear Unit (AGLU) activation function.
This forward method implements the AGLU activation function with learnable parameters lambda and kappa. The
function applies a transformation that adaptively combines linear and non-linear components.

View file

@ -56,15 +56,13 @@ __all__ = (
class DFL(nn.Module):
"""
Integral module of Distribution Focal Loss (DFL).
"""Integral module of Distribution Focal Loss (DFL).
Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
"""
def __init__(self, c1: int = 16):
"""
Initialize a convolutional layer with a given number of input channels.
"""Initialize a convolutional layer with a given number of input channels.
Args:
c1 (int): Number of input channels.
@ -86,8 +84,7 @@ class Proto(nn.Module):
"""Ultralytics YOLO models mask Proto module for segmentation models."""
def __init__(self, c1: int, c_: int = 256, c2: int = 32):
"""
Initialize the Ultralytics YOLO models mask Proto module with specified number of protos and masks.
"""Initialize the Ultralytics YOLO models mask Proto module with specified number of protos and masks.
Args:
c1 (int): Input channels.
@ -106,15 +103,13 @@ class Proto(nn.Module):
class HGStem(nn.Module):
"""
StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
"""StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
"""
def __init__(self, c1: int, cm: int, c2: int):
"""
Initialize the StemBlock of PPHGNetV2.
"""Initialize the StemBlock of PPHGNetV2.
Args:
c1 (int): Input channels.
@ -144,8 +139,7 @@ class HGStem(nn.Module):
class HGBlock(nn.Module):
"""
HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
"""HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
"""
@ -161,8 +155,7 @@ class HGBlock(nn.Module):
shortcut: bool = False,
act: nn.Module = nn.ReLU(),
):
"""
Initialize HGBlock with specified parameters.
"""Initialize HGBlock with specified parameters.
Args:
c1 (int): Input channels.
@ -193,8 +186,7 @@ class SPP(nn.Module):
"""Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
def __init__(self, c1: int, c2: int, k: tuple[int, ...] = (5, 9, 13)):
"""
Initialize the SPP layer with input/output channels and pooling kernel sizes.
"""Initialize the SPP layer with input/output channels and pooling kernel sizes.
Args:
c1 (int): Input channels.
@ -217,8 +209,7 @@ class SPPF(nn.Module):
"""Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
def __init__(self, c1: int, c2: int, k: int = 5):
"""
Initialize the SPPF layer with given input/output channels and kernel size.
"""Initialize the SPPF layer with given input/output channels and kernel size.
Args:
c1 (int): Input channels.
@ -245,8 +236,7 @@ class C1(nn.Module):
"""CSP Bottleneck with 1 convolution."""
def __init__(self, c1: int, c2: int, n: int = 1):
"""
Initialize the CSP Bottleneck with 1 convolution.
"""Initialize the CSP Bottleneck with 1 convolution.
Args:
c1 (int): Input channels.
@ -267,8 +257,7 @@ class C2(nn.Module):
"""CSP Bottleneck with 2 convolutions."""
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
"""
Initialize a CSP Bottleneck with 2 convolutions.
"""Initialize a CSP Bottleneck with 2 convolutions.
Args:
c1 (int): Input channels.
@ -295,8 +284,7 @@ class C2f(nn.Module):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):
"""
Initialize a CSP bottleneck with 2 convolutions.
"""Initialize a CSP bottleneck with 2 convolutions.
Args:
c1 (int): Input channels.
@ -330,8 +318,7 @@ class C3(nn.Module):
"""CSP Bottleneck with 3 convolutions."""
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
"""
Initialize the CSP Bottleneck with 3 convolutions.
"""Initialize the CSP Bottleneck with 3 convolutions.
Args:
c1 (int): Input channels.
@ -357,8 +344,7 @@ class C3x(C3):
"""C3 module with cross-convolutions."""
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
"""
Initialize C3 module with cross-convolutions.
"""Initialize C3 module with cross-convolutions.
Args:
c1 (int): Input channels.
@ -377,8 +363,7 @@ class RepC3(nn.Module):
"""Rep C3."""
def __init__(self, c1: int, c2: int, n: int = 3, e: float = 1.0):
"""
Initialize CSP Bottleneck with a single convolution.
"""Initialize CSP Bottleneck with a single convolution.
Args:
c1 (int): Input channels.
@ -402,8 +387,7 @@ class C3TR(C3):
"""C3 module with TransformerBlock()."""
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
"""
Initialize C3 module with TransformerBlock.
"""Initialize C3 module with TransformerBlock.
Args:
c1 (int): Input channels.
@ -422,8 +406,7 @@ class C3Ghost(C3):
"""C3 module with GhostBottleneck()."""
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
"""
Initialize C3 module with GhostBottleneck.
"""Initialize C3 module with GhostBottleneck.
Args:
c1 (int): Input channels.
@ -442,8 +425,7 @@ class GhostBottleneck(nn.Module):
"""Ghost Bottleneck https://github.com/huawei-noah/Efficient-AI-Backbones."""
def __init__(self, c1: int, c2: int, k: int = 3, s: int = 1):
"""
Initialize Ghost Bottleneck module.
"""Initialize Ghost Bottleneck module.
Args:
c1 (int): Input channels.
@ -473,8 +455,7 @@ class Bottleneck(nn.Module):
def __init__(
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: tuple[int, int] = (3, 3), e: float = 0.5
):
"""
Initialize a standard bottleneck module.
"""Initialize a standard bottleneck module.
Args:
c1 (int): Input channels.
@ -499,8 +480,7 @@ class BottleneckCSP(nn.Module):
"""CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks."""
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
"""
Initialize CSP Bottleneck.
"""Initialize CSP Bottleneck.
Args:
c1 (int): Input channels.
@ -531,8 +511,7 @@ class ResNetBlock(nn.Module):
"""ResNet block with standard convolution layers."""
def __init__(self, c1: int, c2: int, s: int = 1, e: int = 4):
"""
Initialize ResNet block.
"""Initialize ResNet block.
Args:
c1 (int): Input channels.
@ -556,8 +535,7 @@ class ResNetLayer(nn.Module):
"""ResNet layer with multiple ResNet blocks."""
def __init__(self, c1: int, c2: int, s: int = 1, is_first: bool = False, n: int = 1, e: int = 4):
"""
Initialize ResNet layer.
"""Initialize ResNet layer.
Args:
c1 (int): Input channels.
@ -588,8 +566,7 @@ class MaxSigmoidAttnBlock(nn.Module):
"""Max Sigmoid attention block."""
def __init__(self, c1: int, c2: int, nh: int = 1, ec: int = 128, gc: int = 512, scale: bool = False):
"""
Initialize MaxSigmoidAttnBlock.
"""Initialize MaxSigmoidAttnBlock.
Args:
c1 (int): Input channels.
@ -609,8 +586,7 @@ class MaxSigmoidAttnBlock(nn.Module):
self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0
def forward(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:
"""
Forward pass of MaxSigmoidAttnBlock.
"""Forward pass of MaxSigmoidAttnBlock.
Args:
x (torch.Tensor): Input tensor.
@ -653,8 +629,7 @@ class C2fAttn(nn.Module):
g: int = 1,
e: float = 0.5,
):
"""
Initialize C2f module with attention mechanism.
"""Initialize C2f module with attention mechanism.
Args:
c1 (int): Input channels.
@ -675,8 +650,7 @@ class C2fAttn(nn.Module):
self.attn = MaxSigmoidAttnBlock(self.c, self.c, gc=gc, ec=ec, nh=nh)
def forward(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:
"""
Forward pass through C2f layer with attention.
"""Forward pass through C2f layer with attention.
Args:
x (torch.Tensor): Input tensor.
@ -691,8 +665,7 @@ class C2fAttn(nn.Module):
return self.cv2(torch.cat(y, 1))
def forward_split(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:
"""
Forward pass using split() instead of chunk().
"""Forward pass using split() instead of chunk().
Args:
x (torch.Tensor): Input tensor.
@ -713,8 +686,7 @@ class ImagePoolingAttn(nn.Module):
def __init__(
self, ec: int = 256, ch: tuple[int, ...] = (), ct: int = 512, nh: int = 8, k: int = 3, scale: bool = False
):
"""
Initialize ImagePoolingAttn module.
"""Initialize ImagePoolingAttn module.
Args:
ec (int): Embedding channels.
@ -741,8 +713,7 @@ class ImagePoolingAttn(nn.Module):
self.k = k
def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> torch.Tensor:
"""
Forward pass of ImagePoolingAttn.
"""Forward pass of ImagePoolingAttn.
Args:
x (list[torch.Tensor]): List of input feature maps.
@ -785,8 +756,7 @@ class ContrastiveHead(nn.Module):
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""
Forward function of contrastive learning.
"""Forward function of contrastive learning.
Args:
x (torch.Tensor): Image features.
@ -802,16 +772,14 @@ class ContrastiveHead(nn.Module):
class BNContrastiveHead(nn.Module):
"""
Batch Norm Contrastive Head using batch norm instead of l2-normalization.
"""Batch Norm Contrastive Head using batch norm instead of l2-normalization.
Args:
embed_dims (int): Embed dimensions of text and image features.
"""
def __init__(self, embed_dims: int):
"""
Initialize BNContrastiveHead.
"""Initialize BNContrastiveHead.
Args:
embed_dims (int): Embedding dimensions for features.
@ -835,8 +803,7 @@ class BNContrastiveHead(nn.Module):
return x
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""
Forward function of contrastive learning with batch normalization.
"""Forward function of contrastive learning with batch normalization.
Args:
x (torch.Tensor): Image features.
@ -858,8 +825,7 @@ class RepBottleneck(Bottleneck):
def __init__(
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: tuple[int, int] = (3, 3), e: float = 0.5
):
"""
Initialize RepBottleneck.
"""Initialize RepBottleneck.
Args:
c1 (int): Input channels.
@ -878,8 +844,7 @@ class RepCSP(C3):
"""Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction."""
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
"""
Initialize RepCSP layer.
"""Initialize RepCSP layer.
Args:
c1 (int): Input channels.
@ -898,8 +863,7 @@ class RepNCSPELAN4(nn.Module):
"""CSP-ELAN."""
def __init__(self, c1: int, c2: int, c3: int, c4: int, n: int = 1):
"""
Initialize CSP-ELAN layer.
"""Initialize CSP-ELAN layer.
Args:
c1 (int): Input channels.
@ -932,8 +896,7 @@ class ELAN1(RepNCSPELAN4):
"""ELAN1 module with 4 convolutions."""
def __init__(self, c1: int, c2: int, c3: int, c4: int):
"""
Initialize ELAN1 layer.
"""Initialize ELAN1 layer.
Args:
c1 (int): Input channels.
@ -953,8 +916,7 @@ class AConv(nn.Module):
"""AConv."""
def __init__(self, c1: int, c2: int):
"""
Initialize AConv module.
"""Initialize AConv module.
Args:
c1 (int): Input channels.
@ -973,8 +935,7 @@ class ADown(nn.Module):
"""ADown."""
def __init__(self, c1: int, c2: int):
"""
Initialize ADown module.
"""Initialize ADown module.
Args:
c1 (int): Input channels.
@ -999,8 +960,7 @@ class SPPELAN(nn.Module):
"""SPP-ELAN."""
def __init__(self, c1: int, c2: int, c3: int, k: int = 5):
"""
Initialize SPP-ELAN block.
"""Initialize SPP-ELAN block.
Args:
c1 (int): Input channels.
@ -1027,8 +987,7 @@ class CBLinear(nn.Module):
"""CBLinear."""
def __init__(self, c1: int, c2s: list[int], k: int = 1, s: int = 1, p: int | None = None, g: int = 1):
"""
Initialize CBLinear module.
"""Initialize CBLinear module.
Args:
c1 (int): Input channels.
@ -1051,8 +1010,7 @@ class CBFuse(nn.Module):
"""CBFuse."""
def __init__(self, idx: list[int]):
"""
Initialize CBFuse module.
"""Initialize CBFuse module.
Args:
idx (list[int]): Indices for feature selection.
@ -1061,8 +1019,7 @@ class CBFuse(nn.Module):
self.idx = idx
def forward(self, xs: list[torch.Tensor]) -> torch.Tensor:
"""
Forward pass through CBFuse layer.
"""Forward pass through CBFuse layer.
Args:
xs (list[torch.Tensor]): List of input tensors.
@ -1079,8 +1036,7 @@ class C3f(nn.Module):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):
"""
Initialize CSP bottleneck layer with two convolutions.
"""Initialize CSP bottleneck layer with two convolutions.
Args:
c1 (int): Input channels.
@ -1110,8 +1066,7 @@ class C3k2(C2f):
def __init__(
self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True
):
"""
Initialize C3k2 module.
"""Initialize C3k2 module.
Args:
c1 (int): Input channels.
@ -1132,8 +1087,7 @@ class C3k(C3):
"""C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks."""
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5, k: int = 3):
"""
Initialize C3k module.
"""Initialize C3k module.
Args:
c1 (int): Input channels.
@ -1154,8 +1108,7 @@ class RepVGGDW(torch.nn.Module):
"""RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture."""
def __init__(self, ed: int) -> None:
"""
Initialize RepVGGDW module.
"""Initialize RepVGGDW module.
Args:
ed (int): Input and output channels.
@ -1167,8 +1120,7 @@ class RepVGGDW(torch.nn.Module):
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass of the RepVGGDW block.
"""Perform a forward pass of the RepVGGDW block.
Args:
x (torch.Tensor): Input tensor.
@ -1179,8 +1131,7 @@ class RepVGGDW(torch.nn.Module):
return self.act(self.conv(x) + self.conv1(x))
def forward_fuse(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass of the RepVGGDW block without fusing the convolutions.
"""Perform a forward pass of the RepVGGDW block without fusing the convolutions.
Args:
x (torch.Tensor): Input tensor.
@ -1192,8 +1143,7 @@ class RepVGGDW(torch.nn.Module):
@torch.no_grad()
def fuse(self):
"""
Fuse the convolutional layers in the RepVGGDW block.
"""Fuse the convolutional layers in the RepVGGDW block.
This method fuses the convolutional layers and updates the weights and biases accordingly.
"""
@ -1218,8 +1168,7 @@ class RepVGGDW(torch.nn.Module):
class CIB(nn.Module):
"""
Conditional Identity Block (CIB) module.
"""Conditional Identity Block (CIB) module.
Args:
c1 (int): Number of input channels.
@ -1230,8 +1179,7 @@ class CIB(nn.Module):
"""
def __init__(self, c1: int, c2: int, shortcut: bool = True, e: float = 0.5, lk: bool = False):
"""
Initialize the CIB module.
"""Initialize the CIB module.
Args:
c1 (int): Input channels.
@ -1253,8 +1201,7 @@ class CIB(nn.Module):
self.add = shortcut and c1 == c2
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the CIB module.
"""Forward pass of the CIB module.
Args:
x (torch.Tensor): Input tensor.
@ -1266,8 +1213,7 @@ class CIB(nn.Module):
class C2fCIB(C2f):
"""
C2fCIB class represents a convolutional block with C2f and CIB modules.
"""C2fCIB class represents a convolutional block with C2f and CIB modules.
Args:
c1 (int): Number of input channels.
@ -1282,8 +1228,7 @@ class C2fCIB(C2f):
def __init__(
self, c1: int, c2: int, n: int = 1, shortcut: bool = False, lk: bool = False, g: int = 1, e: float = 0.5
):
"""
Initialize C2fCIB module.
"""Initialize C2fCIB module.
Args:
c1 (int): Input channels.
@ -1299,8 +1244,7 @@ class C2fCIB(C2f):
class Attention(nn.Module):
"""
Attention module that performs self-attention on the input tensor.
"""Attention module that performs self-attention on the input tensor.
Args:
dim (int): The input tensor dimension.
@ -1318,8 +1262,7 @@ class Attention(nn.Module):
"""
def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5):
"""
Initialize multi-head attention module.
"""Initialize multi-head attention module.
Args:
dim (int): Input dimension.
@ -1338,8 +1281,7 @@ class Attention(nn.Module):
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the Attention module.
"""Forward pass of the Attention module.
Args:
x (torch.Tensor): The input tensor.
@ -1362,8 +1304,7 @@ class Attention(nn.Module):
class PSABlock(nn.Module):
"""
PSABlock class implementing a Position-Sensitive Attention block for neural networks.
"""PSABlock class implementing a Position-Sensitive Attention block for neural networks.
This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers
with optional shortcut connections.
@ -1384,8 +1325,7 @@ class PSABlock(nn.Module):
"""
def __init__(self, c: int, attn_ratio: float = 0.5, num_heads: int = 4, shortcut: bool = True) -> None:
"""
Initialize the PSABlock.
"""Initialize the PSABlock.
Args:
c (int): Input and output channels.
@ -1400,8 +1340,7 @@ class PSABlock(nn.Module):
self.add = shortcut
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Execute a forward pass through PSABlock.
"""Execute a forward pass through PSABlock.
Args:
x (torch.Tensor): Input tensor.
@ -1415,8 +1354,7 @@ class PSABlock(nn.Module):
class PSA(nn.Module):
"""
PSA class for implementing Position-Sensitive Attention in neural networks.
"""PSA class for implementing Position-Sensitive Attention in neural networks.
This class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to
input tensors, enhancing feature extraction and processing capabilities.
@ -1439,8 +1377,7 @@ class PSA(nn.Module):
"""
def __init__(self, c1: int, c2: int, e: float = 0.5):
"""
Initialize PSA module.
"""Initialize PSA module.
Args:
c1 (int): Input channels.
@ -1457,8 +1394,7 @@ class PSA(nn.Module):
self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Execute forward pass in PSA module.
"""Execute forward pass in PSA module.
Args:
x (torch.Tensor): Input tensor.
@ -1473,8 +1409,7 @@ class PSA(nn.Module):
class C2PSA(nn.Module):
"""
C2PSA module with attention mechanism for enhanced feature extraction and processing.
"""C2PSA module with attention mechanism for enhanced feature extraction and processing.
This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing
capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.
@ -1498,8 +1433,7 @@ class C2PSA(nn.Module):
"""
def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):
"""
Initialize C2PSA module.
"""Initialize C2PSA module.
Args:
c1 (int): Input channels.
@ -1516,8 +1450,7 @@ class C2PSA(nn.Module):
self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Process the input tensor through a series of PSA blocks.
"""Process the input tensor through a series of PSA blocks.
Args:
x (torch.Tensor): Input tensor.
@ -1531,8 +1464,7 @@ class C2PSA(nn.Module):
class C2fPSA(C2f):
"""
C2fPSA module with enhanced feature extraction using PSA blocks.
"""C2fPSA module with enhanced feature extraction using PSA blocks.
This class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature
extraction.
@ -1557,8 +1489,7 @@ class C2fPSA(C2f):
"""
def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):
"""
Initialize C2fPSA module.
"""Initialize C2fPSA module.
Args:
c1 (int): Input channels.
@ -1572,8 +1503,7 @@ class C2fPSA(C2f):
class SCDown(nn.Module):
"""
SCDown module for downsampling with separable convolutions.
"""SCDown module for downsampling with separable convolutions.
This module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in
efficiently reducing the spatial dimensions of the input tensor while maintaining the channel information.
@ -1596,8 +1526,7 @@ class SCDown(nn.Module):
"""
def __init__(self, c1: int, c2: int, k: int, s: int):
"""
Initialize SCDown module.
"""Initialize SCDown module.
Args:
c1 (int): Input channels.
@ -1610,8 +1539,7 @@ class SCDown(nn.Module):
self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply convolution and downsampling to the input tensor.
"""Apply convolution and downsampling to the input tensor.
Args:
x (torch.Tensor): Input tensor.
@ -1623,8 +1551,7 @@ class SCDown(nn.Module):
class TorchVision(nn.Module):
"""
TorchVision module to allow loading any torchvision model.
"""TorchVision module to allow loading any torchvision model.
This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and
customize the model by truncating or unwrapping layers.
@ -1643,8 +1570,7 @@ class TorchVision(nn.Module):
def __init__(
self, model: str, weights: str = "DEFAULT", unwrap: bool = True, truncate: int = 2, split: bool = False
):
"""
Load the model and weights from torchvision.
"""Load the model and weights from torchvision.
Args:
model (str): Name of the torchvision model to load.
@ -1671,8 +1597,7 @@ class TorchVision(nn.Module):
self.m.head = self.m.heads = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the model.
"""Forward pass through the model.
Args:
x (torch.Tensor): Input tensor.
@ -1689,8 +1614,7 @@ class TorchVision(nn.Module):
class AAttn(nn.Module):
"""
Area-attention module for YOLO models, providing efficient attention mechanisms.
"""Area-attention module for YOLO models, providing efficient attention mechanisms.
This module implements an area-based attention mechanism that processes input features in a spatially-aware manner,
making it particularly effective for object detection tasks.
@ -1715,8 +1639,7 @@ class AAttn(nn.Module):
"""
def __init__(self, dim: int, num_heads: int, area: int = 1):
"""
Initialize an Area-attention module for YOLO models.
"""Initialize an Area-attention module for YOLO models.
Args:
dim (int): Number of hidden channels.
@ -1735,8 +1658,7 @@ class AAttn(nn.Module):
self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Process the input tensor through the area-attention.
"""Process the input tensor through the area-attention.
Args:
x (torch.Tensor): Input tensor.
@ -1775,8 +1697,7 @@ class AAttn(nn.Module):
class ABlock(nn.Module):
"""
Area-attention block module for efficient feature extraction in YOLO models.
"""Area-attention block module for efficient feature extraction in YOLO models.
This module implements an area-attention mechanism combined with a feed-forward network for processing feature maps.
It uses a novel area-based attention approach that is more efficient than traditional self-attention while
@ -1799,8 +1720,7 @@ class ABlock(nn.Module):
"""
def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 1.2, area: int = 1):
"""
Initialize an Area-attention block module.
"""Initialize an Area-attention block module.
Args:
dim (int): Number of input channels.
@ -1817,8 +1737,7 @@ class ABlock(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module):
"""
Initialize weights using a truncated normal distribution.
"""Initialize weights using a truncated normal distribution.
Args:
m (nn.Module): Module to initialize.
@ -1829,8 +1748,7 @@ class ABlock(nn.Module):
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through ABlock.
"""Forward pass through ABlock.
Args:
x (torch.Tensor): Input tensor.
@ -1843,8 +1761,7 @@ class ABlock(nn.Module):
class A2C2f(nn.Module):
"""
Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.
"""Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.
This module extends the C2f architecture by incorporating area-attention and ABlock layers for improved feature
processing. It supports both area-attention and standard convolution modes.
@ -1879,8 +1796,7 @@ class A2C2f(nn.Module):
g: int = 1,
shortcut: bool = True,
):
"""
Initialize Area-Attention C2f module.
"""Initialize Area-Attention C2f module.
Args:
c1 (int): Number of input channels.
@ -1910,8 +1826,7 @@ class A2C2f(nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through A2C2f layer.
"""Forward pass through A2C2f layer.
Args:
x (torch.Tensor): Input tensor.
@ -1931,8 +1846,7 @@ class SwiGLUFFN(nn.Module):
"""SwiGLU Feed-Forward Network for transformer-based architectures."""
def __init__(self, gc: int, ec: int, e: int = 4) -> None:
"""
Initialize SwiGLU FFN with input dimension, output dimension, and expansion factor.
"""Initialize SwiGLU FFN with input dimension, output dimension, and expansion factor.
Args:
gc (int): Guide channels.
@ -1955,8 +1869,7 @@ class Residual(nn.Module):
"""Residual connection wrapper for neural network modules."""
def __init__(self, m: nn.Module) -> None:
"""
Initialize residual module with the wrapped module.
"""Initialize residual module with the wrapped module.
Args:
m (nn.Module): Module to wrap with residual connection.
@ -1977,8 +1890,7 @@ class SAVPE(nn.Module):
"""Spatial-Aware Visual Prompt Embedding module for feature enhancement."""
def __init__(self, ch: list[int], c3: int, embed: int):
"""
Initialize SAVPE module with channels, intermediate channels, and embedding dimension.
"""Initialize SAVPE module with channels, intermediate channels, and embedding dimension.
Args:
ch (list[int]): List of input channel dimensions.

View file

@ -37,8 +37,7 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation
class Conv(nn.Module):
"""
Standard convolution module with batch normalization and activation.
"""Standard convolution module with batch normalization and activation.
Attributes:
conv (nn.Conv2d): Convolutional layer.
@ -50,8 +49,7 @@ class Conv(nn.Module):
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""
Initialize Conv layer with given parameters.
"""Initialize Conv layer with given parameters.
Args:
c1 (int): Number of input channels.
@ -69,8 +67,7 @@ class Conv(nn.Module):
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""
Apply convolution, batch normalization and activation to input tensor.
"""Apply convolution, batch normalization and activation to input tensor.
Args:
x (torch.Tensor): Input tensor.
@ -81,8 +78,7 @@ class Conv(nn.Module):
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""
Apply convolution and activation without batch normalization.
"""Apply convolution and activation without batch normalization.
Args:
x (torch.Tensor): Input tensor.
@ -94,8 +90,7 @@ class Conv(nn.Module):
class Conv2(Conv):
"""
Simplified RepConv module with Conv fusing.
"""Simplified RepConv module with Conv fusing.
Attributes:
conv (nn.Conv2d): Main 3x3 convolutional layer.
@ -105,8 +100,7 @@ class Conv2(Conv):
"""
def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
"""
Initialize Conv2 layer with given parameters.
"""Initialize Conv2 layer with given parameters.
Args:
c1 (int): Number of input channels.
@ -122,8 +116,7 @@ class Conv2(Conv):
self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
def forward(self, x):
"""
Apply convolution, batch normalization and activation to input tensor.
"""Apply convolution, batch normalization and activation to input tensor.
Args:
x (torch.Tensor): Input tensor.
@ -134,8 +127,7 @@ class Conv2(Conv):
return self.act(self.bn(self.conv(x) + self.cv2(x)))
def forward_fuse(self, x):
"""
Apply fused convolution, batch normalization and activation to input tensor.
"""Apply fused convolution, batch normalization and activation to input tensor.
Args:
x (torch.Tensor): Input tensor.
@ -156,8 +148,7 @@ class Conv2(Conv):
class LightConv(nn.Module):
"""
Light convolution module with 1x1 and depthwise convolutions.
"""Light convolution module with 1x1 and depthwise convolutions.
This implementation is based on the PaddleDetection HGNetV2 backbone.
@ -167,8 +158,7 @@ class LightConv(nn.Module):
"""
def __init__(self, c1, c2, k=1, act=nn.ReLU()):
"""
Initialize LightConv layer with given parameters.
"""Initialize LightConv layer with given parameters.
Args:
c1 (int): Number of input channels.
@ -181,8 +171,7 @@ class LightConv(nn.Module):
self.conv2 = DWConv(c2, c2, k, act=act)
def forward(self, x):
"""
Apply 2 convolutions to input tensor.
"""Apply 2 convolutions to input tensor.
Args:
x (torch.Tensor): Input tensor.
@ -197,8 +186,7 @@ class DWConv(Conv):
"""Depth-wise convolution module."""
def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
"""
Initialize depth-wise convolution with given parameters.
"""Initialize depth-wise convolution with given parameters.
Args:
c1 (int): Number of input channels.
@ -215,8 +203,7 @@ class DWConvTranspose2d(nn.ConvTranspose2d):
"""Depth-wise transpose convolution module."""
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0):
"""
Initialize depth-wise transpose convolution with given parameters.
"""Initialize depth-wise transpose convolution with given parameters.
Args:
c1 (int): Number of input channels.
@ -230,8 +217,7 @@ class DWConvTranspose2d(nn.ConvTranspose2d):
class ConvTranspose(nn.Module):
"""
Convolution transpose module with optional batch normalization and activation.
"""Convolution transpose module with optional batch normalization and activation.
Attributes:
conv_transpose (nn.ConvTranspose2d): Transposed convolution layer.
@ -243,8 +229,7 @@ class ConvTranspose(nn.Module):
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
"""
Initialize ConvTranspose layer with given parameters.
"""Initialize ConvTranspose layer with given parameters.
Args:
c1 (int): Number of input channels.
@ -261,8 +246,7 @@ class ConvTranspose(nn.Module):
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""
Apply transposed convolution, batch normalization and activation to input.
"""Apply transposed convolution, batch normalization and activation to input.
Args:
x (torch.Tensor): Input tensor.
@ -273,8 +257,7 @@ class ConvTranspose(nn.Module):
return self.act(self.bn(self.conv_transpose(x)))
def forward_fuse(self, x):
"""
Apply activation and convolution transpose operation to input.
"""Apply activation and convolution transpose operation to input.
Args:
x (torch.Tensor): Input tensor.
@ -286,8 +269,7 @@ class ConvTranspose(nn.Module):
class Focus(nn.Module):
"""
Focus module for concentrating feature information.
"""Focus module for concentrating feature information.
Slices input tensor into 4 parts and concatenates them in the channel dimension.
@ -296,8 +278,7 @@ class Focus(nn.Module):
"""
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
"""
Initialize Focus module with given parameters.
"""Initialize Focus module with given parameters.
Args:
c1 (int): Number of input channels.
@ -313,8 +294,7 @@ class Focus(nn.Module):
# self.contract = Contract(gain=2)
def forward(self, x):
"""
Apply Focus operation and convolution to input tensor.
"""Apply Focus operation and convolution to input tensor.
Input shape is (B, C, W, H) and output shape is (B, 4C, W/2, H/2).
@ -329,8 +309,7 @@ class Focus(nn.Module):
class GhostConv(nn.Module):
"""
Ghost Convolution module.
"""Ghost Convolution module.
Generates more features with fewer parameters by using cheap operations.
@ -343,8 +322,7 @@ class GhostConv(nn.Module):
"""
def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
"""
Initialize Ghost Convolution module with given parameters.
"""Initialize Ghost Convolution module with given parameters.
Args:
c1 (int): Number of input channels.
@ -360,8 +338,7 @@ class GhostConv(nn.Module):
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
def forward(self, x):
"""
Apply Ghost Convolution to input tensor.
"""Apply Ghost Convolution to input tensor.
Args:
x (torch.Tensor): Input tensor.
@ -374,8 +351,7 @@ class GhostConv(nn.Module):
class RepConv(nn.Module):
"""
RepConv module with training and deploy modes.
"""RepConv module with training and deploy modes.
This module is used in RT-DETR and can fuse convolutions during inference for efficiency.
@ -393,8 +369,7 @@ class RepConv(nn.Module):
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
"""
Initialize RepConv module with given parameters.
"""Initialize RepConv module with given parameters.
Args:
c1 (int): Number of input channels.
@ -420,8 +395,7 @@ class RepConv(nn.Module):
self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
def forward_fuse(self, x):
"""
Forward pass for deploy mode.
"""Forward pass for deploy mode.
Args:
x (torch.Tensor): Input tensor.
@ -432,8 +406,7 @@ class RepConv(nn.Module):
return self.act(self.conv(x))
def forward(self, x):
"""
Forward pass for training mode.
"""Forward pass for training mode.
Args:
x (torch.Tensor): Input tensor.
@ -445,8 +418,7 @@ class RepConv(nn.Module):
return self.act(self.conv1(x) + self.conv2(x) + id_out)
def get_equivalent_kernel_bias(self):
"""
Calculate equivalent kernel and bias by fusing convolutions.
"""Calculate equivalent kernel and bias by fusing convolutions.
Returns:
(torch.Tensor): Equivalent kernel
@ -459,8 +431,7 @@ class RepConv(nn.Module):
@staticmethod
def _pad_1x1_to_3x3_tensor(kernel1x1):
"""
Pad a 1x1 kernel to 3x3 size.
"""Pad a 1x1 kernel to 3x3 size.
Args:
kernel1x1 (torch.Tensor): 1x1 convolution kernel.
@ -474,8 +445,7 @@ class RepConv(nn.Module):
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
def _fuse_bn_tensor(self, branch):
"""
Fuse batch normalization with convolution weights.
"""Fuse batch normalization with convolution weights.
Args:
branch (Conv | nn.BatchNorm2d | None): Branch to fuse.
@ -540,8 +510,7 @@ class RepConv(nn.Module):
class ChannelAttention(nn.Module):
"""
Channel-attention module for feature recalibration.
"""Channel-attention module for feature recalibration.
Applies attention weights to channels based on global average pooling.
@ -555,8 +524,7 @@ class ChannelAttention(nn.Module):
"""
def __init__(self, channels: int) -> None:
"""
Initialize Channel-attention module.
"""Initialize Channel-attention module.
Args:
channels (int): Number of input channels.
@ -567,8 +535,7 @@ class ChannelAttention(nn.Module):
self.act = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply channel attention to input tensor.
"""Apply channel attention to input tensor.
Args:
x (torch.Tensor): Input tensor.
@ -580,8 +547,7 @@ class ChannelAttention(nn.Module):
class SpatialAttention(nn.Module):
"""
Spatial-attention module for feature recalibration.
"""Spatial-attention module for feature recalibration.
Applies attention weights to spatial dimensions based on channel statistics.
@ -591,8 +557,7 @@ class SpatialAttention(nn.Module):
"""
def __init__(self, kernel_size=7):
"""
Initialize Spatial-attention module.
"""Initialize Spatial-attention module.
Args:
kernel_size (int): Size of the convolutional kernel (3 or 7).
@ -604,8 +569,7 @@ class SpatialAttention(nn.Module):
self.act = nn.Sigmoid()
def forward(self, x):
"""
Apply spatial attention to input tensor.
"""Apply spatial attention to input tensor.
Args:
x (torch.Tensor): Input tensor.
@ -617,8 +581,7 @@ class SpatialAttention(nn.Module):
class CBAM(nn.Module):
"""
Convolutional Block Attention Module.
"""Convolutional Block Attention Module.
Combines channel and spatial attention mechanisms for comprehensive feature refinement.
@ -628,8 +591,7 @@ class CBAM(nn.Module):
"""
def __init__(self, c1, kernel_size=7):
"""
Initialize CBAM with given parameters.
"""Initialize CBAM with given parameters.
Args:
c1 (int): Number of input channels.
@ -640,8 +602,7 @@ class CBAM(nn.Module):
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
"""
Apply channel and spatial attention sequentially to input tensor.
"""Apply channel and spatial attention sequentially to input tensor.
Args:
x (torch.Tensor): Input tensor.
@ -653,16 +614,14 @@ class CBAM(nn.Module):
class Concat(nn.Module):
"""
Concatenate a list of tensors along specified dimension.
"""Concatenate a list of tensors along specified dimension.
Attributes:
d (int): Dimension along which to concatenate tensors.
"""
def __init__(self, dimension=1):
"""
Initialize Concat module.
"""Initialize Concat module.
Args:
dimension (int): Dimension along which to concatenate tensors.
@ -671,8 +630,7 @@ class Concat(nn.Module):
self.d = dimension
def forward(self, x: list[torch.Tensor]):
"""
Concatenate input tensors along specified dimension.
"""Concatenate input tensors along specified dimension.
Args:
x (list[torch.Tensor]): List of input tensors.
@ -684,16 +642,14 @@ class Concat(nn.Module):
class Index(nn.Module):
"""
Returns a particular index of the input.
"""Returns a particular index of the input.
Attributes:
index (int): Index to select from input.
"""
def __init__(self, index=0):
"""
Initialize Index module.
"""Initialize Index module.
Args:
index (int): Index to select from input.
@ -702,8 +658,7 @@ class Index(nn.Module):
self.index = index
def forward(self, x: list[torch.Tensor]):
"""
Select and return a particular index from input.
"""Select and return a particular index from input.
Args:
x (list[torch.Tensor]): List of input tensors.

View file

@ -24,8 +24,7 @@ __all__ = "OBB", "Classify", "Detect", "Pose", "RTDETRDecoder", "Segment", "YOLO
class Detect(nn.Module):
"""
YOLO Detect head for object detection models.
"""YOLO Detect head for object detection models.
This class implements the detection head used in YOLO models for predicting bounding boxes and class probabilities.
It supports both training and inference modes, with optional end-to-end detection capabilities.
@ -78,8 +77,7 @@ class Detect(nn.Module):
xyxy = False # xyxy or xywh output
def __init__(self, nc: int = 80, ch: tuple = ()):
"""
Initialize the YOLO detection layer with specified number of classes and channels.
"""Initialize the YOLO detection layer with specified number of classes and channels.
Args:
nc (int): Number of classes.
@ -126,15 +124,14 @@ class Detect(nn.Module):
return y if self.export else (y, x)
def forward_end2end(self, x: list[torch.Tensor]) -> dict | tuple:
"""
Perform forward pass of the v10Detect module.
"""Perform forward pass of the v10Detect module.
Args:
x (list[torch.Tensor]): Input feature maps from different levels.
Returns:
outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs.
Inference mode returns processed detections or tuple with detections and raw outputs.
outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs. Inference mode returns
processed detections or tuple with detections and raw outputs.
"""
x_detach = [xi.detach() for xi in x]
one2one = [
@ -150,8 +147,7 @@ class Detect(nn.Module):
return y if self.export else (y, {"one2many": x, "one2one": one2one})
def _inference(self, x: list[torch.Tensor]) -> torch.Tensor:
"""
Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
"""Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
Args:
x (list[torch.Tensor]): List of feature maps from different detection layers.
@ -194,8 +190,7 @@ class Detect(nn.Module):
@staticmethod
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
"""
Post-process YOLO model predictions.
"""Post-process YOLO model predictions.
Args:
preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
@ -218,8 +213,7 @@ class Detect(nn.Module):
class Segment(Detect):
"""
YOLO Segment head for segmentation models.
"""YOLO Segment head for segmentation models.
This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.
@ -240,8 +234,7 @@ class Segment(Detect):
"""
def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, ch: tuple = ()):
"""
Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
"""Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
Args:
nc (int): Number of classes.
@ -270,8 +263,7 @@ class Segment(Detect):
class OBB(Detect):
"""
YOLO OBB detection head for detection with rotation models.
"""YOLO OBB detection head for detection with rotation models.
This class extends the Detect head to include oriented bounding box prediction with rotation angles.
@ -292,8 +284,7 @@ class OBB(Detect):
"""
def __init__(self, nc: int = 80, ne: int = 1, ch: tuple = ()):
"""
Initialize OBB with number of classes `nc` and layer channels `ch`.
"""Initialize OBB with number of classes `nc` and layer channels `ch`.
Args:
nc (int): Number of classes.
@ -326,8 +317,7 @@ class OBB(Detect):
class Pose(Detect):
"""
YOLO Pose head for keypoints models.
"""YOLO Pose head for keypoints models.
This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.
@ -348,8 +338,7 @@ class Pose(Detect):
"""
def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), ch: tuple = ()):
"""
Initialize YOLO network with default parameters and Convolutional Layers.
"""Initialize YOLO network with default parameters and Convolutional Layers.
Args:
nc (int): Number of classes.
@ -396,8 +385,7 @@ class Pose(Detect):
class Classify(nn.Module):
"""
YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
This class implements a classification head that transforms feature maps into class predictions.
@ -421,8 +409,7 @@ class Classify(nn.Module):
export = False # export mode
def __init__(self, c1: int, c2: int, k: int = 1, s: int = 1, p: int | None = None, g: int = 1):
"""
Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
"""Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
Args:
c1 (int): Number of input channels.
@ -451,8 +438,7 @@ class Classify(nn.Module):
class WorldDetect(Detect):
"""
Head for integrating YOLO detection models with semantic understanding from text embeddings.
"""Head for integrating YOLO detection models with semantic understanding from text embeddings.
This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding in
object detection tasks.
@ -474,8 +460,7 @@ class WorldDetect(Detect):
"""
def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
"""
Initialize YOLO detection layer with nc classes and layer channels ch.
"""Initialize YOLO detection layer with nc classes and layer channels ch.
Args:
nc (int): Number of classes.
@ -509,8 +494,7 @@ class WorldDetect(Detect):
class LRPCHead(nn.Module):
"""
Lightweight Region Proposal and Classification Head for efficient object detection.
"""Lightweight Region Proposal and Classification Head for efficient object detection.
This head combines region proposal filtering with classification to enable efficient detection with dynamic
vocabulary support.
@ -534,8 +518,7 @@ class LRPCHead(nn.Module):
"""
def __init__(self, vocab: nn.Module, pf: nn.Module, loc: nn.Module, enabled: bool = True):
"""
Initialize LRPCHead with vocabulary, proposal filter, and localization components.
"""Initialize LRPCHead with vocabulary, proposal filter, and localization components.
Args:
vocab (nn.Module): Vocabulary/classification module.
@ -574,8 +557,7 @@ class LRPCHead(nn.Module):
class YOLOEDetect(Detect):
"""
Head for integrating YOLO detection models with semantic understanding from text embeddings.
"""Head for integrating YOLO detection models with semantic understanding from text embeddings.
This class extends the standard Detect head to support text-guided detection with enhanced semantic understanding
through text embeddings and visual prompt embeddings.
@ -607,8 +589,7 @@ class YOLOEDetect(Detect):
is_fused = False
def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
"""
Initialize YOLO detection layer with nc classes and layer channels ch.
"""Initialize YOLO detection layer with nc classes and layer channels ch.
Args:
nc (int): Number of classes.
@ -762,8 +743,7 @@ class YOLOEDetect(Detect):
class YOLOESegment(YOLOEDetect):
"""
YOLO segmentation head with text embedding capabilities.
"""YOLO segmentation head with text embedding capabilities.
This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks with
text-guided semantic understanding.
@ -788,8 +768,7 @@ class YOLOESegment(YOLOEDetect):
def __init__(
self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, with_bn: bool = False, ch: tuple = ()
):
"""
Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
"""Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
Args:
nc (int): Number of classes.
@ -830,8 +809,7 @@ class YOLOESegment(YOLOEDetect):
class RTDETRDecoder(nn.Module):
"""
Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
"""Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
@ -895,8 +873,7 @@ class RTDETRDecoder(nn.Module):
box_noise_scale: float = 1.0,
learnt_init_query: bool = False,
):
"""
Initialize the RTDETRDecoder module with the given parameters.
"""Initialize the RTDETRDecoder module with the given parameters.
Args:
nc (int): Number of classes.
@ -956,8 +933,7 @@ class RTDETRDecoder(nn.Module):
self._reset_parameters()
def forward(self, x: list[torch.Tensor], batch: dict | None = None) -> tuple | torch.Tensor:
"""
Run the forward pass of the module, returning bounding box and classification scores for the input.
"""Run the forward pass of the module, returning bounding box and classification scores for the input.
Args:
x (list[torch.Tensor]): List of feature maps from the backbone.
@ -1013,8 +989,7 @@ class RTDETRDecoder(nn.Module):
device: str = "cpu",
eps: float = 1e-2,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Generate anchor bounding boxes for given shapes with specific grid size and validate them.
"""Generate anchor bounding boxes for given shapes with specific grid size and validate them.
Args:
shapes (list): List of feature map shapes.
@ -1046,8 +1021,7 @@ class RTDETRDecoder(nn.Module):
return anchors, valid_mask
def _get_encoder_input(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, list[list[int]]]:
"""
Process and return encoder inputs by getting projection features from input and concatenating them.
"""Process and return encoder inputs by getting projection features from input and concatenating them.
Args:
x (list[torch.Tensor]): List of feature maps from the backbone.
@ -1079,8 +1053,7 @@ class RTDETRDecoder(nn.Module):
dn_embed: torch.Tensor | None = None,
dn_bbox: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate and prepare the input required for the decoder from the provided features and shapes.
"""Generate and prepare the input required for the decoder from the provided features and shapes.
Args:
feats (torch.Tensor): Processed features from encoder.
@ -1158,8 +1131,7 @@ class RTDETRDecoder(nn.Module):
class v10Detect(Detect):
"""
v10 Detection head from https://arxiv.org/pdf/2405.14458.
"""v10 Detection head from https://arxiv.org/pdf/2405.14458.
This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions for
improved efficiency and performance.
@ -1186,8 +1158,7 @@ class v10Detect(Detect):
end2end = True
def __init__(self, nc: int = 80, ch: tuple = ()):
"""
Initialize the v10Detect object with the specified number of classes and input channels.
"""Initialize the v10Detect object with the specified number of classes and input channels.
Args:
nc (int): Number of classes.

View file

@ -30,8 +30,7 @@ __all__ = (
class TransformerEncoderLayer(nn.Module):
"""
A single layer of the transformer encoder.
"""A single layer of the transformer encoder.
This class implements a standard transformer encoder layer with multi-head attention and feedforward network,
supporting both pre-normalization and post-normalization configurations.
@ -58,8 +57,7 @@ class TransformerEncoderLayer(nn.Module):
act: nn.Module = nn.GELU(),
normalize_before: bool = False,
):
"""
Initialize the TransformerEncoderLayer with specified parameters.
"""Initialize the TransformerEncoderLayer with specified parameters.
Args:
c1 (int): Input dimension.
@ -102,8 +100,7 @@ class TransformerEncoderLayer(nn.Module):
src_key_padding_mask: torch.Tensor | None = None,
pos: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Perform forward pass with post-normalization.
"""Perform forward pass with post-normalization.
Args:
src (torch.Tensor): Input tensor.
@ -129,8 +126,7 @@ class TransformerEncoderLayer(nn.Module):
src_key_padding_mask: torch.Tensor | None = None,
pos: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Perform forward pass with pre-normalization.
"""Perform forward pass with pre-normalization.
Args:
src (torch.Tensor): Input tensor.
@ -156,8 +152,7 @@ class TransformerEncoderLayer(nn.Module):
src_key_padding_mask: torch.Tensor | None = None,
pos: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Forward propagate the input through the encoder module.
"""Forward propagate the input through the encoder module.
Args:
src (torch.Tensor): Input tensor.
@ -174,8 +169,7 @@ class TransformerEncoderLayer(nn.Module):
class AIFI(TransformerEncoderLayer):
"""
AIFI transformer layer for 2D data with positional embeddings.
"""AIFI transformer layer for 2D data with positional embeddings.
This class extends TransformerEncoderLayer to work with 2D feature maps by adding 2D sine-cosine positional
embeddings and handling the spatial dimensions appropriately.
@ -190,8 +184,7 @@ class AIFI(TransformerEncoderLayer):
act: nn.Module = nn.GELU(),
normalize_before: bool = False,
):
"""
Initialize the AIFI instance with specified parameters.
"""Initialize the AIFI instance with specified parameters.
Args:
c1 (int): Input dimension.
@ -204,8 +197,7 @@ class AIFI(TransformerEncoderLayer):
super().__init__(c1, cm, num_heads, dropout, act, normalize_before)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the AIFI transformer layer.
"""Forward pass for the AIFI transformer layer.
Args:
x (torch.Tensor): Input tensor with shape [B, C, H, W].
@ -223,8 +215,7 @@ class AIFI(TransformerEncoderLayer):
def build_2d_sincos_position_embedding(
w: int, h: int, embed_dim: int = 256, temperature: float = 10000.0
) -> torch.Tensor:
"""
Build 2D sine-cosine position embedding.
"""Build 2D sine-cosine position embedding.
Args:
w (int): Width of the feature map.
@ -253,8 +244,7 @@ class TransformerLayer(nn.Module):
"""Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""
def __init__(self, c: int, num_heads: int):
"""
Initialize a self-attention mechanism using linear transformations and multi-head attention.
"""Initialize a self-attention mechanism using linear transformations and multi-head attention.
Args:
c (int): Input and output channel dimension.
@ -269,8 +259,7 @@ class TransformerLayer(nn.Module):
self.fc2 = nn.Linear(c, c, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply a transformer block to the input x and return the output.
"""Apply a transformer block to the input x and return the output.
Args:
x (torch.Tensor): Input tensor.
@ -283,8 +272,7 @@ class TransformerLayer(nn.Module):
class TransformerBlock(nn.Module):
"""
Vision Transformer block based on https://arxiv.org/abs/2010.11929.
"""Vision Transformer block based on https://arxiv.org/abs/2010.11929.
This class implements a complete transformer block with optional convolution layer for channel adjustment, learnable
position embedding, and multiple transformer layers.
@ -297,8 +285,7 @@ class TransformerBlock(nn.Module):
"""
def __init__(self, c1: int, c2: int, num_heads: int, num_layers: int):
"""
Initialize a Transformer module with position embedding and specified number of heads and layers.
"""Initialize a Transformer module with position embedding and specified number of heads and layers.
Args:
c1 (int): Input channel dimension.
@ -315,8 +302,7 @@ class TransformerBlock(nn.Module):
self.c2 = c2
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward propagate the input through the transformer block.
"""Forward propagate the input through the transformer block.
Args:
x (torch.Tensor): Input tensor with shape [b, c1, w, h].
@ -335,8 +321,7 @@ class MLPBlock(nn.Module):
"""A single block of a multi-layer perceptron."""
def __init__(self, embedding_dim: int, mlp_dim: int, act=nn.GELU):
"""
Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.
"""Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.
Args:
embedding_dim (int): Input and output dimension.
@ -349,8 +334,7 @@ class MLPBlock(nn.Module):
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MLPBlock.
"""Forward pass for the MLPBlock.
Args:
x (torch.Tensor): Input tensor.
@ -362,8 +346,7 @@ class MLPBlock(nn.Module):
class MLP(nn.Module):
"""
A simple multi-layer perceptron (also called FFN).
"""A simple multi-layer perceptron (also called FFN).
This class implements a configurable MLP with multiple linear layers, activation functions, and optional sigmoid
output activation.
@ -378,8 +361,7 @@ class MLP(nn.Module):
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act=nn.ReLU, sigmoid: bool = False
):
"""
Initialize the MLP with specified input, hidden, output dimensions and number of layers.
"""Initialize the MLP with specified input, hidden, output dimensions and number of layers.
Args:
input_dim (int): Input dimension.
@ -397,8 +379,7 @@ class MLP(nn.Module):
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the entire MLP.
"""Forward pass for the entire MLP.
Args:
x (torch.Tensor): Input tensor.
@ -412,8 +393,7 @@ class MLP(nn.Module):
class LayerNorm2d(nn.Module):
"""
2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.
"""2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.
This class implements layer normalization for 2D feature maps, normalizing across the channel dimension while
preserving spatial dimensions.
@ -429,8 +409,7 @@ class LayerNorm2d(nn.Module):
"""
def __init__(self, num_channels: int, eps: float = 1e-6):
"""
Initialize LayerNorm2d with the given parameters.
"""Initialize LayerNorm2d with the given parameters.
Args:
num_channels (int): Number of channels in the input.
@ -442,8 +421,7 @@ class LayerNorm2d(nn.Module):
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform forward pass for 2D layer normalization.
"""Perform forward pass for 2D layer normalization.
Args:
x (torch.Tensor): Input tensor.
@ -458,8 +436,7 @@ class LayerNorm2d(nn.Module):
class MSDeformAttn(nn.Module):
"""
Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.
"""Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.
This module implements multiscale deformable attention that can attend to features at multiple scales with learnable
sampling locations and attention weights.
@ -480,8 +457,7 @@ class MSDeformAttn(nn.Module):
"""
def __init__(self, d_model: int = 256, n_levels: int = 4, n_heads: int = 8, n_points: int = 4):
"""
Initialize MSDeformAttn with the given parameters.
"""Initialize MSDeformAttn with the given parameters.
Args:
d_model (int): Model dimension.
@ -539,13 +515,12 @@ class MSDeformAttn(nn.Module):
value_shapes: list,
value_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Perform forward pass for multiscale deformable attention.
"""Perform forward pass for multiscale deformable attention.
Args:
query (torch.Tensor): Query tensor with shape [bs, query_length, C].
refer_bbox (torch.Tensor): Reference bounding boxes with shape [bs, query_length, n_levels, 2],
range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area.
refer_bbox (torch.Tensor): Reference bounding boxes with shape [bs, query_length, n_levels, 2], range in [0,
1], top-left (0,0), bottom-right (1, 1), including padding area.
value (torch.Tensor): Value tensor with shape [bs, value_length, C].
value_shapes (list): List with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})].
value_mask (torch.Tensor, optional): Mask tensor with shape [bs, value_length], True for non-padding
@ -584,8 +559,7 @@ class MSDeformAttn(nn.Module):
class DeformableTransformerDecoderLayer(nn.Module):
"""
Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.
"""Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.
This class implements a single decoder layer with self-attention, cross-attention using multiscale deformable
attention, and a feedforward network.
@ -619,8 +593,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
n_levels: int = 4,
n_points: int = 4,
):
"""
Initialize the DeformableTransformerDecoderLayer with the given parameters.
"""Initialize the DeformableTransformerDecoderLayer with the given parameters.
Args:
d_model (int): Model dimension.
@ -657,8 +630,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt: torch.Tensor) -> torch.Tensor:
"""
Perform forward pass through the Feed-Forward Network part of the layer.
"""Perform forward pass through the Feed-Forward Network part of the layer.
Args:
tgt (torch.Tensor): Input tensor.
@ -680,8 +652,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
attn_mask: torch.Tensor | None = None,
query_pos: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Perform the forward pass through the entire decoder layer.
"""Perform the forward pass through the entire decoder layer.
Args:
embed (torch.Tensor): Input embeddings.
@ -715,8 +686,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
class DeformableTransformerDecoder(nn.Module):
"""
Deformable Transformer Decoder based on PaddleDetection implementation.
"""Deformable Transformer Decoder based on PaddleDetection implementation.
This class implements a complete deformable transformer decoder with multiple decoder layers and prediction heads
for bounding box regression and classification.
@ -732,8 +702,7 @@ class DeformableTransformerDecoder(nn.Module):
"""
def __init__(self, hidden_dim: int, decoder_layer: nn.Module, num_layers: int, eval_idx: int = -1):
"""
Initialize the DeformableTransformerDecoder with the given parameters.
"""Initialize the DeformableTransformerDecoder with the given parameters.
Args:
hidden_dim (int): Hidden dimension.
@ -759,8 +728,7 @@ class DeformableTransformerDecoder(nn.Module):
attn_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
):
"""
Perform the forward pass through the entire decoder.
"""Perform the forward pass through the entire decoder.
Args:
embed (torch.Tensor): Decoder embeddings.

View file

@ -13,8 +13,7 @@ __all__ = "inverse_sigmoid", "multi_scale_deformable_attn_pytorch"
def _get_clones(module, n):
"""
Create a list of cloned modules from the given module.
"""Create a list of cloned modules from the given module.
Args:
module (nn.Module): The module to be cloned.
@ -34,8 +33,7 @@ def _get_clones(module, n):
def bias_init_with_prob(prior_prob=0.01):
"""
Initialize conv/fc bias value according to a given probability value.
"""Initialize conv/fc bias value according to a given probability value.
This function calculates the bias initialization value based on a prior probability using the inverse error
function. It's commonly used in object detection models to initialize classification layers with a specific positive
@ -56,8 +54,7 @@ def bias_init_with_prob(prior_prob=0.01):
def linear_init(module):
"""
Initialize the weights and biases of a linear module.
"""Initialize the weights and biases of a linear module.
This function initializes the weights of a linear module using a uniform distribution within bounds calculated from
the input dimension. If the module has a bias, it is also initialized.
@ -80,8 +77,7 @@ def linear_init(module):
def inverse_sigmoid(x, eps=1e-5):
"""
Calculate the inverse sigmoid function for a tensor.
"""Calculate the inverse sigmoid function for a tensor.
This function applies the inverse of the sigmoid function to a tensor, which is useful in various neural network
operations, particularly in attention mechanisms and coordinate transformations.
@ -110,8 +106,7 @@ def multi_scale_deformable_attn_pytorch(
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
) -> torch.Tensor:
"""
Implement multi-scale deformable attention in PyTorch.
"""Implement multi-scale deformable attention in PyTorch.
This function performs deformable attention across multiple feature map scales, allowing the model to attend to
different spatial locations with learned offsets.
@ -119,10 +114,10 @@ def multi_scale_deformable_attn_pytorch(
Args:
value (torch.Tensor): The value tensor with shape (bs, num_keys, num_heads, embed_dims).
value_spatial_shapes (torch.Tensor): Spatial shapes of the value tensor with shape (num_levels, 2).
sampling_locations (torch.Tensor): The sampling locations with shape
(bs, num_queries, num_heads, num_levels, num_points, 2).
attention_weights (torch.Tensor): The attention weights with shape
(bs, num_queries, num_heads, num_levels, num_points).
sampling_locations (torch.Tensor): The sampling locations with shape (bs, num_queries, num_heads, num_levels,
num_points, 2).
attention_weights (torch.Tensor): The attention weights with shape (bs, num_queries, num_heads, num_levels,
num_points).
Returns:
(torch.Tensor): The output tensor with shape (bs, num_queries, embed_dims).

View file

@ -95,8 +95,7 @@ from ultralytics.utils.torch_utils import (
class BaseModel(torch.nn.Module):
"""
Base class for all YOLO models in the Ultralytics family.
"""Base class for all YOLO models in the Ultralytics family.
This class provides common functionality for YOLO models including forward pass handling, model fusion, information
display, and weight loading capabilities.
@ -121,8 +120,7 @@ class BaseModel(torch.nn.Module):
"""
def forward(self, x, *args, **kwargs):
"""
Perform forward pass of the model for either training or inference.
"""Perform forward pass of the model for either training or inference.
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
@ -139,8 +137,7 @@ class BaseModel(torch.nn.Module):
return self.predict(x, *args, **kwargs)
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
"""
Perform a forward pass through the network.
"""Perform a forward pass through the network.
Args:
x (torch.Tensor): The input tensor to the model.
@ -157,8 +154,7 @@ class BaseModel(torch.nn.Module):
return self._predict_once(x, profile, visualize, embed)
def _predict_once(self, x, profile=False, visualize=False, embed=None):
"""
Perform a forward pass through the network.
"""Perform a forward pass through the network.
Args:
x (torch.Tensor): The input tensor to the model.
@ -196,8 +192,7 @@ class BaseModel(torch.nn.Module):
return self._predict_once(x)
def _profile_one_layer(self, m, x, dt):
"""
Profile the computation time and FLOPs of a single layer of the model on a given input.
"""Profile the computation time and FLOPs of a single layer of the model on a given input.
Args:
m (torch.nn.Module): The layer to be profiled.
@ -222,8 +217,7 @@ class BaseModel(torch.nn.Module):
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
def fuse(self, verbose=True):
"""
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
"""Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
efficiency.
Returns:
@ -254,8 +248,7 @@ class BaseModel(torch.nn.Module):
return self
def is_fused(self, thresh=10):
"""
Check if the model has less than a certain threshold of BatchNorm layers.
"""Check if the model has less than a certain threshold of BatchNorm layers.
Args:
thresh (int, optional): The threshold number of BatchNorm layers.
@ -267,8 +260,7 @@ class BaseModel(torch.nn.Module):
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
def info(self, detailed=False, verbose=True, imgsz=640):
"""
Print model information.
"""Print model information.
Args:
detailed (bool): If True, prints out detailed information about the model.
@ -278,8 +270,7 @@ class BaseModel(torch.nn.Module):
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
def _apply(self, fn):
"""
Apply a function to all tensors in the model that are not parameters or registered buffers.
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
Args:
fn (function): The function to apply to the model.
@ -298,8 +289,7 @@ class BaseModel(torch.nn.Module):
return self
def load(self, weights, verbose=True):
"""
Load weights into the model.
"""Load weights into the model.
Args:
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
@ -324,8 +314,7 @@ class BaseModel(torch.nn.Module):
LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")
def loss(self, batch, preds=None):
"""
Compute loss.
"""Compute loss.
Args:
batch (dict): Batch to compute loss on.
@ -344,8 +333,7 @@ class BaseModel(torch.nn.Module):
class DetectionModel(BaseModel):
"""
YOLO detection model.
"""YOLO detection model.
This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
inference, and loss computation for object detection tasks.
@ -373,8 +361,7 @@ class DetectionModel(BaseModel):
"""
def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
"""
Initialize the YOLO detection model with the given config and parameters.
"""Initialize the YOLO detection model with the given config and parameters.
Args:
cfg (str | dict): Model configuration file path or dictionary.
@ -429,8 +416,7 @@ class DetectionModel(BaseModel):
LOGGER.info("")
def _predict_augment(self, x):
"""
Perform augmentations on input image x and return augmented inference and train outputs.
"""Perform augmentations on input image x and return augmented inference and train outputs.
Args:
x (torch.Tensor): Input image tensor.
@ -455,8 +441,7 @@ class DetectionModel(BaseModel):
@staticmethod
def _descale_pred(p, flips, scale, img_size, dim=1):
"""
De-scale predictions following augmented inference (inverse operation).
"""De-scale predictions following augmented inference (inverse operation).
Args:
p (torch.Tensor): Predictions tensor.
@ -477,8 +462,7 @@ class DetectionModel(BaseModel):
return torch.cat((x, y, wh, cls), dim)
def _clip_augmented(self, y):
"""
Clip YOLO augmented inference tails.
"""Clip YOLO augmented inference tails.
Args:
y (list[torch.Tensor]): List of detection tensors.
@ -501,8 +485,7 @@ class DetectionModel(BaseModel):
class OBBModel(DetectionModel):
"""
YOLO Oriented Bounding Box (OBB) model.
"""YOLO Oriented Bounding Box (OBB) model.
This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
computation for rotated object detection.
@ -518,8 +501,7 @@ class OBBModel(DetectionModel):
"""
def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
"""
Initialize YOLO OBB model with given config and parameters.
"""Initialize YOLO OBB model with given config and parameters.
Args:
cfg (str | dict): Model configuration file path or dictionary.
@ -535,8 +517,7 @@ class OBBModel(DetectionModel):
class SegmentationModel(DetectionModel):
"""
YOLO segmentation model.
"""YOLO segmentation model.
This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
pixel-level object detection and segmentation.
@ -552,8 +533,7 @@ class SegmentationModel(DetectionModel):
"""
def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
"""
Initialize Ultralytics YOLO segmentation model with given config and parameters.
"""Initialize Ultralytics YOLO segmentation model with given config and parameters.
Args:
cfg (str | dict): Model configuration file path or dictionary.
@ -569,8 +549,7 @@ class SegmentationModel(DetectionModel):
class PoseModel(DetectionModel):
"""
YOLO pose model.
"""YOLO pose model.
This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
keypoint detection and pose estimation.
@ -589,8 +568,7 @@ class PoseModel(DetectionModel):
"""
def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
"""
Initialize Ultralytics YOLO Pose model.
"""Initialize Ultralytics YOLO Pose model.
Args:
cfg (str | dict): Model configuration file path or dictionary.
@ -612,8 +590,7 @@ class PoseModel(DetectionModel):
class ClassificationModel(BaseModel):
"""
YOLO classification model.
"""YOLO classification model.
This class implements the YOLO classification architecture for image classification tasks, providing model
initialization, configuration, and output reshaping capabilities.
@ -637,8 +614,7 @@ class ClassificationModel(BaseModel):
"""
def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
"""
Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
"""Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
Args:
cfg (str | dict): Model configuration file path or dictionary.
@ -650,8 +626,7 @@ class ClassificationModel(BaseModel):
self._from_yaml(cfg, ch, nc, verbose)
def _from_yaml(self, cfg, ch, nc, verbose):
"""
Set Ultralytics YOLO model configurations and define the model architecture.
"""Set Ultralytics YOLO model configurations and define the model architecture.
Args:
cfg (str | dict): Model configuration file path or dictionary.
@ -675,8 +650,7 @@ class ClassificationModel(BaseModel):
@staticmethod
def reshape_outputs(model, nc):
"""
Update a TorchVision classification model to class count 'n' if required.
"""Update a TorchVision classification model to class count 'n' if required.
Args:
model (torch.nn.Module): Model to update.
@ -708,8 +682,7 @@ class ClassificationModel(BaseModel):
class RTDETRDetectionModel(DetectionModel):
"""
RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
"""RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
the training and inference processes. RTDETR is an object detection and tracking model that extends from the
@ -732,8 +705,7 @@ class RTDETRDetectionModel(DetectionModel):
"""
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
"""
Initialize the RTDETRDetectionModel.
"""Initialize the RTDETRDetectionModel.
Args:
cfg (str | dict): Configuration file name or path.
@ -744,8 +716,7 @@ class RTDETRDetectionModel(DetectionModel):
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def _apply(self, fn):
"""
Apply a function to all tensors in the model that are not parameters or registered buffers.
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
Args:
fn (function): The function to apply to the model.
@ -766,8 +737,7 @@ class RTDETRDetectionModel(DetectionModel):
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
def loss(self, batch, preds=None):
"""
Compute the loss for the given batch of data.
"""Compute the loss for the given batch of data.
Args:
batch (dict): Dictionary containing image and label data.
@ -813,8 +783,7 @@ class RTDETRDetectionModel(DetectionModel):
)
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
"""
Perform a forward pass through the model.
"""Perform a forward pass through the model.
Args:
x (torch.Tensor): The input tensor.
@ -849,8 +818,7 @@ class RTDETRDetectionModel(DetectionModel):
class WorldModel(DetectionModel):
"""
YOLOv8 World Model.
"""YOLOv8 World Model.
This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
specification and CLIP model integration for zero-shot detection capabilities.
@ -874,8 +842,7 @@ class WorldModel(DetectionModel):
"""
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
"""
Initialize YOLOv8 world model with given config and parameters.
"""Initialize YOLOv8 world model with given config and parameters.
Args:
cfg (str | dict): Model configuration file path or dictionary.
@ -888,8 +855,7 @@ class WorldModel(DetectionModel):
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def set_classes(self, text, batch=80, cache_clip_model=True):
"""
Set classes in advance so that model could do offline-inference without clip model.
"""Set classes in advance so that model could do offline-inference without clip model.
Args:
text (list[str]): List of class names.
@ -900,8 +866,7 @@ class WorldModel(DetectionModel):
self.model[-1].nc = len(text)
def get_text_pe(self, text, batch=80, cache_clip_model=True):
"""
Set classes in advance so that model could do offline-inference without clip model.
"""Set classes in advance so that model could do offline-inference without clip model.
Args:
text (list[str]): List of class names.
@ -924,8 +889,7 @@ class WorldModel(DetectionModel):
return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
"""
Perform a forward pass through the model.
"""Perform a forward pass through the model.
Args:
x (torch.Tensor): The input tensor.
@ -969,8 +933,7 @@ class WorldModel(DetectionModel):
return x
def loss(self, batch, preds=None):
"""
Compute loss.
"""Compute loss.
Args:
batch (dict): Batch to compute loss on.
@ -985,8 +948,7 @@ class WorldModel(DetectionModel):
class YOLOEModel(DetectionModel):
"""
YOLOE detection model.
"""YOLOE detection model.
This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
both prompt-based and prompt-free inference modes.
@ -1013,8 +975,7 @@ class YOLOEModel(DetectionModel):
"""
def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
"""
Initialize YOLOE model with given config and parameters.
"""Initialize YOLOE model with given config and parameters.
Args:
cfg (str | dict): Model configuration file path or dictionary.
@ -1026,8 +987,7 @@ class YOLOEModel(DetectionModel):
@smart_inference_mode()
def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
"""
Set classes in advance so that model could do offline-inference without clip model.
"""Set classes in advance so that model could do offline-inference without clip model.
Args:
text (list[str]): List of class names.
@ -1059,8 +1019,7 @@ class YOLOEModel(DetectionModel):
@smart_inference_mode()
def get_visual_pe(self, img, visual):
"""
Get visual embeddings.
"""Get visual embeddings.
Args:
img (torch.Tensor): Input image tensor.
@ -1072,8 +1031,7 @@ class YOLOEModel(DetectionModel):
return self(img, vpe=visual, return_vpe=True)
def set_vocab(self, vocab, names):
"""
Set vocabulary for the prompt-free model.
"""Set vocabulary for the prompt-free model.
Args:
vocab (nn.ModuleList): List of vocabulary items.
@ -1101,8 +1059,7 @@ class YOLOEModel(DetectionModel):
self.names = check_class_names(names)
def get_vocab(self, names):
"""
Get fused vocabulary layer from the model.
"""Get fused vocabulary layer from the model.
Args:
names (list): List of class names.
@ -1127,8 +1084,7 @@ class YOLOEModel(DetectionModel):
return vocab
def set_classes(self, names, embeddings):
"""
Set classes in advance so that model could do offline-inference without clip model.
"""Set classes in advance so that model could do offline-inference without clip model.
Args:
names (list[str]): List of class names.
@ -1143,8 +1099,7 @@ class YOLOEModel(DetectionModel):
self.names = check_class_names(names)
def get_cls_pe(self, tpe, vpe):
"""
Get class positional embeddings.
"""Get class positional embeddings.
Args:
tpe (torch.Tensor, optional): Text positional embeddings.
@ -1167,8 +1122,7 @@ class YOLOEModel(DetectionModel):
def predict(
self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
):
"""
Perform a forward pass through the model.
"""Perform a forward pass through the model.
Args:
x (torch.Tensor): The input tensor.
@ -1215,8 +1169,7 @@ class YOLOEModel(DetectionModel):
return x
def loss(self, batch, preds=None):
"""
Compute loss.
"""Compute loss.
Args:
batch (dict): Batch to compute loss on.
@ -1234,8 +1187,7 @@ class YOLOEModel(DetectionModel):
class YOLOESegModel(YOLOEModel, SegmentationModel):
"""
YOLOE segmentation model.
"""YOLOE segmentation model.
This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
specialized loss computation for pixel-level object detection and segmentation.
@ -1251,8 +1203,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
"""
def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
"""
Initialize YOLOE segmentation model with given config and parameters.
"""Initialize YOLOE segmentation model with given config and parameters.
Args:
cfg (str | dict): Model configuration file path or dictionary.
@ -1263,8 +1214,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def loss(self, batch, preds=None):
"""
Compute loss.
"""Compute loss.
Args:
batch (dict): Batch to compute loss on.
@ -1282,8 +1232,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
class Ensemble(torch.nn.ModuleList):
"""
Ensemble of models.
"""Ensemble of models.
This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
or other ensemble techniques.
@ -1305,8 +1254,7 @@ class Ensemble(torch.nn.ModuleList):
super().__init__()
def forward(self, x, augment=False, profile=False, visualize=False):
"""
Generate the YOLO network's final layer.
"""Generate the YOLO network's final layer.
Args:
x (torch.Tensor): Input tensor.
@ -1330,8 +1278,7 @@ class Ensemble(torch.nn.ModuleList):
@contextlib.contextmanager
def temporary_modules(modules=None, attributes=None):
"""
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
"""Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
moved a module from one location to another, but you still want to support the old import paths for backwards
@ -1393,8 +1340,7 @@ class SafeUnpickler(pickle.Unpickler):
"""Custom Unpickler that replaces unknown classes with SafeClass."""
def find_class(self, module, name):
"""
Attempt to find a class, returning SafeClass if not among safe modules.
"""Attempt to find a class, returning SafeClass if not among safe modules.
Args:
module (str): Module name.
@ -1419,10 +1365,9 @@ class SafeUnpickler(pickle.Unpickler):
def torch_safe_load(weight, safe_only=False):
"""
Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
After installation, the function again attempts to load the model using torch.load().
"""Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
the error, logs a warning message, and attempts to install the missing module via the check_requirements()
function. After installation, the function again attempts to load the model using torch.load().
Args:
weight (str): The file path of the PyTorch model.
@ -1501,8 +1446,7 @@ def torch_safe_load(weight, safe_only=False):
def load_checkpoint(weight, device=None, inplace=True, fuse=False):
"""
Load a single model weights.
"""Load a single model weights.
Args:
weight (str | Path): Model weight path.
@ -1539,8 +1483,7 @@ def load_checkpoint(weight, device=None, inplace=True, fuse=False):
def parse_model(d, ch, verbose=True):
"""
Parse a YOLO model.yaml dictionary into a PyTorch model.
"""Parse a YOLO model.yaml dictionary into a PyTorch model.
Args:
d (dict): Model dictionary.
@ -1718,8 +1661,7 @@ def parse_model(d, ch, verbose=True):
def yaml_model_load(path):
"""
Load a YOLOv8 model from a YAML file.
"""Load a YOLOv8 model from a YAML file.
Args:
path (str | Path): Path to the YAML file.
@ -1742,8 +1684,7 @@ def yaml_model_load(path):
def guess_model_scale(model_path):
"""
Extract the size character n, s, m, l, or x of the model's scale from the model path.
"""Extract the size character n, s, m, l, or x of the model's scale from the model path.
Args:
model_path (str | Path): The path to the YOLO model's YAML file.
@ -1758,8 +1699,7 @@ def guess_model_scale(model_path):
def guess_model_task(model):
"""
Guess the task of a PyTorch model from its architecture or configuration.
"""Guess the task of a PyTorch model from its architecture or configuration.
Args:
model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.

View file

@ -20,8 +20,7 @@ except ImportError:
class TextModel(nn.Module):
"""
Abstract base class for text encoding models.
"""Abstract base class for text encoding models.
This class defines the interface for text encoding models used in vision-language tasks. Subclasses must implement
the tokenize and encode_text methods to provide text tokenization and encoding functionality.
@ -47,8 +46,7 @@ class TextModel(nn.Module):
class CLIP(TextModel):
"""
Implements OpenAI's CLIP (Contrastive Language-Image Pre-training) text encoder.
"""Implements OpenAI's CLIP (Contrastive Language-Image Pre-training) text encoder.
This class provides a text encoder based on OpenAI's CLIP model, which can convert text into feature vectors that
are aligned with corresponding image features in a shared embedding space.
@ -71,8 +69,7 @@ class CLIP(TextModel):
"""
def __init__(self, size: str, device: torch.device) -> None:
"""
Initialize the CLIP text encoder.
"""Initialize the CLIP text encoder.
This class implements the TextModel interface using OpenAI's CLIP model for text encoding. It loads a
pre-trained CLIP model of the specified size and prepares it for text encoding tasks.
@ -93,8 +90,7 @@ class CLIP(TextModel):
self.eval()
def tokenize(self, texts: str | list[str]) -> torch.Tensor:
"""
Convert input texts to CLIP tokens.
"""Convert input texts to CLIP tokens.
Args:
texts (str | list[str]): Input text or list of texts to tokenize.
@ -111,8 +107,7 @@ class CLIP(TextModel):
@smart_inference_mode()
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""
Encode tokenized texts into normalized feature vectors.
"""Encode tokenized texts into normalized feature vectors.
This method processes tokenized text inputs through the CLIP model to generate feature vectors, which are then
normalized to unit length. These normalized vectors can be used for text-image similarity comparisons.
@ -137,15 +132,14 @@ class CLIP(TextModel):
@smart_inference_mode()
def encode_image(self, image: Image.Image | torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""
Encode preprocessed images into normalized feature vectors.
"""Encode preprocessed images into normalized feature vectors.
This method processes preprocessed image inputs through the CLIP model to generate feature vectors, which are
then normalized to unit length. These normalized vectors can be used for text-image similarity comparisons.
Args:
image (PIL.Image | torch.Tensor): Preprocessed image input. If a PIL Image is provided, it will be
converted to a tensor using the model's image preprocessing function.
image (PIL.Image | torch.Tensor): Preprocessed image input. If a PIL Image is provided, it will be converted
to a tensor using the model's image preprocessing function.
dtype (torch.dtype, optional): Data type for output features.
Returns:
@ -169,8 +163,7 @@ class CLIP(TextModel):
class MobileCLIP(TextModel):
"""
Implement Apple's MobileCLIP text encoder for efficient text encoding.
"""Implement Apple's MobileCLIP text encoder for efficient text encoding.
This class implements the TextModel interface using Apple's MobileCLIP model, providing efficient text encoding
capabilities for vision-language tasks with reduced computational requirements compared to standard CLIP models.
@ -195,8 +188,7 @@ class MobileCLIP(TextModel):
config_size_map = {"s0": "s0", "s1": "s1", "s2": "s2", "b": "b", "blt": "b"}
def __init__(self, size: str, device: torch.device) -> None:
"""
Initialize the MobileCLIP text encoder.
"""Initialize the MobileCLIP text encoder.
This class implements the TextModel interface using Apple's MobileCLIP model for efficient text encoding.
@ -236,8 +228,7 @@ class MobileCLIP(TextModel):
self.eval()
def tokenize(self, texts: list[str]) -> torch.Tensor:
"""
Convert input texts to MobileCLIP tokens.
"""Convert input texts to MobileCLIP tokens.
Args:
texts (list[str]): List of text strings to tokenize.
@ -253,8 +244,7 @@ class MobileCLIP(TextModel):
@smart_inference_mode()
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""
Encode tokenized texts into normalized feature vectors.
"""Encode tokenized texts into normalized feature vectors.
Args:
texts (torch.Tensor): Tokenized text inputs.
@ -276,8 +266,7 @@ class MobileCLIP(TextModel):
class MobileCLIPTS(TextModel):
"""
Load a TorchScript traced version of MobileCLIP.
"""Load a TorchScript traced version of MobileCLIP.
This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format, providing
efficient text encoding capabilities for vision-language tasks with optimized inference performance.
@ -299,8 +288,7 @@ class MobileCLIPTS(TextModel):
"""
def __init__(self, device: torch.device):
"""
Initialize the MobileCLIP TorchScript text encoder.
"""Initialize the MobileCLIP TorchScript text encoder.
This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format for efficient
text encoding with optimized inference performance.
@ -321,8 +309,7 @@ class MobileCLIPTS(TextModel):
self.device = device
def tokenize(self, texts: list[str]) -> torch.Tensor:
"""
Convert input texts to MobileCLIP tokens.
"""Convert input texts to MobileCLIP tokens.
Args:
texts (list[str]): List of text strings to tokenize.
@ -338,8 +325,7 @@ class MobileCLIPTS(TextModel):
@smart_inference_mode()
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""
Encode tokenized texts into normalized feature vectors.
"""Encode tokenized texts into normalized feature vectors.
Args:
texts (torch.Tensor): Tokenized text inputs.
@ -360,8 +346,7 @@ class MobileCLIPTS(TextModel):
def build_text_model(variant: str, device: torch.device = None) -> TextModel:
"""
Build a text encoding model based on the specified variant.
"""Build a text encoding model based on the specified variant.
Args:
variant (str): Model variant in format "base:size" (e.g., "clip:ViT-B/32" or "mobileclip:s0").

View file

@ -7,8 +7,7 @@ from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, Sol
class AIGym(BaseSolution):
"""
A class to manage gym steps of people in a real-time video stream based on their poses.
"""A class to manage gym steps of people in a real-time video stream based on their poses.
This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts
repetitions of exercises based on predefined angle thresholds for up and down positions.
@ -32,8 +31,7 @@ class AIGym(BaseSolution):
"""
def __init__(self, **kwargs: Any) -> None:
"""
Initialize AIGym for workout monitoring using pose estimation and predefined angles.
"""Initialize AIGym for workout monitoring using pose estimation and predefined angles.
Args:
**kwargs (Any): Keyword arguments passed to the parent class constructor including:
@ -49,8 +47,7 @@ class AIGym(BaseSolution):
self.kpts = self.CFG["kpts"] # User selected kpts of workouts storage for further usage
def process(self, im0) -> SolutionResults:
"""
Monitor workouts using Ultralytics YOLO Pose Model.
"""Monitor workouts using Ultralytics YOLO Pose Model.
This function processes an input image to track and analyze human poses for workout monitoring. It uses the YOLO
Pose model to detect keypoints, estimate angles, and count repetitions based on predefined angle thresholds.

View file

@ -12,8 +12,7 @@ from ultralytics.solutions.solutions import BaseSolution, SolutionResults # Imp
class Analytics(BaseSolution):
"""
A class for creating and updating various types of charts for visual analytics.
"""A class for creating and updating various types of charts for visual analytics.
This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts based on
object detection and tracking data.
@ -92,8 +91,7 @@ class Analytics(BaseSolution):
self.ax.axis("equal")
def process(self, im0: np.ndarray, frame_number: int) -> SolutionResults:
"""
Process image data and run object tracking to update analytics charts.
"""Process image data and run object tracking to update analytics charts.
Args:
im0 (np.ndarray): Input image for processing.
@ -139,13 +137,12 @@ class Analytics(BaseSolution):
def update_graph(
self, frame_number: int, count_dict: dict[str, int] | None = None, plot: str = "line"
) -> np.ndarray:
"""
Update the graph with new data for single or multiple classes.
"""Update the graph with new data for single or multiple classes.
Args:
frame_number (int): The current frame number.
count_dict (dict[str, int], optional): Dictionary with class names as keys and counts as values for
multiple classes. If None, updates a single line graph.
count_dict (dict[str, int], optional): Dictionary with class names as keys and counts as values for multiple
classes. If None, updates a single line graph.
plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'.
Returns:

View file

@ -10,8 +10,7 @@ import cv2
@dataclass
class SolutionConfig:
"""
Manages configuration parameters for Ultralytics Vision AI solutions.
"""Manages configuration parameters for Ultralytics Vision AI solutions.
The SolutionConfig class serves as a centralized configuration container for all the Ultralytics solution modules:
https://docs.ultralytics.com/solutions/#solutions. It leverages Python `dataclass` for clear, type-safe, and

View file

@ -10,8 +10,7 @@ from ultralytics.utils.plotting import colors
class DistanceCalculation(BaseSolution):
"""
A class to calculate distance between two objects in a real-time video stream based on their tracks.
"""A class to calculate distance between two objects in a real-time video stream based on their tracks.
This class extends BaseSolution to provide functionality for selecting objects and calculating the distance between
them in a video stream using YOLO object detection and tracking.
@ -43,8 +42,7 @@ class DistanceCalculation(BaseSolution):
self.centroids: list[list[int]] = [] # Store centroids of selected objects
def mouse_event_for_distance(self, event: int, x: int, y: int, flags: int, param: Any) -> None:
"""
Handle mouse events to select regions in a real-time video stream for distance calculation.
"""Handle mouse events to select regions in a real-time video stream for distance calculation.
Args:
event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN).
@ -69,8 +67,7 @@ class DistanceCalculation(BaseSolution):
self.left_mouse_count = 0
def process(self, im0) -> SolutionResults:
"""
Process a video frame and calculate the distance between two selected bounding boxes.
"""Process a video frame and calculate the distance between two selected bounding boxes.
This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance between
two user-selected objects if they have been chosen.
@ -79,8 +76,8 @@ class DistanceCalculation(BaseSolution):
im0 (np.ndarray): The input image frame to process.
Returns:
(SolutionResults): Contains processed image `plot_im`, `total_tracks` (int) representing the total number
of tracked objects, and `pixels_distance` (float) representing the distance between selected objects
(SolutionResults): Contains processed image `plot_im`, `total_tracks` (int) representing the total number of
tracked objects, and `pixels_distance` (float) representing the distance between selected objects
in pixels.
Examples:

View file

@ -12,8 +12,7 @@ from ultralytics.solutions.solutions import SolutionAnnotator, SolutionResults
class Heatmap(ObjectCounter):
"""
A class to draw heatmaps in real-time video streams based on object tracks.
"""A class to draw heatmaps in real-time video streams based on object tracks.
This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video
streams. It uses tracked object positions to create a cumulative heatmap effect over time.
@ -36,8 +35,7 @@ class Heatmap(ObjectCounter):
"""
def __init__(self, **kwargs: Any) -> None:
"""
Initialize the Heatmap class for real-time video stream heatmap generation based on object tracks.
"""Initialize the Heatmap class for real-time video stream heatmap generation based on object tracks.
Args:
**kwargs (Any): Keyword arguments passed to the parent ObjectCounter class.
@ -53,8 +51,7 @@ class Heatmap(ObjectCounter):
self.heatmap = None
def heatmap_effect(self, box: list[float]) -> None:
"""
Efficiently calculate heatmap area and effect location for applying colormap.
"""Efficiently calculate heatmap area and effect location for applying colormap.
Args:
box (list[float]): Bounding box coordinates [x0, y0, x1, y1].
@ -75,17 +72,15 @@ class Heatmap(ObjectCounter):
self.heatmap[y0:y1, x0:x1][within_radius] += 2
def process(self, im0: np.ndarray) -> SolutionResults:
"""
Generate heatmap for each frame using Ultralytics tracking.
"""Generate heatmap for each frame using Ultralytics tracking.
Args:
im0 (np.ndarray): Input image array for processing.
Returns:
(SolutionResults): Contains processed image `plot_im`,
'in_count' (int, count of objects entering the region), 'out_count' (int, count of objects exiting the
region), 'classwise_count' (dict, per-class object count), and 'total_tracks' (int, total number of
tracked objects).
(SolutionResults): Contains processed image `plot_im`, 'in_count' (int, count of objects entering the
region), 'out_count' (int, count of objects exiting the region), 'classwise_count' (dict, per-class
object count), and 'total_tracks' (int, total number of tracked objects).
"""
if not self.initialized:
self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99

Some files were not shown because too many files have changed in this diff Show more