mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
Update Google-style docstrings (#22565)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
b9a1365450
commit
0aef7a9a51
159 changed files with 1931 additions and 3432 deletions
|
|
@ -41,8 +41,7 @@ This example illustrates a Google-style docstring. Ensure that both input and ou
|
|||
|
||||
```python
|
||||
def example_function(arg1, arg2=4):
|
||||
"""
|
||||
Example function demonstrating Google-style docstrings.
|
||||
"""Example function demonstrating Google-style docstrings.
|
||||
|
||||
Args:
|
||||
arg1 (int): The first argument.
|
||||
|
|
@ -65,8 +64,7 @@ This example includes both a Google-style docstring and [type hints](https://doc
|
|||
|
||||
```python
|
||||
def example_function(arg1: int, arg2: int = 4) -> bool:
|
||||
"""
|
||||
Example function demonstrating Google-style docstrings.
|
||||
"""Example function demonstrating Google-style docstrings.
|
||||
|
||||
Args:
|
||||
arg1: The first argument.
|
||||
|
|
|
|||
|
|
@ -262,8 +262,7 @@ def remove_macros():
|
|||
|
||||
|
||||
def remove_comments_and_empty_lines(content: str, file_type: str) -> str:
|
||||
"""
|
||||
Remove comments and empty lines from a string of code, preserving newlines and URLs.
|
||||
"""Remove comments and empty lines from a string of code, preserving newlines and URLs.
|
||||
|
||||
Args:
|
||||
content (str): Code content to process.
|
||||
|
|
|
|||
|
|
@ -401,8 +401,7 @@ import ros_numpy
|
|||
|
||||
|
||||
def pointcloud2_to_array(pointcloud2: PointCloud2) -> tuple:
|
||||
"""
|
||||
Convert a ROS PointCloud2 message to a numpy array.
|
||||
"""Convert a ROS PointCloud2 message to a numpy array.
|
||||
|
||||
Args:
|
||||
pointcloud2 (PointCloud2): the PointCloud2 message
|
||||
|
|
@ -472,8 +471,7 @@ for index, class_id in enumerate(classes):
|
|||
|
||||
|
||||
def pointcloud2_to_array(pointcloud2: PointCloud2) -> tuple:
|
||||
"""
|
||||
Convert a ROS PointCloud2 message to a numpy array.
|
||||
"""Convert a ROS PointCloud2 message to a numpy array.
|
||||
|
||||
Args:
|
||||
pointcloud2 (PointCloud2): the PointCloud2 message
|
||||
|
|
|
|||
|
|
@ -58,8 +58,7 @@ When adding new functions or classes, include [Google-style docstrings](https://
|
|||
|
||||
```python
|
||||
def example_function(arg1, arg2=4):
|
||||
"""
|
||||
Example function demonstrating Google-style docstrings.
|
||||
"""Example function demonstrating Google-style docstrings.
|
||||
|
||||
Args:
|
||||
arg1 (int): The first argument.
|
||||
|
|
@ -81,8 +80,7 @@ When adding new functions or classes, include [Google-style docstrings](https://
|
|||
|
||||
```python
|
||||
def example_function(arg1, arg2=4):
|
||||
"""
|
||||
Example function demonstrating Google-style docstrings.
|
||||
"""Example function demonstrating Google-style docstrings.
|
||||
|
||||
Args:
|
||||
arg1 (int): The first argument.
|
||||
|
|
@ -104,8 +102,7 @@ When adding new functions or classes, include [Google-style docstrings](https://
|
|||
|
||||
```python
|
||||
def example_function(arg1, arg2=4):
|
||||
"""
|
||||
Example function demonstrating Google-style docstrings.
|
||||
"""Example function demonstrating Google-style docstrings.
|
||||
|
||||
Args:
|
||||
arg1 (int): The first argument.
|
||||
|
|
@ -146,8 +143,7 @@ When adding new functions or classes, include [Google-style docstrings](https://
|
|||
|
||||
```python
|
||||
def example_function(arg1: int, arg2: int = 4) -> bool:
|
||||
"""
|
||||
Example function demonstrating Google-style docstrings.
|
||||
"""Example function demonstrating Google-style docstrings.
|
||||
|
||||
Args:
|
||||
arg1: The first argument.
|
||||
|
|
|
|||
|
|
@ -359,8 +359,7 @@ Finally, after all threads have completed their task, the windows displaying the
|
|||
|
||||
|
||||
def run_tracker_in_thread(model_name, filename):
|
||||
"""
|
||||
Run YOLO tracker in its own thread for concurrent processing.
|
||||
"""Run YOLO tracker in its own thread for concurrent processing.
|
||||
|
||||
Args:
|
||||
model_name (str): The YOLO11 model object.
|
||||
|
|
@ -449,8 +448,7 @@ To run object tracking on multiple video streams simultaneously, you can use Pyt
|
|||
|
||||
|
||||
def run_tracker_in_thread(model_name, filename):
|
||||
"""
|
||||
Run YOLO tracker in its own thread for concurrent processing.
|
||||
"""Run YOLO tracker in its own thread for concurrent processing.
|
||||
|
||||
Args:
|
||||
model_name (str): The YOLO11 model object.
|
||||
|
|
|
|||
|
|
@ -13,8 +13,7 @@ import yaml
|
|||
|
||||
|
||||
def download_file(url: str, local_path: str) -> str:
|
||||
"""
|
||||
Download a file from a URL to a local path.
|
||||
"""Download a file from a URL to a local path.
|
||||
|
||||
Args:
|
||||
url (str): URL of the file to download.
|
||||
|
|
@ -34,8 +33,7 @@ def download_file(url: str, local_path: str) -> str:
|
|||
|
||||
|
||||
class RTDETR:
|
||||
"""
|
||||
RT-DETR (Real-Time Detection Transformer) object detection model for ONNX inference and visualization.
|
||||
"""RT-DETR (Real-Time Detection Transformer) object detection model for ONNX inference and visualization.
|
||||
|
||||
This class implements the RT-DETR model for object detection tasks, supporting ONNX model inference and
|
||||
visualization of detection results with bounding boxes and class labels.
|
||||
|
|
@ -77,16 +75,15 @@ class RTDETR:
|
|||
iou_thres: float = 0.5,
|
||||
class_names: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the RT-DETR object detection model.
|
||||
"""Initialize the RT-DETR object detection model.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the ONNX model file.
|
||||
img_path (str): Path to the input image.
|
||||
conf_thres (float, optional): Confidence threshold for filtering detections.
|
||||
iou_thres (float, optional): IoU threshold for non-maximum suppression.
|
||||
class_names (Optional[str], optional): Path to a YAML file containing class names.
|
||||
If None, uses COCO dataset classes.
|
||||
class_names (Optional[str], optional): Path to a YAML file containing class names. If None, uses COCO
|
||||
dataset classes.
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.img_path = img_path
|
||||
|
|
@ -157,8 +154,7 @@ class RTDETR:
|
|||
)
|
||||
|
||||
def preprocess(self) -> np.ndarray:
|
||||
"""
|
||||
Preprocess the input image for model inference.
|
||||
"""Preprocess the input image for model inference.
|
||||
|
||||
Loads the image, converts color space from BGR to RGB, resizes to model input dimensions, and normalizes pixel
|
||||
values to [0, 1] range.
|
||||
|
|
@ -190,12 +186,11 @@ class RTDETR:
|
|||
return image_data
|
||||
|
||||
def bbox_cxcywh_to_xyxy(self, boxes: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert bounding boxes from center format to corner format.
|
||||
"""Convert bounding boxes from center format to corner format.
|
||||
|
||||
Args:
|
||||
boxes (np.ndarray): Array of shape (N, 4) where each row represents a bounding box in
|
||||
(center_x, center_y, width, height) format.
|
||||
boxes (np.ndarray): Array of shape (N, 4) where each row represents a bounding box in (center_x, center_y,
|
||||
width, height) format.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Array of shape (N, 4) with bounding boxes in (x_min, y_min, x_max, y_max) format.
|
||||
|
|
@ -214,8 +209,7 @@ class RTDETR:
|
|||
return np.column_stack((x_min, y_min, x_max, y_max))
|
||||
|
||||
def postprocess(self, model_output: list[np.ndarray]) -> np.ndarray:
|
||||
"""
|
||||
Postprocess model output to extract and visualize detections.
|
||||
"""Postprocess model output to extract and visualize detections.
|
||||
|
||||
Applies confidence thresholding, converts bounding box format, scales coordinates to original image dimensions,
|
||||
and draws detection annotations.
|
||||
|
|
@ -255,8 +249,7 @@ class RTDETR:
|
|||
return self.img
|
||||
|
||||
def main(self) -> np.ndarray:
|
||||
"""
|
||||
Execute the complete object detection pipeline on the input image.
|
||||
"""Execute the complete object detection pipeline on the input image.
|
||||
|
||||
Performs preprocessing, ONNX model inference, and postprocessing to generate annotated detection results.
|
||||
|
||||
|
|
|
|||
|
|
@ -55,8 +55,7 @@ selected_center = None
|
|||
|
||||
|
||||
def get_center(x1: int, y1: int, x2: int, y2: int) -> tuple[int, int]:
|
||||
"""
|
||||
Calculate the center point of a bounding box.
|
||||
"""Calculate the center point of a bounding box.
|
||||
|
||||
Args:
|
||||
x1 (int): Top-left X coordinate.
|
||||
|
|
@ -72,8 +71,7 @@ def get_center(x1: int, y1: int, x2: int, y2: int) -> tuple[int, int]:
|
|||
|
||||
|
||||
def extend_line_from_edge(mid_x: int, mid_y: int, direction: str, img_shape: tuple[int, int, int]) -> tuple[int, int]:
|
||||
"""
|
||||
Calculate the endpoint to extend a line from the center toward an image edge.
|
||||
"""Calculate the endpoint to extend a line from the center toward an image edge.
|
||||
|
||||
Args:
|
||||
mid_x (int): X-coordinate of the midpoint.
|
||||
|
|
@ -99,8 +97,7 @@ def extend_line_from_edge(mid_x: int, mid_y: int, direction: str, img_shape: tup
|
|||
|
||||
|
||||
def draw_tracking_scope(im, bbox: tuple, color: tuple) -> None:
|
||||
"""
|
||||
Draw tracking scope lines extending from the bounding box to image edges.
|
||||
"""Draw tracking scope lines extending from the bounding box to image edges.
|
||||
|
||||
Args:
|
||||
im (np.ndarray): Image array to draw on.
|
||||
|
|
@ -119,8 +116,7 @@ def draw_tracking_scope(im, bbox: tuple, color: tuple) -> None:
|
|||
|
||||
|
||||
def click_event(event: int, x: int, y: int, flags: int, param) -> None:
|
||||
"""
|
||||
Handle mouse click events to select an object for focused tracking.
|
||||
"""Handle mouse click events to select an object for focused tracking.
|
||||
|
||||
Args:
|
||||
event (int): OpenCV mouse event type.
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ from ultralytics.utils.torch_utils import select_device
|
|||
|
||||
|
||||
class TorchVisionVideoClassifier:
|
||||
"""
|
||||
Video classifier using pretrained TorchVision models for action recognition.
|
||||
"""Video classifier using pretrained TorchVision models for action recognition.
|
||||
|
||||
This class provides an interface for video classification using various pretrained models from TorchVision's video
|
||||
model collection, supporting models like S3D, R3D, Swin3D, and MViT architectures.
|
||||
|
|
@ -72,8 +71,7 @@ class TorchVisionVideoClassifier:
|
|||
}
|
||||
|
||||
def __init__(self, model_name: str, device: str | torch.device = ""):
|
||||
"""
|
||||
Initialize the VideoClassifier with the specified model name and device.
|
||||
"""Initialize the VideoClassifier with the specified model name and device.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to use. Must be one of the available models.
|
||||
|
|
@ -87,8 +85,7 @@ class TorchVisionVideoClassifier:
|
|||
|
||||
@staticmethod
|
||||
def available_model_names() -> list[str]:
|
||||
"""
|
||||
Get the list of available model names.
|
||||
"""Get the list of available model names.
|
||||
|
||||
Returns:
|
||||
(list[str]): List of available model names that can be used with this classifier.
|
||||
|
|
@ -98,8 +95,7 @@ class TorchVisionVideoClassifier:
|
|||
def preprocess_crops_for_video_cls(
|
||||
self, crops: list[np.ndarray], input_size: list[int] | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess a list of crops for video classification.
|
||||
"""Preprocess a list of crops for video classification.
|
||||
|
||||
Args:
|
||||
crops (list[np.ndarray]): List of crops to preprocess. Each crop should have dimensions (H, W, C).
|
||||
|
|
@ -124,8 +120,7 @@ class TorchVisionVideoClassifier:
|
|||
return torch.stack(processed_crops).unsqueeze(0).permute(0, 2, 1, 3, 4).to(self.device)
|
||||
|
||||
def __call__(self, sequences: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Perform inference on the given sequences.
|
||||
"""Perform inference on the given sequences.
|
||||
|
||||
Args:
|
||||
sequences (torch.Tensor): The input sequences for the model with dimensions (B, T, C, H, W) for batched
|
||||
|
|
@ -138,8 +133,7 @@ class TorchVisionVideoClassifier:
|
|||
return self.model(sequences)
|
||||
|
||||
def postprocess(self, outputs: torch.Tensor) -> tuple[list[str], list[float]]:
|
||||
"""
|
||||
Postprocess the model's batch output.
|
||||
"""Postprocess the model's batch output.
|
||||
|
||||
Args:
|
||||
outputs (torch.Tensor): The model's output logits.
|
||||
|
|
@ -161,8 +155,7 @@ class TorchVisionVideoClassifier:
|
|||
|
||||
|
||||
class HuggingFaceVideoClassifier:
|
||||
"""
|
||||
Zero-shot video classifier using Hugging Face transformer models.
|
||||
"""Zero-shot video classifier using Hugging Face transformer models.
|
||||
|
||||
This class provides an interface for zero-shot video classification using Hugging Face models, supporting custom
|
||||
label sets and various transformer architectures for video understanding.
|
||||
|
|
@ -195,8 +188,7 @@ class HuggingFaceVideoClassifier:
|
|||
device: str | torch.device = "",
|
||||
fp16: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the HuggingFaceVideoClassifier with the specified model name.
|
||||
"""Initialize the HuggingFaceVideoClassifier with the specified model name.
|
||||
|
||||
Args:
|
||||
labels (list[str]): List of labels for zero-shot classification.
|
||||
|
|
@ -216,8 +208,7 @@ class HuggingFaceVideoClassifier:
|
|||
def preprocess_crops_for_video_cls(
|
||||
self, crops: list[np.ndarray], input_size: list[int] | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess a list of crops for video classification.
|
||||
"""Preprocess a list of crops for video classification.
|
||||
|
||||
Args:
|
||||
crops (list[np.ndarray]): List of crops to preprocess. Each crop should have dimensions (H, W, C).
|
||||
|
|
@ -247,8 +238,7 @@ class HuggingFaceVideoClassifier:
|
|||
return output
|
||||
|
||||
def __call__(self, sequences: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Perform inference on the given sequences.
|
||||
"""Perform inference on the given sequences.
|
||||
|
||||
Args:
|
||||
sequences (torch.Tensor): Batched input video frames with shape (B, T, H, W, C).
|
||||
|
|
@ -266,8 +256,7 @@ class HuggingFaceVideoClassifier:
|
|||
return outputs.logits_per_video
|
||||
|
||||
def postprocess(self, outputs: torch.Tensor) -> tuple[list[list[str]], list[list[float]]]:
|
||||
"""
|
||||
Postprocess the model's batch output.
|
||||
"""Postprocess the model's batch output.
|
||||
|
||||
Args:
|
||||
outputs (torch.Tensor): The model's output logits.
|
||||
|
|
@ -294,8 +283,7 @@ class HuggingFaceVideoClassifier:
|
|||
|
||||
|
||||
def crop_and_pad(frame: np.ndarray, box: list[float], margin_percent: int) -> np.ndarray:
|
||||
"""
|
||||
Crop box with margin and take square crop from frame.
|
||||
"""Crop box with margin and take square crop from frame.
|
||||
|
||||
Args:
|
||||
frame (np.ndarray): The input frame to crop from.
|
||||
|
|
@ -338,8 +326,7 @@ def run(
|
|||
video_classifier_model: str = "microsoft/xclip-base-patch32",
|
||||
labels: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run action recognition on a video source using YOLO for object detection and a video classifier.
|
||||
"""Run action recognition on a video source using YOLO for object detection and a video classifier.
|
||||
|
||||
Args:
|
||||
weights (str): Path to the YOLO model weights.
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@ from ultralytics.utils.checks import check_requirements, check_yaml
|
|||
|
||||
|
||||
class YOLOv8:
|
||||
"""
|
||||
YOLOv8 object detection model class for handling ONNX inference and visualization.
|
||||
"""YOLOv8 object detection model class for handling ONNX inference and visualization.
|
||||
|
||||
This class provides functionality to load a YOLOv8 ONNX model, perform inference on images, and visualize the
|
||||
detection results with bounding boxes and labels.
|
||||
|
|
@ -47,8 +46,7 @@ class YOLOv8:
|
|||
"""
|
||||
|
||||
def __init__(self, onnx_model: str, input_image: str, confidence_thres: float, iou_thres: float):
|
||||
"""
|
||||
Initialize an instance of the YOLOv8 class.
|
||||
"""Initialize an instance of the YOLOv8 class.
|
||||
|
||||
Args:
|
||||
onnx_model (str): Path to the ONNX model.
|
||||
|
|
@ -68,8 +66,7 @@ class YOLOv8:
|
|||
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
|
||||
|
||||
def letterbox(self, img: np.ndarray, new_shape: tuple[int, int] = (640, 640)) -> tuple[np.ndarray, tuple[int, int]]:
|
||||
"""
|
||||
Resize and reshape images while maintaining aspect ratio by adding padding.
|
||||
"""Resize and reshape images while maintaining aspect ratio by adding padding.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): Input image to be resized.
|
||||
|
|
@ -126,8 +123,7 @@ class YOLOv8:
|
|||
cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
||||
|
||||
def preprocess(self) -> tuple[np.ndarray, tuple[int, int]]:
|
||||
"""
|
||||
Preprocess the input image before performing inference.
|
||||
"""Preprocess the input image before performing inference.
|
||||
|
||||
This method reads the input image, converts its color space, applies letterboxing to maintain aspect ratio,
|
||||
normalizes pixel values, and prepares the image data for model input.
|
||||
|
|
@ -160,8 +156,7 @@ class YOLOv8:
|
|||
return image_data, pad
|
||||
|
||||
def postprocess(self, input_image: np.ndarray, output: list[np.ndarray], pad: tuple[int, int]) -> np.ndarray:
|
||||
"""
|
||||
Perform post-processing on the model's output to extract and visualize detections.
|
||||
"""Perform post-processing on the model's output to extract and visualize detections.
|
||||
|
||||
This method processes the raw model output to extract bounding boxes, scores, and class IDs. It applies
|
||||
non-maximum suppression to filter overlapping detections and draws the results on the input image.
|
||||
|
|
@ -234,8 +229,7 @@ class YOLOv8:
|
|||
return input_image
|
||||
|
||||
def main(self) -> np.ndarray:
|
||||
"""
|
||||
Perform inference using an ONNX model and return the output image with drawn detections.
|
||||
"""Perform inference using an ONNX model and return the output image with drawn detections.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output image with drawn detections.
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@ colors = np.random.uniform(0, 255, size=(len(CLASSES), 3))
|
|||
def draw_bounding_box(
|
||||
img: np.ndarray, class_id: int, confidence: float, x: int, y: int, x_plus_w: int, y_plus_h: int
|
||||
) -> None:
|
||||
"""
|
||||
Draw bounding boxes on the input image based on the provided arguments.
|
||||
"""Draw bounding boxes on the input image based on the provided arguments.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The input image to draw the bounding box on.
|
||||
|
|
@ -37,8 +36,7 @@ def draw_bounding_box(
|
|||
|
||||
|
||||
def main(onnx_model: str, input_image: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load ONNX model, perform inference, draw bounding boxes, and display the output image.
|
||||
"""Load ONNX model, perform inference, draw bounding boxes, and display the output image.
|
||||
|
||||
Args:
|
||||
onnx_model (str): Path to the ONNX model.
|
||||
|
|
|
|||
|
|
@ -40,8 +40,7 @@ counting_regions = [
|
|||
|
||||
|
||||
def mouse_callback(event: int, x: int, y: int, flags: int, param: Any) -> None:
|
||||
"""
|
||||
Handle mouse events for region manipulation in the video frame.
|
||||
"""Handle mouse events for region manipulation in the video frame.
|
||||
|
||||
This function enables interactive region selection and dragging functionality for counting regions. It responds to
|
||||
mouse button down, move, and up events to allow users to select and reposition counting regions in real-time.
|
||||
|
|
@ -97,8 +96,7 @@ def run(
|
|||
track_thickness: int = 2,
|
||||
region_thickness: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Run object detection and counting within specified regions using YOLO and ByteTrack.
|
||||
"""Run object detection and counting within specified regions using YOLO and ByteTrack.
|
||||
|
||||
This function performs real-time object detection, tracking, and counting within user-defined polygonal or
|
||||
rectangular regions. It supports interactive region manipulation, multiple counting areas, and both live viewing and
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ from ultralytics.utils.files import increment_path
|
|||
|
||||
|
||||
class SAHIInference:
|
||||
"""
|
||||
Runs Ultralytics YOLO11 and SAHI for object detection on video with options to view, save, and track results.
|
||||
"""Runs Ultralytics YOLO11 and SAHI for object detection on video with options to view, save, and track results.
|
||||
|
||||
This class integrates SAHI (Slicing Aided Hyper Inference) with YOLO11 models to perform efficient object detection
|
||||
on large images by slicing them into smaller pieces, running inference on each slice, and then merging the results.
|
||||
|
|
@ -36,8 +35,7 @@ class SAHIInference:
|
|||
self.detection_model = None
|
||||
|
||||
def load_model(self, weights: str, device: str) -> None:
|
||||
"""
|
||||
Load a YOLO11 model with specified weights for object detection using SAHI.
|
||||
"""Load a YOLO11 model with specified weights for object detection using SAHI.
|
||||
|
||||
Args:
|
||||
weights (str): Path to the model weights file.
|
||||
|
|
@ -63,8 +61,7 @@ class SAHIInference:
|
|||
slice_width: int = 512,
|
||||
slice_height: int = 512,
|
||||
) -> None:
|
||||
"""
|
||||
Run object detection on a video using YOLO11 and SAHI.
|
||||
"""Run object detection on a video using YOLO11 and SAHI.
|
||||
|
||||
The function processes each frame of the video, applies sliced inference using SAHI, and optionally displays
|
||||
and/or saves the results with bounding boxes and labels.
|
||||
|
|
@ -123,8 +120,7 @@ class SAHIInference:
|
|||
|
||||
@staticmethod
|
||||
def parse_opt() -> argparse.Namespace:
|
||||
"""
|
||||
Parse command line arguments for the inference process.
|
||||
"""Parse command line arguments for the inference process.
|
||||
|
||||
Returns:
|
||||
(argparse.Namespace): Parsed command line arguments.
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from ultralytics.utils.checks import check_yaml
|
|||
|
||||
|
||||
class YOLOv8Seg:
|
||||
"""
|
||||
YOLOv8 segmentation model for performing instance segmentation using ONNX Runtime.
|
||||
"""YOLOv8 segmentation model for performing instance segmentation using ONNX Runtime.
|
||||
|
||||
This class implements a YOLOv8 instance segmentation model using ONNX Runtime for inference. It handles
|
||||
preprocessing of input images, running inference with the ONNX model, and postprocessing the results to generate
|
||||
|
|
@ -43,15 +42,14 @@ class YOLOv8Seg:
|
|||
"""
|
||||
|
||||
def __init__(self, onnx_model: str, conf: float = 0.25, iou: float = 0.7, imgsz: int | tuple[int, int] = 640):
|
||||
"""
|
||||
Initialize the instance segmentation model using an ONNX model.
|
||||
"""Initialize the instance segmentation model using an ONNX model.
|
||||
|
||||
Args:
|
||||
onnx_model (str): Path to the ONNX model file.
|
||||
conf (float, optional): Confidence threshold for filtering detections.
|
||||
iou (float, optional): IoU threshold for non-maximum suppression.
|
||||
imgsz (int | tuple[int, int], optional): Input image size of the model. Can be an integer for square
|
||||
input or a tuple for rectangular input.
|
||||
imgsz (int | tuple[int, int], optional): Input image size of the model. Can be an integer for square input
|
||||
or a tuple for rectangular input.
|
||||
"""
|
||||
self.session = ort.InferenceSession(
|
||||
onnx_model,
|
||||
|
|
@ -66,8 +64,7 @@ class YOLOv8Seg:
|
|||
self.iou = iou
|
||||
|
||||
def __call__(self, img: np.ndarray) -> list[Results]:
|
||||
"""
|
||||
Run inference on the input image using the ONNX model.
|
||||
"""Run inference on the input image using the ONNX model.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original input image in BGR format.
|
||||
|
|
@ -81,8 +78,7 @@ class YOLOv8Seg:
|
|||
return self.postprocess(img, prep_img, outs)
|
||||
|
||||
def letterbox(self, img: np.ndarray, new_shape: tuple[int, int] = (640, 640)) -> np.ndarray:
|
||||
"""
|
||||
Resize and pad image while maintaining aspect ratio.
|
||||
"""Resize and pad image while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): Input image in BGR format.
|
||||
|
|
@ -109,16 +105,15 @@ class YOLOv8Seg:
|
|||
return img
|
||||
|
||||
def preprocess(self, img: np.ndarray, new_shape: tuple[int, int]) -> np.ndarray:
|
||||
"""
|
||||
Preprocess the input image before feeding it into the model.
|
||||
"""Preprocess the input image before feeding it into the model.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The input image in BGR format.
|
||||
new_shape (tuple[int, int]): The target shape for resizing as (height, width).
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Preprocessed image ready for model inference, with shape (1, 3, height, width) and
|
||||
normalized to [0, 1].
|
||||
(np.ndarray): Preprocessed image ready for model inference, with shape (1, 3, height, width) and normalized
|
||||
to [0, 1].
|
||||
"""
|
||||
img = self.letterbox(img, new_shape)
|
||||
img = img[..., ::-1].transpose([2, 0, 1])[None] # BGR to RGB, BHWC to BCHW
|
||||
|
|
@ -127,8 +122,7 @@ class YOLOv8Seg:
|
|||
return img
|
||||
|
||||
def postprocess(self, img: np.ndarray, prep_img: np.ndarray, outs: list) -> list[Results]:
|
||||
"""
|
||||
Post-process model predictions to extract meaningful results.
|
||||
"""Post-process model predictions to extract meaningful results.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original input image.
|
||||
|
|
@ -152,8 +146,7 @@ class YOLOv8Seg:
|
|||
def process_mask(
|
||||
self, protos: torch.Tensor, masks_in: torch.Tensor, bboxes: torch.Tensor, shape: tuple[int, int]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process prototype masks with predicted mask coefficients to generate instance segmentation masks.
|
||||
"""Process prototype masks with predicted mask coefficients to generate instance segmentation masks.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): Prototype masks with shape (mask_dim, mask_h, mask_w).
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ except ImportError:
|
|||
|
||||
|
||||
class YOLOv8TFLite:
|
||||
"""
|
||||
A YOLOv8 object detection class using TensorFlow Lite for efficient inference.
|
||||
"""A YOLOv8 object detection class using TensorFlow Lite for efficient inference.
|
||||
|
||||
This class handles model loading, preprocessing, inference, and visualization of detection results for YOLOv8 models
|
||||
converted to TensorFlow Lite format.
|
||||
|
|
@ -56,8 +55,7 @@ class YOLOv8TFLite:
|
|||
"""
|
||||
|
||||
def __init__(self, model: str, conf: float = 0.25, iou: float = 0.45, metadata: str | None = None):
|
||||
"""
|
||||
Initialize the YOLOv8TFLite detector.
|
||||
"""Initialize the YOLOv8TFLite detector.
|
||||
|
||||
Args:
|
||||
model (str): Path to the TFLite model file.
|
||||
|
|
@ -94,8 +92,7 @@ class YOLOv8TFLite:
|
|||
def letterbox(
|
||||
self, img: np.ndarray, new_shape: tuple[int, int] = (640, 640)
|
||||
) -> tuple[np.ndarray, tuple[float, float]]:
|
||||
"""
|
||||
Resize and pad image while maintaining aspect ratio.
|
||||
"""Resize and pad image while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): Input image with shape (H, W, C).
|
||||
|
|
@ -123,8 +120,7 @@ class YOLOv8TFLite:
|
|||
return img, (top / img.shape[0], left / img.shape[1])
|
||||
|
||||
def draw_detections(self, img: np.ndarray, box: np.ndarray, score: np.float32, class_id: int) -> None:
|
||||
"""
|
||||
Draw bounding boxes and labels on the input image based on detected objects.
|
||||
"""Draw bounding boxes and labels on the input image based on detected objects.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The input image to draw detections on.
|
||||
|
|
@ -161,8 +157,7 @@ class YOLOv8TFLite:
|
|||
cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
||||
|
||||
def preprocess(self, img: np.ndarray) -> tuple[np.ndarray, tuple[float, float]]:
|
||||
"""
|
||||
Preprocess the input image before performing inference.
|
||||
"""Preprocess the input image before performing inference.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The input image to be preprocessed with shape (H, W, C).
|
||||
|
|
@ -178,8 +173,7 @@ class YOLOv8TFLite:
|
|||
return img / 255, pad # Normalize to [0, 1]
|
||||
|
||||
def postprocess(self, img: np.ndarray, outputs: np.ndarray, pad: tuple[float, float]) -> np.ndarray:
|
||||
"""
|
||||
Process model outputs to extract and visualize detections.
|
||||
"""Process model outputs to extract and visualize detections.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original input image.
|
||||
|
|
@ -216,8 +210,7 @@ class YOLOv8TFLite:
|
|||
return img
|
||||
|
||||
def detect(self, img_path: str) -> np.ndarray:
|
||||
"""
|
||||
Perform object detection on an input image.
|
||||
"""Perform object detection on an input image.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the input image file.
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@ def pytest_addoption(parser):
|
|||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""
|
||||
Modify the list of test items to exclude tests marked as slow if the --slow option is not specified.
|
||||
"""Modify the list of test items to exclude tests marked as slow if the --slow option is not specified.
|
||||
|
||||
Args:
|
||||
config: The pytest configuration object that provides access to command-line options.
|
||||
|
|
@ -23,8 +22,7 @@ def pytest_collection_modifyitems(config, items):
|
|||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
"""
|
||||
Initialize session configurations for pytest.
|
||||
"""Initialize session configurations for pytest.
|
||||
|
||||
This function is automatically called by pytest after the 'Session' object has been created but before performing
|
||||
test collection. It sets the initial seeds for the test session.
|
||||
|
|
@ -38,8 +36,7 @@ def pytest_sessionstart(session):
|
|||
|
||||
|
||||
def pytest_terminal_summary(terminalreporter, exitstatus, config):
|
||||
"""
|
||||
Cleanup operations after pytest session.
|
||||
"""Cleanup operations after pytest session.
|
||||
|
||||
This function is automatically called by pytest at the end of the entire test session. It removes certain files and
|
||||
directories used during testing.
|
||||
|
|
|
|||
|
|
@ -175,8 +175,7 @@ def test_youtube():
|
|||
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
def test_track_stream(model, tmp_path):
|
||||
"""
|
||||
Test streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.
|
||||
"""Test streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.
|
||||
|
||||
Note imgsz=160 required for tracking for higher confidence and better matches.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -242,12 +242,11 @@ CFG_BOOL_KEYS = frozenset(
|
|||
|
||||
|
||||
def cfg2dict(cfg: str | Path | dict | SimpleNamespace) -> dict:
|
||||
"""
|
||||
Convert a configuration object to a dictionary.
|
||||
"""Convert a configuration object to a dictionary.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | dict | SimpleNamespace): Configuration object to be converted. Can be a file path,
|
||||
a string, a dictionary, or a SimpleNamespace object.
|
||||
cfg (str | Path | dict | SimpleNamespace): Configuration object to be converted. Can be a file path, a string, a
|
||||
dictionary, or a SimpleNamespace object.
|
||||
|
||||
Returns:
|
||||
(dict): Configuration object in dictionary format.
|
||||
|
|
@ -279,8 +278,7 @@ def cfg2dict(cfg: str | Path | dict | SimpleNamespace) -> dict:
|
|||
def get_cfg(
|
||||
cfg: str | Path | dict | SimpleNamespace = DEFAULT_CFG_DICT, overrides: dict | None = None
|
||||
) -> SimpleNamespace:
|
||||
"""
|
||||
Load and merge configuration data from a file or dictionary, with optional overrides.
|
||||
"""Load and merge configuration data from a file or dictionary, with optional overrides.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | dict | SimpleNamespace): Configuration data source. Can be a file path, dictionary, or
|
||||
|
|
@ -327,8 +325,7 @@ def get_cfg(
|
|||
|
||||
|
||||
def check_cfg(cfg: dict, hard: bool = True) -> None:
|
||||
"""
|
||||
Check configuration argument types and values for the Ultralytics library.
|
||||
"""Check configuration argument types and values for the Ultralytics library.
|
||||
|
||||
This function validates the types and values of configuration arguments, ensuring correctness and converting them if
|
||||
necessary. It checks for specific key types defined in global variables such as `CFG_FLOAT_KEYS`,
|
||||
|
|
@ -389,14 +386,13 @@ def check_cfg(cfg: dict, hard: bool = True) -> None:
|
|||
|
||||
|
||||
def get_save_dir(args: SimpleNamespace, name: str | None = None) -> Path:
|
||||
"""
|
||||
Return the directory path for saving outputs, derived from arguments or default settings.
|
||||
"""Return the directory path for saving outputs, derived from arguments or default settings.
|
||||
|
||||
Args:
|
||||
args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task',
|
||||
'mode', and 'save_dir'.
|
||||
name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name'
|
||||
or the 'args.mode'.
|
||||
args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task', 'mode',
|
||||
and 'save_dir'.
|
||||
name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name' or the
|
||||
'args.mode'.
|
||||
|
||||
Returns:
|
||||
(Path): Directory path where outputs should be saved.
|
||||
|
|
@ -421,8 +417,7 @@ def get_save_dir(args: SimpleNamespace, name: str | None = None) -> Path:
|
|||
|
||||
|
||||
def _handle_deprecation(custom: dict) -> dict:
|
||||
"""
|
||||
Handle deprecated configuration keys by mapping them to current equivalents with deprecation warnings.
|
||||
"""Handle deprecated configuration keys by mapping them to current equivalents with deprecation warnings.
|
||||
|
||||
Args:
|
||||
custom (dict): Configuration dictionary potentially containing deprecated keys.
|
||||
|
|
@ -465,8 +460,7 @@ def _handle_deprecation(custom: dict) -> dict:
|
|||
|
||||
|
||||
def check_dict_alignment(base: dict, custom: dict, e: Exception | None = None) -> None:
|
||||
"""
|
||||
Check alignment between custom and base configuration dictionaries, handling deprecated keys and providing error
|
||||
"""Check alignment between custom and base configuration dictionaries, handling deprecated keys and providing error
|
||||
messages for mismatched keys.
|
||||
|
||||
Args:
|
||||
|
|
@ -505,8 +499,7 @@ def check_dict_alignment(base: dict, custom: dict, e: Exception | None = None) -
|
|||
|
||||
|
||||
def merge_equals_args(args: list[str]) -> list[str]:
|
||||
"""
|
||||
Merge arguments around isolated '=' in a list of strings and join fragments with brackets.
|
||||
"""Merge arguments around isolated '=' in a list of strings and join fragments with brackets.
|
||||
|
||||
This function handles the following cases:
|
||||
1. ['arg', '=', 'val'] becomes ['arg=val']
|
||||
|
|
@ -565,15 +558,14 @@ def merge_equals_args(args: list[str]) -> list[str]:
|
|||
|
||||
|
||||
def handle_yolo_hub(args: list[str]) -> None:
|
||||
"""
|
||||
Handle Ultralytics HUB command-line interface (CLI) commands for authentication.
|
||||
"""Handle Ultralytics HUB command-line interface (CLI) commands for authentication.
|
||||
|
||||
This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a
|
||||
script with arguments related to HUB authentication.
|
||||
|
||||
Args:
|
||||
args (list[str]): A list of command line arguments. The first argument should be either 'login'
|
||||
or 'logout'. For 'login', an optional second argument can be the API key.
|
||||
args (list[str]): A list of command line arguments. The first argument should be either 'login' or 'logout'. For
|
||||
'login', an optional second argument can be the API key.
|
||||
|
||||
Examples:
|
||||
$ yolo login YOUR_API_KEY
|
||||
|
|
@ -595,8 +587,7 @@ def handle_yolo_hub(args: list[str]) -> None:
|
|||
|
||||
|
||||
def handle_yolo_settings(args: list[str]) -> None:
|
||||
"""
|
||||
Handle YOLO settings command-line interface (CLI) commands.
|
||||
"""Handle YOLO settings command-line interface (CLI) commands.
|
||||
|
||||
This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be
|
||||
called when executing a script with arguments related to YOLO settings management.
|
||||
|
|
@ -638,8 +629,7 @@ def handle_yolo_settings(args: list[str]) -> None:
|
|||
|
||||
|
||||
def handle_yolo_solutions(args: list[str]) -> None:
|
||||
"""
|
||||
Process YOLO solutions arguments and run the specified computer vision solutions pipeline.
|
||||
"""Process YOLO solutions arguments and run the specified computer vision solutions pipeline.
|
||||
|
||||
Args:
|
||||
args (list[str]): Command-line arguments for configuring and running the Ultralytics YOLO solutions.
|
||||
|
|
@ -748,8 +738,7 @@ def handle_yolo_solutions(args: list[str]) -> None:
|
|||
|
||||
|
||||
def parse_key_value_pair(pair: str = "key=value") -> tuple:
|
||||
"""
|
||||
Parse a key-value pair string into separate key and value components.
|
||||
"""Parse a key-value pair string into separate key and value components.
|
||||
|
||||
Args:
|
||||
pair (str): A string containing a key-value pair in the format "key=value".
|
||||
|
|
@ -782,8 +771,7 @@ def parse_key_value_pair(pair: str = "key=value") -> tuple:
|
|||
|
||||
|
||||
def smart_value(v: str) -> Any:
|
||||
"""
|
||||
Convert a string representation of a value to its appropriate Python type.
|
||||
"""Convert a string representation of a value to its appropriate Python type.
|
||||
|
||||
This function attempts to convert a given string into a Python object of the most appropriate type. It handles
|
||||
conversions to None, bool, int, float, and other types that can be evaluated safely.
|
||||
|
|
@ -792,8 +780,8 @@ def smart_value(v: str) -> Any:
|
|||
v (str): The string representation of the value to be converted.
|
||||
|
||||
Returns:
|
||||
(Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion
|
||||
is applicable.
|
||||
(Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion is
|
||||
applicable.
|
||||
|
||||
Examples:
|
||||
>>> smart_value("42")
|
||||
|
|
@ -827,8 +815,7 @@ def smart_value(v: str) -> Any:
|
|||
|
||||
|
||||
def entrypoint(debug: str = "") -> None:
|
||||
"""
|
||||
Ultralytics entrypoint function for parsing and executing command-line arguments.
|
||||
"""Ultralytics entrypoint function for parsing and executing command-line arguments.
|
||||
|
||||
This function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments and executing
|
||||
the corresponding tasks such as training, validation, prediction, exporting models, and more.
|
||||
|
|
@ -1000,8 +987,7 @@ def entrypoint(debug: str = "") -> None:
|
|||
|
||||
# Special modes --------------------------------------------------------------------------------------------------------
|
||||
def copy_default_cfg() -> None:
|
||||
"""
|
||||
Copy the default configuration file and create a new one with '_copy' appended to its name.
|
||||
"""Copy the default configuration file and create a new one with '_copy' appended to its name.
|
||||
|
||||
This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it with '_copy'
|
||||
appended to its name in the current working directory. It provides a convenient way to create a custom configuration
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ def auto_annotate(
|
|||
classes: list[int] | None = None,
|
||||
output_dir: str | Path | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Automatically annotate images using a YOLO object detection model and a SAM segmentation model.
|
||||
"""Automatically annotate images using a YOLO object detection model and a SAM segmentation model.
|
||||
|
||||
This function processes images in a specified directory, detects objects using a YOLO model, and then generates
|
||||
segmentation masks using a SAM model. The resulting annotations are saved as text files in YOLO format.
|
||||
|
|
@ -35,8 +34,8 @@ def auto_annotate(
|
|||
imgsz (int): Input image resize dimension.
|
||||
max_det (int): Maximum number of detections per image.
|
||||
classes (list[int], optional): Filter predictions to specified class IDs, returning only relevant detections.
|
||||
output_dir (str | Path, optional): Directory to save the annotated results. If None, creates a default
|
||||
directory based on the input data path.
|
||||
output_dir (str | Path, optional): Directory to save the annotated results. If None, creates a default directory
|
||||
based on the input data path.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.data.annotator import auto_annotate
|
||||
|
|
|
|||
|
|
@ -26,8 +26,7 @@ DEFAULT_STD = (1.0, 1.0, 1.0)
|
|||
|
||||
|
||||
class BaseTransform:
|
||||
"""
|
||||
Base class for image transformations in the Ultralytics library.
|
||||
"""Base class for image transformations in the Ultralytics library.
|
||||
|
||||
This class serves as a foundation for implementing various image processing operations, designed to be compatible
|
||||
with both classification and semantic segmentation tasks.
|
||||
|
|
@ -45,8 +44,7 @@ class BaseTransform:
|
|||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the BaseTransform object.
|
||||
"""Initialize the BaseTransform object.
|
||||
|
||||
This constructor sets up the base transformation object, which can be extended for specific image processing
|
||||
tasks. It is designed to be compatible with both classification and semantic segmentation.
|
||||
|
|
@ -57,15 +55,14 @@ class BaseTransform:
|
|||
pass
|
||||
|
||||
def apply_image(self, labels):
|
||||
"""
|
||||
Apply image transformations to labels.
|
||||
"""Apply image transformations to labels.
|
||||
|
||||
This method is intended to be overridden by subclasses to implement specific image transformation
|
||||
logic. In its base form, it returns the input labels unchanged.
|
||||
|
||||
Args:
|
||||
labels (Any): The input labels to be transformed. The exact type and structure of labels may
|
||||
vary depending on the specific implementation.
|
||||
labels (Any): The input labels to be transformed. The exact type and structure of labels may vary depending
|
||||
on the specific implementation.
|
||||
|
||||
Returns:
|
||||
(Any): The transformed labels. In the base implementation, this is identical to the input.
|
||||
|
|
@ -80,8 +77,7 @@ class BaseTransform:
|
|||
pass
|
||||
|
||||
def apply_instances(self, labels):
|
||||
"""
|
||||
Apply transformations to object instances in labels.
|
||||
"""Apply transformations to object instances in labels.
|
||||
|
||||
This method is responsible for applying various transformations to object instances within the given
|
||||
labels. It is designed to be overridden by subclasses to implement specific instance transformation
|
||||
|
|
@ -101,8 +97,7 @@ class BaseTransform:
|
|||
pass
|
||||
|
||||
def apply_semantic(self, labels):
|
||||
"""
|
||||
Apply semantic segmentation transformations to an image.
|
||||
"""Apply semantic segmentation transformations to an image.
|
||||
|
||||
This method is intended to be overridden by subclasses to implement specific semantic segmentation
|
||||
transformations. In its base form, it does not perform any operations.
|
||||
|
|
@ -121,16 +116,15 @@ class BaseTransform:
|
|||
pass
|
||||
|
||||
def __call__(self, labels):
|
||||
"""
|
||||
Apply all label transformations to an image, instances, and semantic masks.
|
||||
"""Apply all label transformations to an image, instances, and semantic masks.
|
||||
|
||||
This method orchestrates the application of various transformations defined in the BaseTransform class to the
|
||||
input labels. It sequentially calls the apply_image and apply_instances methods to process the image and object
|
||||
instances, respectively.
|
||||
|
||||
Args:
|
||||
labels (dict): A dictionary containing image data and annotations. Expected keys include 'img' for
|
||||
the image data, and 'instances' for object instances.
|
||||
labels (dict): A dictionary containing image data and annotations. Expected keys include 'img' for the image
|
||||
data, and 'instances' for object instances.
|
||||
|
||||
Returns:
|
||||
(dict): The input labels dictionary with transformed image and instances.
|
||||
|
|
@ -146,8 +140,7 @@ class BaseTransform:
|
|||
|
||||
|
||||
class Compose:
|
||||
"""
|
||||
A class for composing multiple image transformations.
|
||||
"""A class for composing multiple image transformations.
|
||||
|
||||
Attributes:
|
||||
transforms (list[Callable]): A list of transformation functions to be applied sequentially.
|
||||
|
|
@ -169,8 +162,7 @@ class Compose:
|
|||
"""
|
||||
|
||||
def __init__(self, transforms):
|
||||
"""
|
||||
Initialize the Compose object with a list of transforms.
|
||||
"""Initialize the Compose object with a list of transforms.
|
||||
|
||||
Args:
|
||||
transforms (list[Callable]): A list of callable transform objects to be applied sequentially.
|
||||
|
|
@ -183,14 +175,13 @@ class Compose:
|
|||
self.transforms = transforms if isinstance(transforms, list) else [transforms]
|
||||
|
||||
def __call__(self, data):
|
||||
"""
|
||||
Apply a series of transformations to input data.
|
||||
"""Apply a series of transformations to input data.
|
||||
|
||||
This method sequentially applies each transformation in the Compose object's transforms to the input data.
|
||||
|
||||
Args:
|
||||
data (Any): The input data to be transformed. This can be of any type, depending on the
|
||||
transformations in the list.
|
||||
data (Any): The input data to be transformed. This can be of any type, depending on the transformations in
|
||||
the list.
|
||||
|
||||
Returns:
|
||||
(Any): The transformed data after applying all transformations in sequence.
|
||||
|
|
@ -205,8 +196,7 @@ class Compose:
|
|||
return data
|
||||
|
||||
def append(self, transform):
|
||||
"""
|
||||
Append a new transform to the existing list of transforms.
|
||||
"""Append a new transform to the existing list of transforms.
|
||||
|
||||
Args:
|
||||
transform (BaseTransform): The transformation to be added to the composition.
|
||||
|
|
@ -218,8 +208,7 @@ class Compose:
|
|||
self.transforms.append(transform)
|
||||
|
||||
def insert(self, index, transform):
|
||||
"""
|
||||
Insert a new transform at a specified index in the existing list of transforms.
|
||||
"""Insert a new transform at a specified index in the existing list of transforms.
|
||||
|
||||
Args:
|
||||
index (int): The index at which to insert the new transform.
|
||||
|
|
@ -234,8 +223,7 @@ class Compose:
|
|||
self.transforms.insert(index, transform)
|
||||
|
||||
def __getitem__(self, index: list | int) -> Compose:
|
||||
"""
|
||||
Retrieve a specific transform or a set of transforms using indexing.
|
||||
"""Retrieve a specific transform or a set of transforms using indexing.
|
||||
|
||||
Args:
|
||||
index (int | list[int]): Index or list of indices of the transforms to retrieve.
|
||||
|
|
@ -256,8 +244,7 @@ class Compose:
|
|||
return Compose([self.transforms[i] for i in index]) if isinstance(index, list) else self.transforms[index]
|
||||
|
||||
def __setitem__(self, index: list | int, value: list | int) -> None:
|
||||
"""
|
||||
Set one or more transforms in the composition using indexing.
|
||||
"""Set one or more transforms in the composition using indexing.
|
||||
|
||||
Args:
|
||||
index (int | list[int]): Index or list of indices to set transforms at.
|
||||
|
|
@ -283,8 +270,7 @@ class Compose:
|
|||
self.transforms[i] = v
|
||||
|
||||
def tolist(self):
|
||||
"""
|
||||
Convert the list of transforms to a standard Python list.
|
||||
"""Convert the list of transforms to a standard Python list.
|
||||
|
||||
Returns:
|
||||
(list): A list containing all the transform objects in the Compose instance.
|
||||
|
|
@ -299,8 +285,7 @@ class Compose:
|
|||
return self.transforms
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Return a string representation of the Compose object.
|
||||
"""Return a string representation of the Compose object.
|
||||
|
||||
Returns:
|
||||
(str): A string representation of the Compose object, including the list of transforms.
|
||||
|
|
@ -318,8 +303,7 @@ class Compose:
|
|||
|
||||
|
||||
class BaseMixTransform:
|
||||
"""
|
||||
Base class for mix transformations like Cutmix, MixUp and Mosaic.
|
||||
"""Base class for mix transformations like Cutmix, MixUp and Mosaic.
|
||||
|
||||
This class provides a foundation for implementing mix transformations on datasets. It handles the probability-based
|
||||
application of transforms and manages the mixing of multiple images and labels.
|
||||
|
|
@ -349,8 +333,7 @@ class BaseMixTransform:
|
|||
"""
|
||||
|
||||
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
|
||||
"""
|
||||
Initialize the BaseMixTransform object for mix transformations like CutMix, MixUp and Mosaic.
|
||||
"""Initialize the BaseMixTransform object for mix transformations like CutMix, MixUp and Mosaic.
|
||||
|
||||
This class serves as a base for implementing mix transformations in image processing pipelines.
|
||||
|
||||
|
|
@ -369,8 +352,7 @@ class BaseMixTransform:
|
|||
self.p = p
|
||||
|
||||
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Apply pre-processing transforms and cutmix/mixup/mosaic transforms to labels data.
|
||||
"""Apply pre-processing transforms and cutmix/mixup/mosaic transforms to labels data.
|
||||
|
||||
This method determines whether to apply the mix transform based on a probability factor. If applied, it selects
|
||||
additional images, applies pre-transforms if specified, and then performs the mix transform.
|
||||
|
|
@ -409,8 +391,7 @@ class BaseMixTransform:
|
|||
return labels
|
||||
|
||||
def _mix_transform(self, labels: dict[str, Any]):
|
||||
"""
|
||||
Apply CutMix, MixUp or Mosaic augmentation to the label dictionary.
|
||||
"""Apply CutMix, MixUp or Mosaic augmentation to the label dictionary.
|
||||
|
||||
This method should be implemented by subclasses to perform specific mix transformations like CutMix, MixUp or
|
||||
Mosaic. It modifies the input label dictionary in-place with the augmented data.
|
||||
|
|
@ -430,8 +411,7 @@ class BaseMixTransform:
|
|||
raise NotImplementedError
|
||||
|
||||
def get_indexes(self):
|
||||
"""
|
||||
Get a list of shuffled indexes for mosaic augmentation.
|
||||
"""Get a list of shuffled indexes for mosaic augmentation.
|
||||
|
||||
Returns:
|
||||
(list[int]): A list of shuffled indexes from the dataset.
|
||||
|
|
@ -445,15 +425,14 @@ class BaseMixTransform:
|
|||
|
||||
@staticmethod
|
||||
def _update_label_text(labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Update label text and class IDs for mixed labels in image augmentation.
|
||||
"""Update label text and class IDs for mixed labels in image augmentation.
|
||||
|
||||
This method processes the 'texts' and 'cls' fields of the input labels dictionary and any mixed labels, creating
|
||||
a unified set of text labels and updating class IDs accordingly.
|
||||
|
||||
Args:
|
||||
labels (dict[str, Any]): A dictionary containing label information, including 'texts' and 'cls' fields,
|
||||
and optionally a 'mix_labels' field with additional label dictionaries.
|
||||
labels (dict[str, Any]): A dictionary containing label information, including 'texts' and 'cls' fields, and
|
||||
optionally a 'mix_labels' field with additional label dictionaries.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): The updated labels dictionary with unified text labels and updated class IDs.
|
||||
|
|
@ -490,8 +469,7 @@ class BaseMixTransform:
|
|||
|
||||
|
||||
class Mosaic(BaseMixTransform):
|
||||
"""
|
||||
Mosaic augmentation for image datasets.
|
||||
"""Mosaic augmentation for image datasets.
|
||||
|
||||
This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. The
|
||||
augmentation is applied to a dataset with a given probability.
|
||||
|
|
@ -520,8 +498,7 @@ class Mosaic(BaseMixTransform):
|
|||
"""
|
||||
|
||||
def __init__(self, dataset, imgsz: int = 640, p: float = 1.0, n: int = 4):
|
||||
"""
|
||||
Initialize the Mosaic augmentation object.
|
||||
"""Initialize the Mosaic augmentation object.
|
||||
|
||||
This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. The
|
||||
augmentation is applied to a dataset with a given probability.
|
||||
|
|
@ -546,15 +523,14 @@ class Mosaic(BaseMixTransform):
|
|||
self.buffer_enabled = self.dataset.cache != "ram"
|
||||
|
||||
def get_indexes(self):
|
||||
"""
|
||||
Return a list of random indexes from the dataset for mosaic augmentation.
|
||||
"""Return a list of random indexes from the dataset for mosaic augmentation.
|
||||
|
||||
This method selects random image indexes either from a buffer or from the entire dataset, depending on the
|
||||
'buffer' parameter. It is used to choose images for creating mosaic augmentations.
|
||||
|
||||
Returns:
|
||||
(list[int]): A list of random image indexes. The length of the list is n-1, where n is the number
|
||||
of images used in the mosaic (either 3 or 8, depending on whether n is 4 or 9).
|
||||
(list[int]): A list of random image indexes. The length of the list is n-1, where n is the number of images
|
||||
used in the mosaic (either 3 or 8, depending on whether n is 4 or 9).
|
||||
|
||||
Examples:
|
||||
>>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)
|
||||
|
|
@ -567,8 +543,7 @@ class Mosaic(BaseMixTransform):
|
|||
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
|
||||
|
||||
def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Apply mosaic augmentation to the input image and labels.
|
||||
"""Apply mosaic augmentation to the input image and labels.
|
||||
|
||||
This method combines multiple images (3, 4, or 9) into a single mosaic image based on the 'n' attribute. It
|
||||
ensures that rectangular annotations are not present and that there are other images available for mosaic
|
||||
|
|
@ -596,8 +571,7 @@ class Mosaic(BaseMixTransform):
|
|||
) # This code is modified for mosaic3 method.
|
||||
|
||||
def _mosaic3(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Create a 1x3 image mosaic by combining three images.
|
||||
"""Create a 1x3 image mosaic by combining three images.
|
||||
|
||||
This method arranges three images in a horizontal layout, with the main image in the center and two additional
|
||||
images on either side. It's part of the Mosaic augmentation technique used in object detection.
|
||||
|
|
@ -655,8 +629,7 @@ class Mosaic(BaseMixTransform):
|
|||
return final_labels
|
||||
|
||||
def _mosaic4(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Create a 2x2 image mosaic from four input images.
|
||||
"""Create a 2x2 image mosaic from four input images.
|
||||
|
||||
This method combines four images into a single mosaic image by placing them in a 2x2 grid. It also updates the
|
||||
corresponding labels for each image in the mosaic.
|
||||
|
|
@ -714,8 +687,7 @@ class Mosaic(BaseMixTransform):
|
|||
return final_labels
|
||||
|
||||
def _mosaic9(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Create a 3x3 image mosaic from the input image and eight additional images.
|
||||
"""Create a 3x3 image mosaic from the input image and eight additional images.
|
||||
|
||||
This method combines nine images into a single mosaic image. The input image is placed at the center, and eight
|
||||
additional images from the dataset are placed around it in a 3x3 grid pattern.
|
||||
|
|
@ -788,8 +760,7 @@ class Mosaic(BaseMixTransform):
|
|||
|
||||
@staticmethod
|
||||
def _update_labels(labels, padw: int, padh: int) -> dict[str, Any]:
|
||||
"""
|
||||
Update label coordinates with padding values.
|
||||
"""Update label coordinates with padding values.
|
||||
|
||||
This method adjusts the bounding box coordinates of object instances in the labels by adding padding
|
||||
values. It also denormalizes the coordinates if they were previously normalized.
|
||||
|
|
@ -814,8 +785,7 @@ class Mosaic(BaseMixTransform):
|
|||
return labels
|
||||
|
||||
def _cat_labels(self, mosaic_labels: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""
|
||||
Concatenate and process labels for mosaic augmentation.
|
||||
"""Concatenate and process labels for mosaic augmentation.
|
||||
|
||||
This method combines labels from multiple images used in mosaic augmentation, clips instances to the mosaic
|
||||
border, and removes zero-area boxes.
|
||||
|
|
@ -866,8 +836,7 @@ class Mosaic(BaseMixTransform):
|
|||
|
||||
|
||||
class MixUp(BaseMixTransform):
|
||||
"""
|
||||
Apply MixUp augmentation to image datasets.
|
||||
"""Apply MixUp augmentation to image datasets.
|
||||
|
||||
This class implements the MixUp augmentation technique as described in the paper [mixup: Beyond Empirical Risk
|
||||
Minimization](https://arxiv.org/abs/1710.09412). MixUp combines two images and their labels using a random weight.
|
||||
|
|
@ -888,8 +857,7 @@ class MixUp(BaseMixTransform):
|
|||
"""
|
||||
|
||||
def __init__(self, dataset, pre_transform=None, p: float = 0.0) -> None:
|
||||
"""
|
||||
Initialize the MixUp augmentation object.
|
||||
"""Initialize the MixUp augmentation object.
|
||||
|
||||
MixUp is an image augmentation technique that combines two images by taking a weighted sum of their pixel values
|
||||
and labels. This implementation is designed for use with the Ultralytics YOLO framework.
|
||||
|
|
@ -907,8 +875,7 @@ class MixUp(BaseMixTransform):
|
|||
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
|
||||
|
||||
def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Apply MixUp augmentation to the input labels.
|
||||
"""Apply MixUp augmentation to the input labels.
|
||||
|
||||
This method implements the MixUp augmentation technique as described in the paper "mixup: Beyond Empirical Risk
|
||||
Minimization" (https://arxiv.org/abs/1710.09412).
|
||||
|
|
@ -932,8 +899,7 @@ class MixUp(BaseMixTransform):
|
|||
|
||||
|
||||
class CutMix(BaseMixTransform):
|
||||
"""
|
||||
Apply CutMix augmentation to image datasets as described in the paper https://arxiv.org/abs/1905.04899.
|
||||
"""Apply CutMix augmentation to image datasets as described in the paper https://arxiv.org/abs/1905.04899.
|
||||
|
||||
CutMix combines two images by replacing a random rectangular region of one image with the corresponding region from
|
||||
another image, and adjusts the labels proportionally to the area of the mixed region.
|
||||
|
|
@ -957,8 +923,7 @@ class CutMix(BaseMixTransform):
|
|||
"""
|
||||
|
||||
def __init__(self, dataset, pre_transform=None, p: float = 0.0, beta: float = 1.0, num_areas: int = 3) -> None:
|
||||
"""
|
||||
Initialize the CutMix augmentation object.
|
||||
"""Initialize the CutMix augmentation object.
|
||||
|
||||
Args:
|
||||
dataset (Any): The dataset to which CutMix augmentation will be applied.
|
||||
|
|
@ -972,8 +937,7 @@ class CutMix(BaseMixTransform):
|
|||
self.num_areas = num_areas
|
||||
|
||||
def _rand_bbox(self, width: int, height: int) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Generate random bounding box coordinates for the cut region.
|
||||
"""Generate random bounding box coordinates for the cut region.
|
||||
|
||||
Args:
|
||||
width (int): Width of the image.
|
||||
|
|
@ -1002,8 +966,7 @@ class CutMix(BaseMixTransform):
|
|||
return x1, y1, x2, y2
|
||||
|
||||
def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Apply CutMix augmentation to the input labels.
|
||||
"""Apply CutMix augmentation to the input labels.
|
||||
|
||||
Args:
|
||||
labels (dict[str, Any]): A dictionary containing the original image and label information.
|
||||
|
|
@ -1050,8 +1013,7 @@ class CutMix(BaseMixTransform):
|
|||
|
||||
|
||||
class RandomPerspective:
|
||||
"""
|
||||
Implement random perspective and affine transformations on images and corresponding annotations.
|
||||
"""Implement random perspective and affine transformations on images and corresponding annotations.
|
||||
|
||||
This class applies random rotations, translations, scaling, shearing, and perspective transformations to images and
|
||||
their associated bounding boxes, segments, and keypoints. It can be used as part of an augmentation pipeline for
|
||||
|
|
@ -1093,8 +1055,7 @@ class RandomPerspective:
|
|||
border: tuple[int, int] = (0, 0),
|
||||
pre_transform=None,
|
||||
):
|
||||
"""
|
||||
Initialize RandomPerspective object with transformation parameters.
|
||||
"""Initialize RandomPerspective object with transformation parameters.
|
||||
|
||||
This class implements random perspective and affine transformations on images and corresponding bounding boxes,
|
||||
segments, and keypoints. Transformations include rotation, translation, scaling, and shearing.
|
||||
|
|
@ -1122,8 +1083,7 @@ class RandomPerspective:
|
|||
self.pre_transform = pre_transform
|
||||
|
||||
def affine_transform(self, img: np.ndarray, border: tuple[int, int]) -> tuple[np.ndarray, np.ndarray, float]:
|
||||
"""
|
||||
Apply a sequence of affine transformations centered around the image center.
|
||||
"""Apply a sequence of affine transformations centered around the image center.
|
||||
|
||||
This function performs a series of geometric transformations on the input image, including translation,
|
||||
perspective change, rotation, scaling, and shearing. The transformations are applied in a specific order to
|
||||
|
|
@ -1186,15 +1146,14 @@ class RandomPerspective:
|
|||
return img, M, s
|
||||
|
||||
def apply_bboxes(self, bboxes: np.ndarray, M: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Apply affine transformation to bounding boxes.
|
||||
"""Apply affine transformation to bounding boxes.
|
||||
|
||||
This function applies an affine transformation to a set of bounding boxes using the provided transformation
|
||||
matrix.
|
||||
|
||||
Args:
|
||||
bboxes (np.ndarray): Bounding boxes in xyxy format with shape (N, 4), where N is the number
|
||||
of bounding boxes.
|
||||
bboxes (np.ndarray): Bounding boxes in xyxy format with shape (N, 4), where N is the number of bounding
|
||||
boxes.
|
||||
M (np.ndarray): Affine transformation matrix with shape (3, 3).
|
||||
|
||||
Returns:
|
||||
|
|
@ -1220,8 +1179,7 @@ class RandomPerspective:
|
|||
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
|
||||
|
||||
def apply_segments(self, segments: np.ndarray, M: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Apply affine transformations to segments and generate new bounding boxes.
|
||||
"""Apply affine transformations to segments and generate new bounding boxes.
|
||||
|
||||
This function applies affine transformations to input segments and generates new bounding boxes based on the
|
||||
transformed segments. It clips the transformed segments to fit within the new bounding boxes.
|
||||
|
|
@ -1256,16 +1214,15 @@ class RandomPerspective:
|
|||
return bboxes, segments
|
||||
|
||||
def apply_keypoints(self, keypoints: np.ndarray, M: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Apply affine transformation to keypoints.
|
||||
"""Apply affine transformation to keypoints.
|
||||
|
||||
This method transforms the input keypoints using the provided affine transformation matrix. It handles
|
||||
perspective rescaling if necessary and updates the visibility of keypoints that fall outside the image
|
||||
boundaries after transformation.
|
||||
|
||||
Args:
|
||||
keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances,
|
||||
17 is the number of keypoints per instance, and 3 represents (x, y, visibility).
|
||||
keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances, 17 is
|
||||
the number of keypoints per instance, and 3 represents (x, y, visibility).
|
||||
M (np.ndarray): 3x3 affine transformation matrix.
|
||||
|
||||
Returns:
|
||||
|
|
@ -1290,8 +1247,7 @@ class RandomPerspective:
|
|||
return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3)
|
||||
|
||||
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Apply random perspective and affine transformations to an image and its associated labels.
|
||||
"""Apply random perspective and affine transformations to an image and its associated labels.
|
||||
|
||||
This method performs a series of transformations including rotation, translation, scaling, shearing, and
|
||||
perspective distortion on the input image and adjusts the corresponding bounding boxes, segments, and keypoints
|
||||
|
|
@ -1378,29 +1334,27 @@ class RandomPerspective:
|
|||
area_thr: float = 0.1,
|
||||
eps: float = 1e-16,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute candidate boxes for further processing based on size and aspect ratio criteria.
|
||||
"""Compute candidate boxes for further processing based on size and aspect ratio criteria.
|
||||
|
||||
This method compares boxes before and after augmentation to determine if they meet specified thresholds for
|
||||
width, height, aspect ratio, and area. It's used to filter out boxes that have been overly distorted or reduced
|
||||
by the augmentation process.
|
||||
|
||||
Args:
|
||||
box1 (np.ndarray): Original boxes before augmentation, shape (4, N) where n is the
|
||||
number of boxes. Format is [x1, y1, x2, y2] in absolute coordinates.
|
||||
box2 (np.ndarray): Augmented boxes after transformation, shape (4, N). Format is
|
||||
[x1, y1, x2, y2] in absolute coordinates.
|
||||
wh_thr (int): Width and height threshold in pixels. Boxes smaller than this in either
|
||||
dimension are rejected.
|
||||
ar_thr (int): Aspect ratio threshold. Boxes with an aspect ratio greater than this
|
||||
value are rejected.
|
||||
area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than
|
||||
this value are rejected.
|
||||
box1 (np.ndarray): Original boxes before augmentation, shape (4, N) where n is the number of boxes. Format
|
||||
is [x1, y1, x2, y2] in absolute coordinates.
|
||||
box2 (np.ndarray): Augmented boxes after transformation, shape (4, N). Format is [x1, y1, x2, y2] in
|
||||
absolute coordinates.
|
||||
wh_thr (int): Width and height threshold in pixels. Boxes smaller than this in either dimension are
|
||||
rejected.
|
||||
ar_thr (int): Aspect ratio threshold. Boxes with an aspect ratio greater than this value are rejected.
|
||||
area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than this value are
|
||||
rejected.
|
||||
eps (float): Small epsilon value to prevent division by zero.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Boolean array of shape (n) indicating which boxes are candidates.
|
||||
True values correspond to boxes that meet all criteria.
|
||||
(np.ndarray): Boolean array of shape (n) indicating which boxes are candidates. True values correspond to
|
||||
boxes that meet all criteria.
|
||||
|
||||
Examples:
|
||||
>>> random_perspective = RandomPerspective()
|
||||
|
|
@ -1417,8 +1371,7 @@ class RandomPerspective:
|
|||
|
||||
|
||||
class RandomHSV:
|
||||
"""
|
||||
Randomly adjust the Hue, Saturation, and Value (HSV) channels of an image.
|
||||
"""Randomly adjust the Hue, Saturation, and Value (HSV) channels of an image.
|
||||
|
||||
This class applies random HSV augmentation to images within predefined limits set by hgain, sgain, and vgain.
|
||||
|
||||
|
|
@ -1441,8 +1394,7 @@ class RandomHSV:
|
|||
"""
|
||||
|
||||
def __init__(self, hgain: float = 0.5, sgain: float = 0.5, vgain: float = 0.5) -> None:
|
||||
"""
|
||||
Initialize the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation.
|
||||
"""Initialize the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation.
|
||||
|
||||
This class applies random adjustments to the HSV channels of an image within specified limits.
|
||||
|
||||
|
|
@ -1460,15 +1412,14 @@ class RandomHSV:
|
|||
self.vgain = vgain
|
||||
|
||||
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Apply random HSV augmentation to an image within predefined limits.
|
||||
"""Apply random HSV augmentation to an image within predefined limits.
|
||||
|
||||
This method modifies the input image by randomly adjusting its Hue, Saturation, and Value (HSV) channels. The
|
||||
adjustments are made within the limits set by hgain, sgain, and vgain during initialization.
|
||||
|
||||
Args:
|
||||
labels (dict[str, Any]): A dictionary containing image data and metadata. Must include an 'img' key with
|
||||
the image as a numpy array.
|
||||
labels (dict[str, Any]): A dictionary containing image data and metadata. Must include an 'img' key with the
|
||||
image as a numpy array.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): A dictionary containing the mixed image and adjusted labels.
|
||||
|
|
@ -1500,8 +1451,7 @@ class RandomHSV:
|
|||
|
||||
|
||||
class RandomFlip:
|
||||
"""
|
||||
Apply a random horizontal or vertical flip to an image with a given probability.
|
||||
"""Apply a random horizontal or vertical flip to an image with a given probability.
|
||||
|
||||
This class performs random image flipping and updates corresponding instance annotations such as bounding boxes and
|
||||
keypoints.
|
||||
|
|
@ -1522,8 +1472,7 @@ class RandomFlip:
|
|||
"""
|
||||
|
||||
def __init__(self, p: float = 0.5, direction: str = "horizontal", flip_idx: list[int] | None = None) -> None:
|
||||
"""
|
||||
Initialize the RandomFlip class with probability and direction.
|
||||
"""Initialize the RandomFlip class with probability and direction.
|
||||
|
||||
This class applies a random horizontal or vertical flip to an image with a given probability. It also updates
|
||||
any instances (bounding boxes, keypoints, etc.) accordingly.
|
||||
|
|
@ -1548,8 +1497,7 @@ class RandomFlip:
|
|||
self.flip_idx = flip_idx
|
||||
|
||||
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Apply random flip to an image and update any instances like bounding boxes or keypoints accordingly.
|
||||
"""Apply random flip to an image and update any instances like bounding boxes or keypoints accordingly.
|
||||
|
||||
This method randomly flips the input image either horizontally or vertically based on the initialized
|
||||
probability and direction. It also updates the corresponding instances (bounding boxes, keypoints) to match the
|
||||
|
|
@ -1594,8 +1542,7 @@ class RandomFlip:
|
|||
|
||||
|
||||
class LetterBox:
|
||||
"""
|
||||
Resize image and padding for detection, instance segmentation, pose.
|
||||
"""Resize image and padding for detection, instance segmentation, pose.
|
||||
|
||||
This class resizes and pads images to a specified shape while preserving aspect ratio. It also updates corresponding
|
||||
labels and bounding boxes.
|
||||
|
|
@ -1629,8 +1576,7 @@ class LetterBox:
|
|||
padding_value: int = 114,
|
||||
interpolation: int = cv2.INTER_LINEAR,
|
||||
):
|
||||
"""
|
||||
Initialize LetterBox object for resizing and padding images.
|
||||
"""Initialize LetterBox object for resizing and padding images.
|
||||
|
||||
This class is designed to resize and pad images for object detection, instance segmentation, and pose estimation
|
||||
tasks. It supports various resizing modes including auto-sizing, scale-fill, and letterboxing.
|
||||
|
|
@ -1668,8 +1614,7 @@ class LetterBox:
|
|||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, labels: dict[str, Any] | None = None, image: np.ndarray = None) -> dict[str, Any] | np.ndarray:
|
||||
"""
|
||||
Resize and pad an image for object detection, instance segmentation, or pose estimation tasks.
|
||||
"""Resize and pad an image for object detection, instance segmentation, or pose estimation tasks.
|
||||
|
||||
This method applies letterboxing to the input image, which involves resizing the image while maintaining its
|
||||
aspect ratio and adding padding to fit the new shape. It also updates any associated labels accordingly.
|
||||
|
|
@ -1748,8 +1693,7 @@ class LetterBox:
|
|||
|
||||
@staticmethod
|
||||
def _update_labels(labels: dict[str, Any], ratio: tuple[float, float], padw: float, padh: float) -> dict[str, Any]:
|
||||
"""
|
||||
Update labels after applying letterboxing to an image.
|
||||
"""Update labels after applying letterboxing to an image.
|
||||
|
||||
This method modifies the bounding box coordinates of instances in the labels to account for resizing and padding
|
||||
applied during letterboxing.
|
||||
|
|
@ -1778,8 +1722,7 @@ class LetterBox:
|
|||
|
||||
|
||||
class CopyPaste(BaseMixTransform):
|
||||
"""
|
||||
CopyPaste class for applying Copy-Paste augmentation to image datasets.
|
||||
"""CopyPaste class for applying Copy-Paste augmentation to image datasets.
|
||||
|
||||
This class implements the Copy-Paste augmentation technique as described in the paper "Simple Copy-Paste is a Strong
|
||||
Data Augmentation Method for Instance Segmentation" (https://arxiv.org/abs/2012.07177). It combines objects from
|
||||
|
|
@ -1878,8 +1821,7 @@ class CopyPaste(BaseMixTransform):
|
|||
|
||||
|
||||
class Albumentations:
|
||||
"""
|
||||
Albumentations transformations for image augmentation.
|
||||
"""Albumentations transformations for image augmentation.
|
||||
|
||||
This class applies various image transformations using the Albumentations library. It includes operations such as
|
||||
Blur, Median Blur, conversion to grayscale, Contrast Limited Adaptive Histogram Equalization (CLAHE), random changes
|
||||
|
|
@ -1904,8 +1846,7 @@ class Albumentations:
|
|||
"""
|
||||
|
||||
def __init__(self, p: float = 1.0) -> None:
|
||||
"""
|
||||
Initialize the Albumentations transform object for YOLO bbox formatted parameters.
|
||||
"""Initialize the Albumentations transform object for YOLO bbox formatted parameters.
|
||||
|
||||
This class applies various image augmentations using the Albumentations library, including Blur, Median Blur,
|
||||
conversion to grayscale, Contrast Limited Adaptive Histogram Equalization, random changes of brightness and
|
||||
|
|
@ -2018,8 +1959,7 @@ class Albumentations:
|
|||
LOGGER.info(f"{prefix}{e}")
|
||||
|
||||
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Apply Albumentations transformations to input labels.
|
||||
"""Apply Albumentations transformations to input labels.
|
||||
|
||||
This method applies a series of image augmentations using the Albumentations library. It can perform both
|
||||
spatial and non-spatial transformations on the input image and its corresponding labels.
|
||||
|
|
@ -2075,8 +2015,7 @@ class Albumentations:
|
|||
|
||||
|
||||
class Format:
|
||||
"""
|
||||
A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks.
|
||||
"""A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks.
|
||||
|
||||
This class standardizes image and instance annotations to be used by the `collate_fn` in PyTorch DataLoader.
|
||||
|
||||
|
|
@ -2116,8 +2055,7 @@ class Format:
|
|||
batch_idx: bool = True,
|
||||
bgr: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Initialize the Format class with given parameters for image and instance annotation formatting.
|
||||
"""Initialize the Format class with given parameters for image and instance annotation formatting.
|
||||
|
||||
This class standardizes image and instance annotations for object detection, instance segmentation, and pose
|
||||
estimation tasks, preparing them for use in PyTorch DataLoader's `collate_fn`.
|
||||
|
|
@ -2160,8 +2098,7 @@ class Format:
|
|||
self.bgr = bgr
|
||||
|
||||
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Format image annotations for object detection, instance segmentation, and pose estimation tasks.
|
||||
"""Format image annotations for object detection, instance segmentation, and pose estimation tasks.
|
||||
|
||||
This method standardizes the image and instance annotations to be used by the `collate_fn` in PyTorch
|
||||
DataLoader. It processes the input labels dictionary, converting annotations to the specified format and
|
||||
|
|
@ -2229,8 +2166,7 @@ class Format:
|
|||
return labels
|
||||
|
||||
def _format_img(self, img: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Format an image for YOLO from a Numpy array to a PyTorch tensor.
|
||||
"""Format an image for YOLO from a Numpy array to a PyTorch tensor.
|
||||
|
||||
This function performs the following operations:
|
||||
1. Ensures the image has 3 dimensions (adds a channel dimension if needed).
|
||||
|
|
@ -2262,8 +2198,7 @@ class Format:
|
|||
def _format_segments(
|
||||
self, instances: Instances, cls: np.ndarray, w: int, h: int
|
||||
) -> tuple[np.ndarray, Instances, np.ndarray]:
|
||||
"""
|
||||
Convert polygon segments to bitmap masks.
|
||||
"""Convert polygon segments to bitmap masks.
|
||||
|
||||
Args:
|
||||
instances (Instances): Object containing segment information.
|
||||
|
|
@ -2297,8 +2232,7 @@ class LoadVisualPrompt:
|
|||
"""Create visual prompts from bounding boxes or masks for model input."""
|
||||
|
||||
def __init__(self, scale_factor: float = 1 / 8) -> None:
|
||||
"""
|
||||
Initialize the LoadVisualPrompt with a scale factor.
|
||||
"""Initialize the LoadVisualPrompt with a scale factor.
|
||||
|
||||
Args:
|
||||
scale_factor (float): Factor to scale the input image dimensions.
|
||||
|
|
@ -2306,8 +2240,7 @@ class LoadVisualPrompt:
|
|||
self.scale_factor = scale_factor
|
||||
|
||||
def make_mask(self, boxes: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
||||
"""
|
||||
Create binary masks from bounding boxes.
|
||||
"""Create binary masks from bounding boxes.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes in xyxy format, shape: (N, 4).
|
||||
|
|
@ -2324,8 +2257,7 @@ class LoadVisualPrompt:
|
|||
return (r >= x1) * (r < x2) * (c >= y1) * (c < y2)
|
||||
|
||||
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Process labels to create visual prompts.
|
||||
"""Process labels to create visual prompts.
|
||||
|
||||
Args:
|
||||
labels (dict[str, Any]): Dictionary containing image data and annotations.
|
||||
|
|
@ -2351,8 +2283,7 @@ class LoadVisualPrompt:
|
|||
bboxes: np.ndarray | torch.Tensor = None,
|
||||
masks: np.ndarray | torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate visual masks based on bounding boxes or masks.
|
||||
"""Generate visual masks based on bounding boxes or masks.
|
||||
|
||||
Args:
|
||||
category (int | np.ndarray | torch.Tensor): The category labels for the objects.
|
||||
|
|
@ -2393,8 +2324,7 @@ class LoadVisualPrompt:
|
|||
|
||||
|
||||
class RandomLoadText:
|
||||
"""
|
||||
Randomly sample positive and negative texts and update class indices accordingly.
|
||||
"""Randomly sample positive and negative texts and update class indices accordingly.
|
||||
|
||||
This class is responsible for sampling texts from a given set of class texts, including both positive (present in
|
||||
the image) and negative (not present in the image) samples. It updates the class indices to reflect the sampled
|
||||
|
|
@ -2426,20 +2356,19 @@ class RandomLoadText:
|
|||
padding: bool = False,
|
||||
padding_value: list[str] = [""],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the RandomLoadText class for randomly sampling positive and negative texts.
|
||||
"""Initialize the RandomLoadText class for randomly sampling positive and negative texts.
|
||||
|
||||
This class is designed to randomly sample positive texts and negative texts, and update the class indices
|
||||
accordingly to the number of samples. It can be used for text-based object detection tasks.
|
||||
|
||||
Args:
|
||||
prompt_format (str): Format string for the prompt. The format string should
|
||||
contain a single pair of curly braces {} where the text will be inserted.
|
||||
neg_samples (tuple[int, int]): A range to randomly sample negative texts. The first integer
|
||||
specifies the minimum number of negative samples, and the second integer specifies the maximum.
|
||||
prompt_format (str): Format string for the prompt. The format string should contain a single pair of curly
|
||||
braces {} where the text will be inserted.
|
||||
neg_samples (tuple[int, int]): A range to randomly sample negative texts. The first integer specifies the
|
||||
minimum number of negative samples, and the second integer specifies the maximum.
|
||||
max_samples (int): The maximum number of different text samples in one image.
|
||||
padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always
|
||||
be equal to max_samples.
|
||||
padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always be equal to
|
||||
max_samples.
|
||||
padding_value (str): The padding text to use when padding is True.
|
||||
|
||||
Attributes:
|
||||
|
|
@ -2465,8 +2394,7 @@ class RandomLoadText:
|
|||
self.padding_value = padding_value
|
||||
|
||||
def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Randomly sample positive and negative texts and update class indices accordingly.
|
||||
"""Randomly sample positive and negative texts and update class indices accordingly.
|
||||
|
||||
This method samples positive texts based on the existing class labels in the image, and randomly selects
|
||||
negative texts from the remaining classes. It then updates the class indices to match the new sampled text
|
||||
|
|
@ -2532,8 +2460,7 @@ class RandomLoadText:
|
|||
|
||||
|
||||
def v8_transforms(dataset, imgsz: int, hyp: IterableSimpleNamespace, stretch: bool = False):
|
||||
"""
|
||||
Apply a series of image transformations for training.
|
||||
"""Apply a series of image transformations for training.
|
||||
|
||||
This function creates a composition of image augmentation techniques to prepare images for YOLO training. It
|
||||
includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments.
|
||||
|
|
@ -2608,8 +2535,7 @@ def classify_transforms(
|
|||
interpolation: str = "BILINEAR",
|
||||
crop_fraction: float | None = None,
|
||||
):
|
||||
"""
|
||||
Create a composition of image transforms for classification tasks.
|
||||
"""Create a composition of image transforms for classification tasks.
|
||||
|
||||
This function generates a sequence of torchvision transforms suitable for preprocessing images for classification
|
||||
models during evaluation or inference. The transforms include resizing, center cropping, conversion to tensor, and
|
||||
|
|
@ -2668,8 +2594,7 @@ def classify_augmentations(
|
|||
erasing: float = 0.0,
|
||||
interpolation: str = "BILINEAR",
|
||||
):
|
||||
"""
|
||||
Create a composition of image augmentation transforms for classification tasks.
|
||||
"""Create a composition of image augmentation transforms for classification tasks.
|
||||
|
||||
This function generates a set of image transformations suitable for training classification models. It includes
|
||||
options for resizing, flipping, color jittering, auto augmentation, and random erasing.
|
||||
|
|
@ -2757,8 +2682,7 @@ def classify_augmentations(
|
|||
|
||||
# NOTE: keep this class for backward compatibility
|
||||
class ClassifyLetterBox:
|
||||
"""
|
||||
A class for resizing and padding images for classification tasks.
|
||||
"""A class for resizing and padding images for classification tasks.
|
||||
|
||||
This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]). It
|
||||
resizes and pads images to a specified size while maintaining the original aspect ratio.
|
||||
|
|
@ -2781,15 +2705,14 @@ class ClassifyLetterBox:
|
|||
"""
|
||||
|
||||
def __init__(self, size: int | tuple[int, int] = (640, 640), auto: bool = False, stride: int = 32):
|
||||
"""
|
||||
Initialize the ClassifyLetterBox object for image preprocessing.
|
||||
"""Initialize the ClassifyLetterBox object for image preprocessing.
|
||||
|
||||
This class is designed to be part of a transformation pipeline for image classification tasks. It resizes and
|
||||
pads images to a specified size while maintaining the original aspect ratio.
|
||||
|
||||
Args:
|
||||
size (int | tuple[int, int]): Target size for the letterboxed image. If an int, a square image of
|
||||
(size, size) is created. If a tuple, it should be (height, width).
|
||||
size (int | tuple[int, int]): Target size for the letterboxed image. If an int, a square image of (size,
|
||||
size) is created. If a tuple, it should be (height, width).
|
||||
auto (bool): If True, automatically calculates the short side based on stride.
|
||||
stride (int): The stride value, used when 'auto' is True.
|
||||
|
||||
|
|
@ -2812,8 +2735,7 @@ class ClassifyLetterBox:
|
|||
self.stride = stride # used with auto
|
||||
|
||||
def __call__(self, im: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Resize and pad an image using the letterbox method.
|
||||
"""Resize and pad an image using the letterbox method.
|
||||
|
||||
This method resizes the input image to fit within the specified dimensions while maintaining its aspect ratio,
|
||||
then pads the resized image to match the target size.
|
||||
|
|
@ -2822,8 +2744,8 @@ class ClassifyLetterBox:
|
|||
im (np.ndarray): Input image as a numpy array with shape (H, W, C).
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are
|
||||
the target height and width respectively.
|
||||
(np.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are the
|
||||
target height and width respectively.
|
||||
|
||||
Examples:
|
||||
>>> letterbox = ClassifyLetterBox(size=(640, 640))
|
||||
|
|
@ -2848,8 +2770,7 @@ class ClassifyLetterBox:
|
|||
|
||||
# NOTE: keep this class for backward compatibility
|
||||
class CenterCrop:
|
||||
"""
|
||||
Apply center cropping to images for classification tasks.
|
||||
"""Apply center cropping to images for classification tasks.
|
||||
|
||||
This class performs center cropping on input images, resizing them to a specified size while maintaining the aspect
|
||||
ratio. It is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).
|
||||
|
|
@ -2870,15 +2791,14 @@ class CenterCrop:
|
|||
"""
|
||||
|
||||
def __init__(self, size: int | tuple[int, int] = (640, 640)):
|
||||
"""
|
||||
Initialize the CenterCrop object for image preprocessing.
|
||||
"""Initialize the CenterCrop object for image preprocessing.
|
||||
|
||||
This class is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).
|
||||
It performs a center crop on input images to a specified size.
|
||||
|
||||
Args:
|
||||
size (int | tuple[int, int]): The desired output size of the crop. If size is an int, a square crop
|
||||
(size, size) is made. If size is a sequence like (h, w), it is used as the output size.
|
||||
size (int | tuple[int, int]): The desired output size of the crop. If size is an int, a square crop (size,
|
||||
size) is made. If size is a sequence like (h, w), it is used as the output size.
|
||||
|
||||
Returns:
|
||||
(None): This method initializes the object and does not return anything.
|
||||
|
|
@ -2894,15 +2814,14 @@ class CenterCrop:
|
|||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
|
||||
def __call__(self, im: Image.Image | np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Apply center cropping to an input image.
|
||||
"""Apply center cropping to an input image.
|
||||
|
||||
This method resizes and crops the center of the image using a letterbox method. It maintains the aspect ratio of
|
||||
the original image while fitting it into the specified dimensions.
|
||||
|
||||
Args:
|
||||
im (np.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a
|
||||
PIL Image object.
|
||||
im (np.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a PIL Image
|
||||
object.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The center-cropped and resized image as a numpy array of shape (self.h, self.w, C).
|
||||
|
|
@ -2923,8 +2842,7 @@ class CenterCrop:
|
|||
|
||||
# NOTE: keep this class for backward compatibility
|
||||
class ToTensor:
|
||||
"""
|
||||
Convert an image from a numpy array to a PyTorch tensor.
|
||||
"""Convert an image from a numpy array to a PyTorch tensor.
|
||||
|
||||
This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).
|
||||
|
||||
|
|
@ -2947,8 +2865,7 @@ class ToTensor:
|
|||
"""
|
||||
|
||||
def __init__(self, half: bool = False):
|
||||
"""
|
||||
Initialize the ToTensor object for converting images to PyTorch tensors.
|
||||
"""Initialize the ToTensor object for converting images to PyTorch tensors.
|
||||
|
||||
This class is designed to be used as part of a transformation pipeline for image preprocessing in the
|
||||
Ultralytics YOLO framework. It converts numpy arrays or PIL Images to PyTorch tensors, with an option for
|
||||
|
|
@ -2968,8 +2885,7 @@ class ToTensor:
|
|||
self.half = half
|
||||
|
||||
def __call__(self, im: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Transform an image from a numpy array to a PyTorch tensor.
|
||||
"""Transform an image from a numpy array to a PyTorch tensor.
|
||||
|
||||
This method converts the input image from a numpy array to a PyTorch tensor, applying optional half-precision
|
||||
conversion and normalization. The image is transposed from HWC to CHW format and the color channels are reversed
|
||||
|
|
@ -2979,8 +2895,8 @@ class ToTensor:
|
|||
im (np.ndarray): Input image as a numpy array with shape (H, W, C) in RGB order.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized
|
||||
to [0, 1] with shape (C, H, W) in RGB order.
|
||||
(torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized to [0, 1] with
|
||||
shape (C, H, W) in RGB order.
|
||||
|
||||
Examples:
|
||||
>>> transform = ToTensor(half=True)
|
||||
|
|
|
|||
|
|
@ -21,8 +21,7 @@ from ultralytics.utils.patches import imread
|
|||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
"""
|
||||
Base dataset class for loading and processing image data.
|
||||
"""Base dataset class for loading and processing image data.
|
||||
|
||||
This class provides core functionality for loading images, caching, and preparing data for training and inference in
|
||||
object detection tasks.
|
||||
|
|
@ -86,8 +85,7 @@ class BaseDataset(Dataset):
|
|||
fraction: float = 1.0,
|
||||
channels: int = 3,
|
||||
):
|
||||
"""
|
||||
Initialize BaseDataset with given configuration and options.
|
||||
"""Initialize BaseDataset with given configuration and options.
|
||||
|
||||
Args:
|
||||
img_path (str | list[str]): Path to the folder containing images or list of image paths.
|
||||
|
|
@ -148,8 +146,7 @@ class BaseDataset(Dataset):
|
|||
self.transforms = self.build_transforms(hyp=hyp)
|
||||
|
||||
def get_img_files(self, img_path: str | list[str]) -> list[str]:
|
||||
"""
|
||||
Read image files from the specified path.
|
||||
"""Read image files from the specified path.
|
||||
|
||||
Args:
|
||||
img_path (str | list[str]): Path or list of paths to image directories or files.
|
||||
|
|
@ -186,8 +183,7 @@ class BaseDataset(Dataset):
|
|||
return im_files
|
||||
|
||||
def update_labels(self, include_class: list[int] | None) -> None:
|
||||
"""
|
||||
Update labels to include only specified classes.
|
||||
"""Update labels to include only specified classes.
|
||||
|
||||
Args:
|
||||
include_class (list[int], optional): List of classes to include. If None, all classes are included.
|
||||
|
|
@ -210,8 +206,7 @@ class BaseDataset(Dataset):
|
|||
self.labels[i]["cls"][:, 0] = 0
|
||||
|
||||
def load_image(self, i: int, rect_mode: bool = True) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
|
||||
"""
|
||||
Load an image from dataset index 'i'.
|
||||
"""Load an image from dataset index 'i'.
|
||||
|
||||
Args:
|
||||
i (int): Index of the image to load.
|
||||
|
|
@ -286,8 +281,7 @@ class BaseDataset(Dataset):
|
|||
np.save(f.as_posix(), imread(self.im_files[i]), allow_pickle=False)
|
||||
|
||||
def check_cache_disk(self, safety_margin: float = 0.5) -> bool:
|
||||
"""
|
||||
Check if there's enough disk space for caching images.
|
||||
"""Check if there's enough disk space for caching images.
|
||||
|
||||
Args:
|
||||
safety_margin (float): Safety margin factor for disk space calculation.
|
||||
|
|
@ -322,8 +316,7 @@ class BaseDataset(Dataset):
|
|||
return True
|
||||
|
||||
def check_cache_ram(self, safety_margin: float = 0.5) -> bool:
|
||||
"""
|
||||
Check if there's enough RAM for caching images.
|
||||
"""Check if there's enough RAM for caching images.
|
||||
|
||||
Args:
|
||||
safety_margin (float): Safety margin factor for RAM calculation.
|
||||
|
|
@ -381,8 +374,7 @@ class BaseDataset(Dataset):
|
|||
return self.transforms(self.get_image_and_label(index))
|
||||
|
||||
def get_image_and_label(self, index: int) -> dict[str, Any]:
|
||||
"""
|
||||
Get and return label information from the dataset.
|
||||
"""Get and return label information from the dataset.
|
||||
|
||||
Args:
|
||||
index (int): Index of the image to retrieve.
|
||||
|
|
@ -410,8 +402,7 @@ class BaseDataset(Dataset):
|
|||
return label
|
||||
|
||||
def build_transforms(self, hyp: dict[str, Any] | None = None):
|
||||
"""
|
||||
Users can customize augmentations here.
|
||||
"""Users can customize augmentations here.
|
||||
|
||||
Examples:
|
||||
>>> if self.augment:
|
||||
|
|
@ -424,8 +415,7 @@ class BaseDataset(Dataset):
|
|||
raise NotImplementedError
|
||||
|
||||
def get_labels(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Users can customize their own format here.
|
||||
"""Users can customize their own format here.
|
||||
|
||||
Examples:
|
||||
Ensure output is a dictionary with the following keys:
|
||||
|
|
|
|||
|
|
@ -35,8 +35,7 @@ from ultralytics.utils.torch_utils import TORCH_2_0
|
|||
|
||||
|
||||
class InfiniteDataLoader(dataloader.DataLoader):
|
||||
"""
|
||||
Dataloader that reuses workers for infinite iteration.
|
||||
"""Dataloader that reuses workers for infinite iteration.
|
||||
|
||||
This dataloader extends the PyTorch DataLoader to provide infinite recycling of workers, which improves efficiency
|
||||
for training loops that need to iterate through the dataset multiple times without recreating workers.
|
||||
|
|
@ -94,8 +93,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
|
|||
|
||||
|
||||
class _RepeatSampler:
|
||||
"""
|
||||
Sampler that repeats forever for infinite iteration.
|
||||
"""Sampler that repeats forever for infinite iteration.
|
||||
|
||||
This sampler wraps another sampler and yields its contents indefinitely, allowing for infinite iteration over a
|
||||
dataset without recreating the sampler.
|
||||
|
|
@ -115,8 +113,7 @@ class _RepeatSampler:
|
|||
|
||||
|
||||
class ContiguousDistributedSampler(torch.utils.data.Sampler):
|
||||
"""
|
||||
Distributed sampler that assigns contiguous batch-aligned chunks of the dataset to each GPU.
|
||||
"""Distributed sampler that assigns contiguous batch-aligned chunks of the dataset to each GPU.
|
||||
|
||||
Unlike PyTorch's DistributedSampler which distributes samples in a round-robin fashion (GPU 0 gets indices
|
||||
[0,2,4,...], GPU 1 gets [1,3,5,...]), this sampler gives each GPU contiguous batches of the dataset (GPU 0 gets
|
||||
|
|
@ -132,8 +129,8 @@ class ContiguousDistributedSampler(torch.utils.data.Sampler):
|
|||
num_replicas (int, optional): Number of distributed processes. Defaults to world size.
|
||||
batch_size (int, optional): Batch size used by dataloader. Defaults to dataset batch size.
|
||||
rank (int, optional): Rank of current process. Defaults to current rank.
|
||||
shuffle (bool, optional): Whether to shuffle indices within each rank's chunk. Defaults to False.
|
||||
When True, shuffling is deterministic and controlled by set_epoch() for reproducibility.
|
||||
shuffle (bool, optional): Whether to shuffle indices within each rank's chunk. Defaults to False. When True,
|
||||
shuffling is deterministic and controlled by set_epoch() for reproducibility.
|
||||
|
||||
Examples:
|
||||
>>> # For validation with size-grouped images
|
||||
|
|
@ -202,8 +199,7 @@ class ContiguousDistributedSampler(torch.utils.data.Sampler):
|
|||
return end_idx - start_idx
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
"""
|
||||
Set the epoch for this sampler to ensure different shuffling patterns across epochs.
|
||||
"""Set the epoch for this sampler to ensure different shuffling patterns across epochs.
|
||||
|
||||
Args:
|
||||
epoch (int): Epoch number to use as the random seed for shuffling.
|
||||
|
|
@ -289,8 +285,7 @@ def build_dataloader(
|
|||
drop_last: bool = False,
|
||||
pin_memory: bool = True,
|
||||
):
|
||||
"""
|
||||
Create and return an InfiniteDataLoader or DataLoader for training or validation.
|
||||
"""Create and return an InfiniteDataLoader or DataLoader for training or validation.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): Dataset to load data from.
|
||||
|
|
@ -337,8 +332,7 @@ def build_dataloader(
|
|||
|
||||
|
||||
def check_source(source):
|
||||
"""
|
||||
Check the type of input source and return corresponding flag values.
|
||||
"""Check the type of input source and return corresponding flag values.
|
||||
|
||||
Args:
|
||||
source (str | int | Path | list | tuple | np.ndarray | PIL.Image | torch.Tensor): The input source to check.
|
||||
|
|
@ -386,8 +380,7 @@ def check_source(source):
|
|||
|
||||
|
||||
def load_inference_source(source=None, batch: int = 1, vid_stride: int = 1, buffer: bool = False, channels: int = 3):
|
||||
"""
|
||||
Load an inference source for object detection and apply necessary transformations.
|
||||
"""Load an inference source for object detection and apply necessary transformations.
|
||||
|
||||
Args:
|
||||
source (str | Path | torch.Tensor | PIL.Image | np.ndarray, optional): The input source for inference.
|
||||
|
|
|
|||
|
|
@ -21,12 +21,11 @@ from ultralytics.utils.files import increment_path
|
|||
|
||||
|
||||
def coco91_to_coco80_class() -> list[int]:
|
||||
"""
|
||||
Convert 91-index COCO class IDs to 80-index COCO class IDs.
|
||||
"""Convert 91-index COCO class IDs to 80-index COCO class IDs.
|
||||
|
||||
Returns:
|
||||
(list[int]): A list of 91 class IDs where the index represents the 80-index class ID and the value
|
||||
is the corresponding 91-index class ID.
|
||||
(list[int]): A list of 91 class IDs where the index represents the 80-index class ID and the value is the
|
||||
corresponding 91-index class ID.
|
||||
"""
|
||||
return [
|
||||
0,
|
||||
|
|
@ -124,8 +123,7 @@ def coco91_to_coco80_class() -> list[int]:
|
|||
|
||||
|
||||
def coco80_to_coco91_class() -> list[int]:
|
||||
r"""
|
||||
Convert 80-index (val2014) to 91-index (paper).
|
||||
r"""Convert 80-index (val2014) to 91-index (paper).
|
||||
|
||||
Returns:
|
||||
(list[int]): A list of 80 class IDs where each value is the corresponding 91-index class ID.
|
||||
|
|
@ -236,8 +234,7 @@ def convert_coco(
|
|||
cls91to80: bool = True,
|
||||
lvis: bool = False,
|
||||
):
|
||||
"""
|
||||
Convert COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
|
||||
"""Convert COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
|
||||
|
||||
Args:
|
||||
labels_dir (str, optional): Path to directory containing COCO dataset annotation files.
|
||||
|
|
@ -348,8 +345,7 @@ def convert_coco(
|
|||
|
||||
|
||||
def convert_segment_masks_to_yolo_seg(masks_dir: str, output_dir: str, classes: int):
|
||||
"""
|
||||
Convert a dataset of segmentation mask images to the YOLO segmentation format.
|
||||
"""Convert a dataset of segmentation mask images to the YOLO segmentation format.
|
||||
|
||||
This function takes the directory containing the binary format mask images and converts them into YOLO segmentation
|
||||
format. The converted masks are saved in the specified output directory.
|
||||
|
|
@ -424,8 +420,7 @@ def convert_segment_masks_to_yolo_seg(masks_dir: str, output_dir: str, classes:
|
|||
|
||||
|
||||
def convert_dota_to_yolo_obb(dota_root_path: str):
|
||||
"""
|
||||
Convert DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format.
|
||||
"""Convert DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format.
|
||||
|
||||
The function processes images in the 'train' and 'val' folders of the DOTA dataset. For each image, it reads the
|
||||
associated label from the original labels directory and writes new labels in YOLO OBB format to a new directory.
|
||||
|
|
@ -517,8 +512,7 @@ def convert_dota_to_yolo_obb(dota_root_path: str):
|
|||
|
||||
|
||||
def min_index(arr1: np.ndarray, arr2: np.ndarray):
|
||||
"""
|
||||
Find a pair of indexes with the shortest distance between two arrays of 2D points.
|
||||
"""Find a pair of indexes with the shortest distance between two arrays of 2D points.
|
||||
|
||||
Args:
|
||||
arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points.
|
||||
|
|
@ -533,14 +527,14 @@ def min_index(arr1: np.ndarray, arr2: np.ndarray):
|
|||
|
||||
|
||||
def merge_multi_segment(segments: list[list]):
|
||||
"""
|
||||
Merge multiple segments into one list by connecting the coordinates with the minimum distance between each segment.
|
||||
"""Merge multiple segments into one list by connecting the coordinates with the minimum distance between each
|
||||
segment.
|
||||
|
||||
This function connects these coordinates with a thin line to merge all segments into one.
|
||||
|
||||
Args:
|
||||
segments (list[list]): Original segmentations in COCO's JSON file.
|
||||
Each element is a list of coordinates, like [segmentation1, segmentation2,...].
|
||||
segments (list[list]): Original segmentations in COCO's JSON file. Each element is a list of coordinates, like
|
||||
[segmentation1, segmentation2,...].
|
||||
|
||||
Returns:
|
||||
s (list[np.ndarray]): A list of connected segments represented as NumPy arrays.
|
||||
|
|
@ -584,14 +578,13 @@ def merge_multi_segment(segments: list[list]):
|
|||
|
||||
|
||||
def yolo_bbox2segment(im_dir: str | Path, save_dir: str | Path | None = None, sam_model: str = "sam_b.pt", device=None):
|
||||
"""
|
||||
Convert existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB) in
|
||||
YOLO format. Generate segmentation data using SAM auto-annotator as needed.
|
||||
"""Convert existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB)
|
||||
in YOLO format. Generate segmentation data using SAM auto-annotator as needed.
|
||||
|
||||
Args:
|
||||
im_dir (str | Path): Path to image directory to convert.
|
||||
save_dir (str | Path, optional): Path to save the generated labels, labels will be saved
|
||||
into `labels-segment` in the same directory level of `im_dir` if save_dir is None.
|
||||
save_dir (str | Path, optional): Path to save the generated labels, labels will be saved into `labels-segment`
|
||||
in the same directory level of `im_dir` if save_dir is None.
|
||||
sam_model (str): Segmentation model to use for intermediate segmentation data.
|
||||
device (int | str, optional): The specific device to run SAM models.
|
||||
|
||||
|
|
@ -648,8 +641,7 @@ def yolo_bbox2segment(im_dir: str | Path, save_dir: str | Path | None = None, sa
|
|||
|
||||
|
||||
def create_synthetic_coco_dataset():
|
||||
"""
|
||||
Create a synthetic COCO dataset with random images based on filenames from label lists.
|
||||
"""Create a synthetic COCO dataset with random images based on filenames from label lists.
|
||||
|
||||
This function downloads COCO labels, reads image filenames from label list files, creates synthetic images for
|
||||
train2017 and val2017 subsets, and organizes them in the COCO dataset structure. It uses multithreading to generate
|
||||
|
|
@ -704,8 +696,7 @@ def create_synthetic_coco_dataset():
|
|||
|
||||
|
||||
def convert_to_multispectral(path: str | Path, n_channels: int = 10, replace: bool = False, zip: bool = False):
|
||||
"""
|
||||
Convert RGB images to multispectral images by interpolating across wavelength bands.
|
||||
"""Convert RGB images to multispectral images by interpolating across wavelength bands.
|
||||
|
||||
This function takes RGB images and interpolates them to create multispectral images with a specified number of
|
||||
channels. It can process either a single image or a directory of images.
|
||||
|
|
@ -756,8 +747,7 @@ def convert_to_multispectral(path: str | Path, n_channels: int = 10, replace: bo
|
|||
|
||||
|
||||
async def convert_ndjson_to_yolo(ndjson_path: str | Path, output_path: str | Path | None = None) -> Path:
|
||||
"""
|
||||
Convert NDJSON dataset format to Ultralytics YOLO11 dataset structure.
|
||||
"""Convert NDJSON dataset format to Ultralytics YOLO11 dataset structure.
|
||||
|
||||
This function converts datasets stored in NDJSON (Newline Delimited JSON) format to the standard YOLO format with
|
||||
separate directories for images and labels. It supports parallel processing for efficient conversion of large
|
||||
|
|
@ -769,8 +759,8 @@ async def convert_ndjson_to_yolo(ndjson_path: str | Path, output_path: str | Pat
|
|||
|
||||
Args:
|
||||
ndjson_path (Union[str, Path]): Path to the input NDJSON file containing dataset information.
|
||||
output_path (Optional[Union[str, Path]], optional): Directory where the converted YOLO dataset
|
||||
will be saved. If None, uses the parent directory of the NDJSON file. Defaults to None.
|
||||
output_path (Optional[Union[str, Path]], optional): Directory where the converted YOLO dataset will be saved. If
|
||||
None, uses the parent directory of the NDJSON file. Defaults to None.
|
||||
|
||||
Returns:
|
||||
(Path): Path to the generated data.yaml file that can be used for YOLO training.
|
||||
|
|
|
|||
|
|
@ -47,8 +47,7 @@ DATASET_CACHE_VERSION = "1.0.3"
|
|||
|
||||
|
||||
class YOLODataset(BaseDataset):
|
||||
"""
|
||||
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
||||
"""Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
||||
|
||||
This class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box
|
||||
(OBB) tasks using the YOLO format.
|
||||
|
|
@ -73,8 +72,7 @@ class YOLODataset(BaseDataset):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
|
||||
"""
|
||||
Initialize the YOLODataset.
|
||||
"""Initialize the YOLODataset.
|
||||
|
||||
Args:
|
||||
data (dict, optional): Dataset configuration dictionary.
|
||||
|
|
@ -90,8 +88,7 @@ class YOLODataset(BaseDataset):
|
|||
super().__init__(*args, channels=self.data.get("channels", 3), **kwargs)
|
||||
|
||||
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict:
|
||||
"""
|
||||
Cache dataset labels, check images and read shapes.
|
||||
"""Cache dataset labels, check images and read shapes.
|
||||
|
||||
Args:
|
||||
path (Path): Path where to save the cache file.
|
||||
|
|
@ -158,8 +155,7 @@ class YOLODataset(BaseDataset):
|
|||
return x
|
||||
|
||||
def get_labels(self) -> list[dict]:
|
||||
"""
|
||||
Return dictionary of labels for YOLO training.
|
||||
"""Return dictionary of labels for YOLO training.
|
||||
|
||||
This method loads labels from disk or cache, verifies their integrity, and prepares them for training.
|
||||
|
||||
|
|
@ -208,8 +204,7 @@ class YOLODataset(BaseDataset):
|
|||
return labels
|
||||
|
||||
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
||||
"""
|
||||
Build and append transforms to the list.
|
||||
"""Build and append transforms to the list.
|
||||
|
||||
Args:
|
||||
hyp (dict, optional): Hyperparameters for transforms.
|
||||
|
|
@ -240,8 +235,7 @@ class YOLODataset(BaseDataset):
|
|||
return transforms
|
||||
|
||||
def close_mosaic(self, hyp: dict) -> None:
|
||||
"""
|
||||
Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
|
||||
"""Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
|
||||
|
||||
Args:
|
||||
hyp (dict): Hyperparameters for transforms.
|
||||
|
|
@ -253,8 +247,7 @@ class YOLODataset(BaseDataset):
|
|||
self.transforms = self.build_transforms(hyp)
|
||||
|
||||
def update_labels_info(self, label: dict) -> dict:
|
||||
"""
|
||||
Update label format for different tasks.
|
||||
"""Update label format for different tasks.
|
||||
|
||||
Args:
|
||||
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
|
||||
|
|
@ -287,8 +280,7 @@ class YOLODataset(BaseDataset):
|
|||
|
||||
@staticmethod
|
||||
def collate_fn(batch: list[dict]) -> dict:
|
||||
"""
|
||||
Collate data samples into batches.
|
||||
"""Collate data samples into batches.
|
||||
|
||||
Args:
|
||||
batch (list[dict]): List of dictionaries containing sample data.
|
||||
|
|
@ -317,8 +309,7 @@ class YOLODataset(BaseDataset):
|
|||
|
||||
|
||||
class YOLOMultiModalDataset(YOLODataset):
|
||||
"""
|
||||
Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
|
||||
"""Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
|
||||
|
||||
This class extends YOLODataset to add text information for multi-modal model training, enabling models to process
|
||||
both image and text data.
|
||||
|
|
@ -334,8 +325,7 @@ class YOLOMultiModalDataset(YOLODataset):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
|
||||
"""
|
||||
Initialize a YOLOMultiModalDataset.
|
||||
"""Initialize a YOLOMultiModalDataset.
|
||||
|
||||
Args:
|
||||
data (dict, optional): Dataset configuration dictionary.
|
||||
|
|
@ -346,8 +336,7 @@ class YOLOMultiModalDataset(YOLODataset):
|
|||
super().__init__(*args, data=data, task=task, **kwargs)
|
||||
|
||||
def update_labels_info(self, label: dict) -> dict:
|
||||
"""
|
||||
Add text information for multi-modal model training.
|
||||
"""Add text information for multi-modal model training.
|
||||
|
||||
Args:
|
||||
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
|
||||
|
|
@ -363,8 +352,7 @@ class YOLOMultiModalDataset(YOLODataset):
|
|||
return labels
|
||||
|
||||
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
||||
"""
|
||||
Enhance data transformations with optional text augmentation for multi-modal training.
|
||||
"""Enhance data transformations with optional text augmentation for multi-modal training.
|
||||
|
||||
Args:
|
||||
hyp (dict, optional): Hyperparameters for transforms.
|
||||
|
|
@ -388,8 +376,7 @@ class YOLOMultiModalDataset(YOLODataset):
|
|||
|
||||
@property
|
||||
def category_names(self):
|
||||
"""
|
||||
Return category names for the dataset.
|
||||
"""Return category names for the dataset.
|
||||
|
||||
Returns:
|
||||
(set[str]): List of class names.
|
||||
|
|
@ -418,8 +405,7 @@ class YOLOMultiModalDataset(YOLODataset):
|
|||
|
||||
|
||||
class GroundingDataset(YOLODataset):
|
||||
"""
|
||||
Dataset class for object detection tasks using annotations from a JSON file in grounding format.
|
||||
"""Dataset class for object detection tasks using annotations from a JSON file in grounding format.
|
||||
|
||||
This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than the standard
|
||||
YOLO format text files.
|
||||
|
|
@ -438,8 +424,7 @@ class GroundingDataset(YOLODataset):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, task: str = "detect", json_file: str = "", max_samples: int = 80, **kwargs):
|
||||
"""
|
||||
Initialize a GroundingDataset for object detection.
|
||||
"""Initialize a GroundingDataset for object detection.
|
||||
|
||||
Args:
|
||||
json_file (str): Path to the JSON file containing annotations.
|
||||
|
|
@ -454,8 +439,7 @@ class GroundingDataset(YOLODataset):
|
|||
super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
|
||||
|
||||
def get_img_files(self, img_path: str) -> list:
|
||||
"""
|
||||
The image files would be read in `get_labels` function, return empty list here.
|
||||
"""The image files would be read in `get_labels` function, return empty list here.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the directory containing images.
|
||||
|
|
@ -466,21 +450,19 @@ class GroundingDataset(YOLODataset):
|
|||
return []
|
||||
|
||||
def verify_labels(self, labels: list[dict[str, Any]]) -> None:
|
||||
"""
|
||||
Verify the number of instances in the dataset matches expected counts.
|
||||
"""Verify the number of instances in the dataset matches expected counts.
|
||||
|
||||
This method checks if the total number of bounding box instances in the provided labels matches the expected
|
||||
count for known datasets. It performs validation against a predefined set of datasets with known instance
|
||||
counts.
|
||||
|
||||
Args:
|
||||
labels (list[dict[str, Any]]): List of label dictionaries, where each dictionary
|
||||
contains dataset annotations. Each label dict must have a 'bboxes' key with a numpy array or tensor
|
||||
containing bounding box coordinates.
|
||||
labels (list[dict[str, Any]]): List of label dictionaries, where each dictionary contains dataset
|
||||
annotations. Each label dict must have a 'bboxes' key with a numpy array or tensor containing bounding
|
||||
box coordinates.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the actual instance count doesn't match the expected count
|
||||
for a recognized dataset.
|
||||
AssertionError: If the actual instance count doesn't match the expected count for a recognized dataset.
|
||||
|
||||
Notes:
|
||||
For unrecognized datasets (those not in the predefined expected_counts),
|
||||
|
|
@ -501,8 +483,7 @@ class GroundingDataset(YOLODataset):
|
|||
LOGGER.warning(f"Skipping instance count verification for unrecognized dataset '{self.json_file}'")
|
||||
|
||||
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict[str, Any]:
|
||||
"""
|
||||
Load annotations from a JSON file, filter, and normalize bounding boxes for each image.
|
||||
"""Load annotations from a JSON file, filter, and normalize bounding boxes for each image.
|
||||
|
||||
Args:
|
||||
path (Path): Path where to save the cache file.
|
||||
|
|
@ -592,8 +573,7 @@ class GroundingDataset(YOLODataset):
|
|||
return x
|
||||
|
||||
def get_labels(self) -> list[dict]:
|
||||
"""
|
||||
Load labels from cache or generate them from JSON file.
|
||||
"""Load labels from cache or generate them from JSON file.
|
||||
|
||||
Returns:
|
||||
(list[dict]): List of label dictionaries, each containing information about an image and its annotations.
|
||||
|
|
@ -614,8 +594,7 @@ class GroundingDataset(YOLODataset):
|
|||
return labels
|
||||
|
||||
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
||||
"""
|
||||
Configure augmentations for training with optional text loading.
|
||||
"""Configure augmentations for training with optional text loading.
|
||||
|
||||
Args:
|
||||
hyp (dict, optional): Hyperparameters for transforms.
|
||||
|
|
@ -661,8 +640,7 @@ class GroundingDataset(YOLODataset):
|
|||
|
||||
|
||||
class YOLOConcatDataset(ConcatDataset):
|
||||
"""
|
||||
Dataset as a concatenation of multiple datasets.
|
||||
"""Dataset as a concatenation of multiple datasets.
|
||||
|
||||
This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same collation
|
||||
function.
|
||||
|
|
@ -678,8 +656,7 @@ class YOLOConcatDataset(ConcatDataset):
|
|||
|
||||
@staticmethod
|
||||
def collate_fn(batch: list[dict]) -> dict:
|
||||
"""
|
||||
Collate data samples into batches.
|
||||
"""Collate data samples into batches.
|
||||
|
||||
Args:
|
||||
batch (list[dict]): List of dictionaries containing sample data.
|
||||
|
|
@ -690,8 +667,7 @@ class YOLOConcatDataset(ConcatDataset):
|
|||
return YOLODataset.collate_fn(batch)
|
||||
|
||||
def close_mosaic(self, hyp: dict) -> None:
|
||||
"""
|
||||
Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
|
||||
"""Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
|
||||
|
||||
Args:
|
||||
hyp (dict): Hyperparameters for transforms.
|
||||
|
|
@ -712,8 +688,7 @@ class SemanticDataset(BaseDataset):
|
|||
|
||||
|
||||
class ClassificationDataset:
|
||||
"""
|
||||
Dataset class for image classification tasks extending torchvision ImageFolder functionality.
|
||||
"""Dataset class for image classification tasks extending torchvision ImageFolder functionality.
|
||||
|
||||
This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently
|
||||
handle large datasets for training deep learning models, with optional image transformations and caching mechanisms
|
||||
|
|
@ -735,8 +710,7 @@ class ClassificationDataset:
|
|||
"""
|
||||
|
||||
def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
|
||||
"""
|
||||
Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings.
|
||||
"""Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings.
|
||||
|
||||
Args:
|
||||
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
||||
|
|
@ -787,8 +761,7 @@ class ClassificationDataset:
|
|||
)
|
||||
|
||||
def __getitem__(self, i: int) -> dict:
|
||||
"""
|
||||
Return subset of data and targets corresponding to given indices.
|
||||
"""Return subset of data and targets corresponding to given indices.
|
||||
|
||||
Args:
|
||||
i (int): Index of the sample to retrieve.
|
||||
|
|
@ -816,8 +789,7 @@ class ClassificationDataset:
|
|||
return len(self.samples)
|
||||
|
||||
def verify_images(self) -> list[tuple]:
|
||||
"""
|
||||
Verify all images in dataset.
|
||||
"""Verify all images in dataset.
|
||||
|
||||
Returns:
|
||||
(list): List of valid samples after verification.
|
||||
|
|
|
|||
|
|
@ -25,8 +25,7 @@ from ultralytics.utils.patches import imread
|
|||
|
||||
@dataclass
|
||||
class SourceTypes:
|
||||
"""
|
||||
Class to represent various types of input sources for predictions.
|
||||
"""Class to represent various types of input sources for predictions.
|
||||
|
||||
This class uses dataclass to define boolean flags for different types of input sources that can be used for making
|
||||
predictions with YOLO models.
|
||||
|
|
@ -52,8 +51,7 @@ class SourceTypes:
|
|||
|
||||
|
||||
class LoadStreams:
|
||||
"""
|
||||
Stream Loader for various types of video streams.
|
||||
"""Stream Loader for various types of video streams.
|
||||
|
||||
Supports RTSP, RTMP, HTTP, and TCP streams. This class handles the loading and processing of multiple video streams
|
||||
simultaneously, making it suitable for real-time video analysis tasks.
|
||||
|
|
@ -94,8 +92,7 @@ class LoadStreams:
|
|||
"""
|
||||
|
||||
def __init__(self, sources: str = "file.streams", vid_stride: int = 1, buffer: bool = False, channels: int = 3):
|
||||
"""
|
||||
Initialize stream loader for multiple video sources, supporting various stream types.
|
||||
"""Initialize stream loader for multiple video sources, supporting various stream types.
|
||||
|
||||
Args:
|
||||
sources (str): Path to streams file or single stream URL.
|
||||
|
|
@ -227,8 +224,7 @@ class LoadStreams:
|
|||
|
||||
|
||||
class LoadScreenshots:
|
||||
"""
|
||||
Ultralytics screenshot dataloader for capturing and processing screen images.
|
||||
"""Ultralytics screenshot dataloader for capturing and processing screen images.
|
||||
|
||||
This class manages the loading of screenshot images for processing with YOLO. It is suitable for use with `yolo
|
||||
predict source=screen`.
|
||||
|
|
@ -259,8 +255,7 @@ class LoadScreenshots:
|
|||
"""
|
||||
|
||||
def __init__(self, source: str, channels: int = 3):
|
||||
"""
|
||||
Initialize screenshot capture with specified screen and region parameters.
|
||||
"""Initialize screenshot capture with specified screen and region parameters.
|
||||
|
||||
Args:
|
||||
source (str): Screen capture source string in format "screen_num left top width height".
|
||||
|
|
@ -307,8 +302,7 @@ class LoadScreenshots:
|
|||
|
||||
|
||||
class LoadImagesAndVideos:
|
||||
"""
|
||||
A class for loading and processing images and videos for YOLO object detection.
|
||||
"""A class for loading and processing images and videos for YOLO object detection.
|
||||
|
||||
This class manages the loading and pre-processing of image and video data from various sources, including single
|
||||
image files, video files, and lists of image and video paths.
|
||||
|
|
@ -347,8 +341,7 @@ class LoadImagesAndVideos:
|
|||
"""
|
||||
|
||||
def __init__(self, path: str | Path | list, batch: int = 1, vid_stride: int = 1, channels: int = 3):
|
||||
"""
|
||||
Initialize dataloader for images and videos, supporting various input formats.
|
||||
"""Initialize dataloader for images and videos, supporting various input formats.
|
||||
|
||||
Args:
|
||||
path (str | Path | list): Path to images/videos, directory, or list of paths.
|
||||
|
|
@ -490,8 +483,7 @@ class LoadImagesAndVideos:
|
|||
|
||||
|
||||
class LoadPilAndNumpy:
|
||||
"""
|
||||
Load images from PIL and Numpy arrays for batch processing.
|
||||
"""Load images from PIL and Numpy arrays for batch processing.
|
||||
|
||||
This class manages loading and pre-processing of image data from both PIL and Numpy formats. It performs basic
|
||||
validation and format conversion to ensure that the images are in the required format for downstream processing.
|
||||
|
|
@ -517,8 +509,7 @@ class LoadPilAndNumpy:
|
|||
"""
|
||||
|
||||
def __init__(self, im0: Image.Image | np.ndarray | list, channels: int = 3):
|
||||
"""
|
||||
Initialize a loader for PIL and Numpy images, converting inputs to a standardized format.
|
||||
"""Initialize a loader for PIL and Numpy images, converting inputs to a standardized format.
|
||||
|
||||
Args:
|
||||
im0 (PIL.Image.Image | np.ndarray | list): Single image or list of images in PIL or numpy format.
|
||||
|
|
@ -564,8 +555,7 @@ class LoadPilAndNumpy:
|
|||
|
||||
|
||||
class LoadTensor:
|
||||
"""
|
||||
A class for loading and processing tensor data for object detection tasks.
|
||||
"""A class for loading and processing tensor data for object detection tasks.
|
||||
|
||||
This class handles the loading and pre-processing of image data from PyTorch tensors, preparing them for further
|
||||
processing in object detection pipelines.
|
||||
|
|
@ -588,8 +578,7 @@ class LoadTensor:
|
|||
"""
|
||||
|
||||
def __init__(self, im0: torch.Tensor) -> None:
|
||||
"""
|
||||
Initialize LoadTensor object for processing torch.Tensor image data.
|
||||
"""Initialize LoadTensor object for processing torch.Tensor image data.
|
||||
|
||||
Args:
|
||||
im0 (torch.Tensor): Input tensor with shape (B, C, H, W).
|
||||
|
|
@ -656,8 +645,7 @@ def autocast_list(source: list[Any]) -> list[Image.Image | np.ndarray]:
|
|||
|
||||
|
||||
def get_best_youtube_url(url: str, method: str = "pytube") -> str | None:
|
||||
"""
|
||||
Retrieve the URL of the best quality MP4 video stream from a given YouTube video.
|
||||
"""Retrieve the URL of the best quality MP4 video stream from a given YouTube video.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the YouTube video.
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ from ultralytics.utils import DATASETS_DIR, LOGGER, TQDM
|
|||
|
||||
|
||||
def split_classify_dataset(source_dir: str | Path, train_ratio: float = 0.8) -> Path:
|
||||
"""
|
||||
Split classification dataset into train and val directories in a new directory.
|
||||
"""Split classification dataset into train and val directories in a new directory.
|
||||
|
||||
Creates a new directory '{source_dir}_split' with train/val subdirectories, preserving the original class structure
|
||||
with an 80/20 split by default.
|
||||
|
|
@ -101,8 +100,8 @@ def autosplit(
|
|||
weights: tuple[float, float, float] = (0.9, 0.1, 0.0),
|
||||
annotated_only: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
|
||||
"""Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt
|
||||
files.
|
||||
|
||||
Args:
|
||||
path (Path): Path to images directory.
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@ from ultralytics.utils.checks import check_requirements
|
|||
|
||||
|
||||
def bbox_iof(polygon1: np.ndarray, bbox2: np.ndarray, eps: float = 1e-6) -> np.ndarray:
|
||||
"""
|
||||
Calculate Intersection over Foreground (IoF) between polygons and bounding boxes.
|
||||
"""Calculate Intersection over Foreground (IoF) between polygons and bounding boxes.
|
||||
|
||||
Args:
|
||||
polygon1 (np.ndarray): Polygon coordinates with shape (N, 8).
|
||||
|
|
@ -65,8 +64,7 @@ def bbox_iof(polygon1: np.ndarray, bbox2: np.ndarray, eps: float = 1e-6) -> np.n
|
|||
|
||||
|
||||
def load_yolo_dota(data_root: str, split: str = "train") -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load DOTA dataset annotations and image information.
|
||||
"""Load DOTA dataset annotations and image information.
|
||||
|
||||
Args:
|
||||
data_root (str): Data root directory.
|
||||
|
|
@ -107,8 +105,7 @@ def get_windows(
|
|||
im_rate_thr: float = 0.6,
|
||||
eps: float = 0.01,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Get the coordinates of sliding windows for image cropping.
|
||||
"""Get the coordinates of sliding windows for image cropping.
|
||||
|
||||
Args:
|
||||
im_size (tuple[int, int]): Original image size, (H, W).
|
||||
|
|
@ -175,8 +172,7 @@ def crop_and_save(
|
|||
lb_dir: str,
|
||||
allow_background_images: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Crop images and save new labels for each window.
|
||||
"""Crop images and save new labels for each window.
|
||||
|
||||
Args:
|
||||
anno (dict[str, Any]): Annotation dict, including 'filepath', 'label', 'ori_size' as its keys.
|
||||
|
|
@ -226,8 +222,7 @@ def split_images_and_labels(
|
|||
crop_sizes: tuple[int, ...] = (1024,),
|
||||
gaps: tuple[int, ...] = (200,),
|
||||
) -> None:
|
||||
"""
|
||||
Split both images and labels for a given dataset split.
|
||||
"""Split both images and labels for a given dataset split.
|
||||
|
||||
Args:
|
||||
data_root (str): Root directory of the dataset.
|
||||
|
|
@ -265,8 +260,7 @@ def split_images_and_labels(
|
|||
def split_trainval(
|
||||
data_root: str, save_dir: str, crop_size: int = 1024, gap: int = 200, rates: tuple[float, ...] = (1.0,)
|
||||
) -> None:
|
||||
"""
|
||||
Split train and val sets of DOTA dataset with multiple scaling rates.
|
||||
"""Split train and val sets of DOTA dataset with multiple scaling rates.
|
||||
|
||||
Args:
|
||||
data_root (str): Root directory of the dataset.
|
||||
|
|
@ -304,8 +298,7 @@ def split_trainval(
|
|||
def split_test(
|
||||
data_root: str, save_dir: str, crop_size: int = 1024, gap: int = 200, rates: tuple[float, ...] = (1.0,)
|
||||
) -> None:
|
||||
"""
|
||||
Split test set of DOTA dataset, labels are not included within this set.
|
||||
"""Split test set of DOTA dataset, labels are not included within this set.
|
||||
|
||||
Args:
|
||||
data_root (str): Root directory of the dataset.
|
||||
|
|
|
|||
|
|
@ -51,8 +51,7 @@ def img2label_paths(img_paths: list[str]) -> list[str]:
|
|||
def check_file_speeds(
|
||||
files: list[str], threshold_ms: float = 10, threshold_mb: float = 50, max_files: int = 5, prefix: str = ""
|
||||
):
|
||||
"""
|
||||
Check dataset file access speed and provide performance feedback.
|
||||
"""Check dataset file access speed and provide performance feedback.
|
||||
|
||||
This function tests the access speed of dataset files by measuring ping (stat call) time and read speed. It samples
|
||||
up to 5 files from the provided list and warns if access times exceed the threshold.
|
||||
|
|
@ -251,8 +250,7 @@ def verify_image_label(args: tuple) -> list:
|
|||
|
||||
|
||||
def visualize_image_annotations(image_path: str, txt_path: str, label_map: dict[int, str]):
|
||||
"""
|
||||
Visualize YOLO annotations (bounding boxes and class labels) on an image.
|
||||
"""Visualize YOLO annotations (bounding boxes and class labels) on an image.
|
||||
|
||||
This function reads an image and its corresponding annotation file in YOLO format, then draws bounding boxes around
|
||||
detected objects and labels them with their respective class names. The bounding box colors are assigned based on
|
||||
|
|
@ -297,13 +295,12 @@ def visualize_image_annotations(image_path: str, txt_path: str, label_map: dict[
|
|||
def polygon2mask(
|
||||
imgsz: tuple[int, int], polygons: list[np.ndarray], color: int = 1, downsample_ratio: int = 1
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Convert a list of polygons to a binary mask of the specified image size.
|
||||
"""Convert a list of polygons to a binary mask of the specified image size.
|
||||
|
||||
Args:
|
||||
imgsz (tuple[int, int]): The size of the image as (height, width).
|
||||
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where
|
||||
N is the number of polygons, and M is the number of points such that M % 2 = 0.
|
||||
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where N is the
|
||||
number of polygons, and M is the number of points such that M % 2 = 0.
|
||||
color (int, optional): The color value to fill in the polygons on the mask.
|
||||
downsample_ratio (int, optional): Factor by which to downsample the mask.
|
||||
|
||||
|
|
@ -322,13 +319,12 @@ def polygon2mask(
|
|||
def polygons2masks(
|
||||
imgsz: tuple[int, int], polygons: list[np.ndarray], color: int, downsample_ratio: int = 1
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Convert a list of polygons to a set of binary masks of the specified image size.
|
||||
"""Convert a list of polygons to a set of binary masks of the specified image size.
|
||||
|
||||
Args:
|
||||
imgsz (tuple[int, int]): The size of the image as (height, width).
|
||||
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where
|
||||
N is the number of polygons, and M is the number of points such that M % 2 = 0.
|
||||
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where N is the
|
||||
number of polygons, and M is the number of points such that M % 2 = 0.
|
||||
color (int): The color value to fill in the polygons on the masks.
|
||||
downsample_ratio (int, optional): Factor by which to downsample each mask.
|
||||
|
||||
|
|
@ -368,8 +364,7 @@ def polygons2masks_overlap(
|
|||
|
||||
|
||||
def find_dataset_yaml(path: Path) -> Path:
|
||||
"""
|
||||
Find and return the YAML file associated with a Detect, Segment or Pose dataset.
|
||||
"""Find and return the YAML file associated with a Detect, Segment or Pose dataset.
|
||||
|
||||
This function searches for a YAML file at the root level of the provided directory first, and if not found, it
|
||||
performs a recursive search. It prefers YAML files that have the same stem as the provided path.
|
||||
|
|
@ -389,8 +384,7 @@ def find_dataset_yaml(path: Path) -> Path:
|
|||
|
||||
|
||||
def check_det_dataset(dataset: str, autodownload: bool = True) -> dict[str, Any]:
|
||||
"""
|
||||
Download, verify, and/or unzip a dataset if not found locally.
|
||||
"""Download, verify, and/or unzip a dataset if not found locally.
|
||||
|
||||
This function checks the availability of a specified dataset, and if not found, it has the option to download and
|
||||
unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
|
||||
|
|
@ -484,8 +478,7 @@ def check_det_dataset(dataset: str, autodownload: bool = True) -> dict[str, Any]
|
|||
|
||||
|
||||
def check_cls_dataset(dataset: str | Path, split: str = "") -> dict[str, Any]:
|
||||
"""
|
||||
Check a classification dataset such as Imagenet.
|
||||
"""Check a classification dataset such as Imagenet.
|
||||
|
||||
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information. If the
|
||||
dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
|
||||
|
|
@ -581,8 +574,7 @@ def check_cls_dataset(dataset: str | Path, split: str = "") -> dict[str, Any]:
|
|||
|
||||
|
||||
class HUBDatasetStats:
|
||||
"""
|
||||
A class for generating HUB dataset JSON and `-hub` dataset directory.
|
||||
"""A class for generating HUB dataset JSON and `-hub` dataset directory.
|
||||
|
||||
Args:
|
||||
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip).
|
||||
|
|
@ -748,10 +740,9 @@ class HUBDatasetStats:
|
|||
|
||||
|
||||
def compress_one_image(f: str, f_new: str | None = None, max_dim: int = 1920, quality: int = 50):
|
||||
"""
|
||||
Compress a single image file to reduced size while preserving its aspect ratio and quality using either the Python
|
||||
Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
|
||||
resized.
|
||||
"""Compress a single image file to reduced size while preserving its aspect ratio and quality using either the
|
||||
Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it
|
||||
will not be resized.
|
||||
|
||||
Args:
|
||||
f (str): The path to the input image file.
|
||||
|
|
|
|||
|
|
@ -192,8 +192,7 @@ def best_onnx_opset(onnx, cuda=False) -> int:
|
|||
|
||||
|
||||
def validate_args(format, passed_args, valid_args):
|
||||
"""
|
||||
Validate arguments based on the export format.
|
||||
"""Validate arguments based on the export format.
|
||||
|
||||
Args:
|
||||
format (str): The export format.
|
||||
|
|
@ -238,8 +237,7 @@ def try_export(inner_func):
|
|||
|
||||
|
||||
class Exporter:
|
||||
"""
|
||||
A class for exporting YOLO models to various formats.
|
||||
"""A class for exporting YOLO models to various formats.
|
||||
|
||||
This class provides functionality to export YOLO models to different formats including ONNX, TensorRT, CoreML,
|
||||
TensorFlow, and others. It handles format validation, device selection, model preparation, and the actual export
|
||||
|
|
@ -289,8 +287,7 @@ class Exporter:
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the Exporter class.
|
||||
"""Initialize the Exporter class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file.
|
||||
|
|
@ -1359,8 +1356,7 @@ class IOSDetectModel(torch.nn.Module):
|
|||
"""Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
|
||||
|
||||
def __init__(self, model, im, mlprogram=True):
|
||||
"""
|
||||
Initialize the IOSDetectModel class with a YOLO model and example image.
|
||||
"""Initialize the IOSDetectModel class with a YOLO model and example image.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The YOLO model to wrap.
|
||||
|
|
@ -1394,8 +1390,7 @@ class NMSModel(torch.nn.Module):
|
|||
"""Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
|
||||
|
||||
def __init__(self, model, args):
|
||||
"""
|
||||
Initialize the NMSModel.
|
||||
"""Initialize the NMSModel.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to wrap with NMS postprocessing.
|
||||
|
|
@ -1408,15 +1403,14 @@ class NMSModel(torch.nn.Module):
|
|||
self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Perform inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
|
||||
"""Perform inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the
|
||||
number of detections after NMS.
|
||||
(torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the number
|
||||
of detections after NMS.
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
|
|
|
|||
|
|
@ -27,8 +27,7 @@ from ultralytics.utils import (
|
|||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
"""
|
||||
A base class for implementing YOLO models, unifying APIs across different model types.
|
||||
"""A base class for implementing YOLO models, unifying APIs across different model types.
|
||||
|
||||
This class provides a common interface for various operations related to YOLO models, such as training, validation,
|
||||
prediction, exporting, and benchmarking. It handles different types of models, including those loaded from local
|
||||
|
|
@ -85,19 +84,17 @@ class Model(torch.nn.Module):
|
|||
task: str | None = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a new instance of the YOLO model class.
|
||||
"""Initialize a new instance of the YOLO model class.
|
||||
|
||||
This constructor sets up the model based on the provided model path or name. It handles various types of model
|
||||
sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several
|
||||
important attributes of the model and prepares it for operations like training, prediction, or export.
|
||||
|
||||
Args:
|
||||
model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a
|
||||
model name from Ultralytics HUB, a Triton Server model, or an already initialized Model instance.
|
||||
model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a model
|
||||
name from Ultralytics HUB, a Triton Server model, or an already initialized Model instance.
|
||||
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
|
||||
verbose (bool): If True, enables verbose output during the model's initialization and subsequent
|
||||
operations.
|
||||
verbose (bool): If True, enables verbose output during the model's initialization and subsequent operations.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified model file does not exist or is inaccessible.
|
||||
|
|
@ -160,22 +157,21 @@ class Model(torch.nn.Module):
|
|||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list:
|
||||
"""
|
||||
Alias for the predict method, enabling the model instance to be callable for predictions.
|
||||
"""Alias for the predict method, enabling the model instance to be callable for predictions.
|
||||
|
||||
This method simplifies the process of making predictions by allowing the model instance to be called directly
|
||||
with the required arguments.
|
||||
|
||||
Args:
|
||||
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of
|
||||
the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch tensor, or
|
||||
a list/tuple of these.
|
||||
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image(s)
|
||||
to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch tensor, or a list/tuple
|
||||
of these.
|
||||
stream (bool): If True, treat the input source as a continuous stream for predictions.
|
||||
**kwargs (Any): Additional keyword arguments to configure the prediction process.
|
||||
|
||||
Returns:
|
||||
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
|
||||
Results object.
|
||||
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a Results
|
||||
object.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO("yolo11n.pt")
|
||||
|
|
@ -187,8 +183,7 @@ class Model(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def is_triton_model(model: str) -> bool:
|
||||
"""
|
||||
Check if the given model string is a Triton Server URL.
|
||||
"""Check if the given model string is a Triton Server URL.
|
||||
|
||||
This static method determines whether the provided model string represents a valid Triton Server URL by parsing
|
||||
its components using urllib.parse.urlsplit().
|
||||
|
|
@ -212,8 +207,7 @@ class Model(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def is_hub_model(model: str) -> bool:
|
||||
"""
|
||||
Check if the provided model is an Ultralytics HUB model.
|
||||
"""Check if the provided model is an Ultralytics HUB model.
|
||||
|
||||
This static method determines whether the given model string represents a valid Ultralytics HUB model
|
||||
identifier.
|
||||
|
|
@ -235,8 +229,7 @@ class Model(torch.nn.Module):
|
|||
return model.startswith(f"{HUB_WEB_ROOT}/models/")
|
||||
|
||||
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
|
||||
"""
|
||||
Initialize a new model and infer the task type from model definitions.
|
||||
"""Initialize a new model and infer the task type from model definitions.
|
||||
|
||||
Creates a new model instance based on the provided configuration file. Loads the model configuration, infers the
|
||||
task type if not specified, and initializes the model using the appropriate class from the task map.
|
||||
|
|
@ -244,8 +237,8 @@ class Model(torch.nn.Module):
|
|||
Args:
|
||||
cfg (str): Path to the model configuration file in YAML format.
|
||||
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
|
||||
model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of
|
||||
creating a new one.
|
||||
model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of creating
|
||||
a new one.
|
||||
verbose (bool): If True, displays model information during loading.
|
||||
|
||||
Raises:
|
||||
|
|
@ -269,8 +262,7 @@ class Model(torch.nn.Module):
|
|||
self.model_name = cfg
|
||||
|
||||
def _load(self, weights: str, task=None) -> None:
|
||||
"""
|
||||
Load a model from a checkpoint file or initialize it from a weights file.
|
||||
"""Load a model from a checkpoint file or initialize it from a weights file.
|
||||
|
||||
This method handles loading models from either .pt checkpoint files or other weight file formats. It sets up the
|
||||
model, task, and related attributes based on the loaded weights.
|
||||
|
|
@ -307,8 +299,7 @@ class Model(torch.nn.Module):
|
|||
self.model_name = weights
|
||||
|
||||
def _check_is_pytorch_model(self) -> None:
|
||||
"""
|
||||
Check if the model is a PyTorch model and raise TypeError if it's not.
|
||||
"""Check if the model is a PyTorch model and raise TypeError if it's not.
|
||||
|
||||
This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that certain
|
||||
operations that require a PyTorch model are only performed on compatible model types.
|
||||
|
|
@ -335,8 +326,7 @@ class Model(torch.nn.Module):
|
|||
)
|
||||
|
||||
def reset_weights(self) -> Model:
|
||||
"""
|
||||
Reset the model's weights to their initial state.
|
||||
"""Reset the model's weights to their initial state.
|
||||
|
||||
This method iterates through all modules in the model and resets their parameters if they have a
|
||||
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them
|
||||
|
|
@ -361,8 +351,7 @@ class Model(torch.nn.Module):
|
|||
return self
|
||||
|
||||
def load(self, weights: str | Path = "yolo11n.pt") -> Model:
|
||||
"""
|
||||
Load parameters from the specified weights file into the model.
|
||||
"""Load parameters from the specified weights file into the model.
|
||||
|
||||
This method supports loading weights from a file or directly from a weights object. It matches parameters by
|
||||
name and shape and transfers them to the model.
|
||||
|
|
@ -389,8 +378,7 @@ class Model(torch.nn.Module):
|
|||
return self
|
||||
|
||||
def save(self, filename: str | Path = "saved_model.pt") -> None:
|
||||
"""
|
||||
Save the current model state to a file.
|
||||
"""Save the current model state to a file.
|
||||
|
||||
This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as the
|
||||
date, Ultralytics version, license information, and a link to the documentation.
|
||||
|
|
@ -421,8 +409,7 @@ class Model(torch.nn.Module):
|
|||
torch.save({**self.ckpt, **updates}, filename)
|
||||
|
||||
def info(self, detailed: bool = False, verbose: bool = True):
|
||||
"""
|
||||
Display model information.
|
||||
"""Display model information.
|
||||
|
||||
This method provides an overview or detailed information about the model, depending on the arguments
|
||||
passed. It can control the verbosity of the output and return the information as a list.
|
||||
|
|
@ -432,8 +419,8 @@ class Model(torch.nn.Module):
|
|||
verbose (bool): If True, prints the information. If False, returns the information as a list.
|
||||
|
||||
Returns:
|
||||
(list[str]): A list of strings containing various types of information about the model, including
|
||||
model summary, layer details, and parameter counts. Empty if verbose is True.
|
||||
(list[str]): A list of strings containing various types of information about the model, including model
|
||||
summary, layer details, and parameter counts. Empty if verbose is True.
|
||||
|
||||
Examples:
|
||||
>>> model = Model("yolo11n.pt")
|
||||
|
|
@ -444,8 +431,7 @@ class Model(torch.nn.Module):
|
|||
return self.model.info(detailed=detailed, verbose=verbose)
|
||||
|
||||
def fuse(self) -> None:
|
||||
"""
|
||||
Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.
|
||||
"""Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.
|
||||
|
||||
This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers into a
|
||||
single layer. This fusion can significantly improve inference speed by reducing the number of operations and
|
||||
|
|
@ -469,15 +455,14 @@ class Model(torch.nn.Module):
|
|||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list:
|
||||
"""
|
||||
Generate image embeddings based on the provided source.
|
||||
"""Generate image embeddings based on the provided source.
|
||||
|
||||
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
|
||||
source. It allows customization of the embedding process through various keyword arguments.
|
||||
|
||||
Args:
|
||||
source (str | Path | int | list | tuple | np.ndarray | torch.Tensor): The source of the image for
|
||||
generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.
|
||||
source (str | Path | int | list | tuple | np.ndarray | torch.Tensor): The source of the image for generating
|
||||
embeddings. Can be a file path, URL, PIL image, numpy array, etc.
|
||||
stream (bool): If True, predictions are streamed.
|
||||
**kwargs (Any): Additional keyword arguments for configuring the embedding process.
|
||||
|
||||
|
|
@ -501,25 +486,24 @@ class Model(torch.nn.Module):
|
|||
predictor=None,
|
||||
**kwargs: Any,
|
||||
) -> list[Results]:
|
||||
"""
|
||||
Perform predictions on the given image source using the YOLO model.
|
||||
"""Perform predictions on the given image source using the YOLO model.
|
||||
|
||||
This method facilitates the prediction process, allowing various configurations through keyword arguments. It
|
||||
supports predictions with custom predictors or the default predictor method. The method handles different types
|
||||
of image sources and can operate in a streaming mode.
|
||||
|
||||
Args:
|
||||
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
|
||||
of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL images,
|
||||
numpy arrays, and torch tensors.
|
||||
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image(s)
|
||||
to make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
|
||||
torch tensors.
|
||||
stream (bool): If True, treats the input source as a continuous stream for predictions.
|
||||
predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
|
||||
If None, the method uses a default predictor.
|
||||
predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions. If
|
||||
None, the method uses a default predictor.
|
||||
**kwargs (Any): Additional keyword arguments for configuring the prediction process.
|
||||
|
||||
Returns:
|
||||
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
|
||||
Results object.
|
||||
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a Results
|
||||
object.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO("yolo11n.pt")
|
||||
|
|
@ -562,8 +546,7 @@ class Model(torch.nn.Module):
|
|||
persist: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[Results]:
|
||||
"""
|
||||
Conduct object tracking on the specified input source using the registered trackers.
|
||||
"""Conduct object tracking on the specified input source using the registered trackers.
|
||||
|
||||
This method performs object tracking using the model's predictors and optionally registered trackers. It handles
|
||||
various input sources such as file paths or video streams, and supports customization through keyword arguments.
|
||||
|
|
@ -604,8 +587,7 @@ class Model(torch.nn.Module):
|
|||
validator=None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Validate the model using a specified dataset and validation configuration.
|
||||
"""Validate the model using a specified dataset and validation configuration.
|
||||
|
||||
This method facilitates the model validation process, allowing for customization through various settings. It
|
||||
supports validation with a custom validator or the default validation approach. The method combines default
|
||||
|
|
@ -636,8 +618,7 @@ class Model(torch.nn.Module):
|
|||
return validator.metrics
|
||||
|
||||
def benchmark(self, data=None, format="", verbose=False, **kwargs: Any):
|
||||
"""
|
||||
Benchmark the model across various export formats to evaluate performance.
|
||||
"""Benchmark the model across various export formats to evaluate performance.
|
||||
|
||||
This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. It
|
||||
uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured using
|
||||
|
|
@ -655,8 +636,8 @@ class Model(torch.nn.Module):
|
|||
- device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the results of the benchmarking process, including metrics for
|
||||
different export formats.
|
||||
(dict): A dictionary containing the results of the benchmarking process, including metrics for different
|
||||
export formats.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
|
|
@ -690,8 +671,7 @@ class Model(torch.nn.Module):
|
|||
self,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Export the model to a different format suitable for deployment.
|
||||
"""Export the model to a different format suitable for deployment.
|
||||
|
||||
This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
|
||||
purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
|
||||
|
|
@ -738,8 +718,7 @@ class Model(torch.nn.Module):
|
|||
trainer=None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Train the model using the specified dataset and training configuration.
|
||||
"""Train the model using the specified dataset and training configuration.
|
||||
|
||||
This method facilitates model training with a range of customizable settings. It supports training with a custom
|
||||
trainer or the default training approach. The method handles scenarios such as resuming training from a
|
||||
|
|
@ -811,8 +790,7 @@ class Model(torch.nn.Module):
|
|||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Conduct hyperparameter tuning for the model, with an option to use Ray Tune.
|
||||
"""Conduct hyperparameter tuning for the model, with an option to use Ray Tune.
|
||||
|
||||
This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. When Ray Tune
|
||||
is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. Otherwise, it uses
|
||||
|
|
@ -853,16 +831,15 @@ class Model(torch.nn.Module):
|
|||
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
||||
|
||||
def _apply(self, fn) -> Model:
|
||||
"""
|
||||
Apply a function to model tensors that are not parameters or registered buffers.
|
||||
"""Apply a function to model tensors that are not parameters or registered buffers.
|
||||
|
||||
This method extends the functionality of the parent class's _apply method by additionally resetting the
|
||||
predictor and updating the device in the model's overrides. It's typically used for operations like moving the
|
||||
model to a different device or changing its precision.
|
||||
|
||||
Args:
|
||||
fn (Callable): A function to be applied to the model's tensors. This is typically a method like
|
||||
to(), cpu(), cuda(), half(), or float().
|
||||
fn (Callable): A function to be applied to the model's tensors. This is typically a method like to(), cpu(),
|
||||
cuda(), half(), or float().
|
||||
|
||||
Returns:
|
||||
(Model): The model instance with the function applied and updated attributes.
|
||||
|
|
@ -882,8 +859,7 @@ class Model(torch.nn.Module):
|
|||
|
||||
@property
|
||||
def names(self) -> dict[int, str]:
|
||||
"""
|
||||
Retrieve the class names associated with the loaded model.
|
||||
"""Retrieve the class names associated with the loaded model.
|
||||
|
||||
This property returns the class names if they are defined in the model. It checks the class names for validity
|
||||
using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
|
||||
|
|
@ -913,8 +889,7 @@ class Model(torch.nn.Module):
|
|||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""
|
||||
Get the device on which the model's parameters are allocated.
|
||||
"""Get the device on which the model's parameters are allocated.
|
||||
|
||||
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
|
||||
applicable only to models that are instances of torch.nn.Module.
|
||||
|
|
@ -937,8 +912,7 @@ class Model(torch.nn.Module):
|
|||
|
||||
@property
|
||||
def transforms(self):
|
||||
"""
|
||||
Retrieve the transformations applied to the input data of the loaded model.
|
||||
"""Retrieve the transformations applied to the input data of the loaded model.
|
||||
|
||||
This property returns the transformations if they are defined in the model. The transforms typically include
|
||||
preprocessing steps like resizing, normalization, and data augmentation that are applied to input data before it
|
||||
|
|
@ -958,18 +932,17 @@ class Model(torch.nn.Module):
|
|||
return self.model.transforms if hasattr(self.model, "transforms") else None
|
||||
|
||||
def add_callback(self, event: str, func) -> None:
|
||||
"""
|
||||
Add a callback function for a specified event.
|
||||
"""Add a callback function for a specified event.
|
||||
|
||||
This method allows registering custom callback functions that are triggered on specific events during model
|
||||
operations such as training or inference. Callbacks provide a way to extend and customize the behavior of the
|
||||
model at various stages of its lifecycle.
|
||||
|
||||
Args:
|
||||
event (str): The name of the event to attach the callback to. Must be a valid event name recognized
|
||||
by the Ultralytics framework.
|
||||
func (Callable): The callback function to be registered. This function will be called when the
|
||||
specified event occurs.
|
||||
event (str): The name of the event to attach the callback to. Must be a valid event name recognized by the
|
||||
Ultralytics framework.
|
||||
func (Callable): The callback function to be registered. This function will be called when the specified
|
||||
event occurs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the event name is not recognized or is invalid.
|
||||
|
|
@ -984,8 +957,7 @@ class Model(torch.nn.Module):
|
|||
self.callbacks[event].append(func)
|
||||
|
||||
def clear_callback(self, event: str) -> None:
|
||||
"""
|
||||
Clear all callback functions registered for a specified event.
|
||||
"""Clear all callback functions registered for a specified event.
|
||||
|
||||
This method removes all custom and default callback functions associated with the given event. It resets the
|
||||
callback list for the specified event to an empty list, effectively removing all registered callbacks for that
|
||||
|
|
@ -1012,8 +984,7 @@ class Model(torch.nn.Module):
|
|||
self.callbacks[event] = []
|
||||
|
||||
def reset_callbacks(self) -> None:
|
||||
"""
|
||||
Reset all callbacks to their default functions.
|
||||
"""Reset all callbacks to their default functions.
|
||||
|
||||
This method reinstates the default callback functions for all events, removing any custom callbacks that were
|
||||
previously added. It iterates through all default callback events and replaces the current callbacks with the
|
||||
|
|
@ -1036,8 +1007,7 @@ class Model(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Reset specific arguments when loading a PyTorch model checkpoint.
|
||||
"""Reset specific arguments when loading a PyTorch model checkpoint.
|
||||
|
||||
This method filters the input arguments dictionary to retain only a specific set of keys that are considered
|
||||
important for model loading. It's used to ensure that only relevant arguments are preserved when loading a model
|
||||
|
|
@ -1064,8 +1034,7 @@ class Model(torch.nn.Module):
|
|||
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def _smart_load(self, key: str):
|
||||
"""
|
||||
Intelligently load the appropriate module based on the model task.
|
||||
"""Intelligently load the appropriate module based on the model task.
|
||||
|
||||
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor) based
|
||||
on the current task of the model and the provided key. It uses the task_map dictionary to determine the
|
||||
|
|
@ -1094,8 +1063,7 @@ class Model(torch.nn.Module):
|
|||
|
||||
@property
|
||||
def task_map(self) -> dict:
|
||||
"""
|
||||
Provide a mapping from model tasks to corresponding classes for different modes.
|
||||
"""Provide a mapping from model tasks to corresponding classes for different modes.
|
||||
|
||||
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify) to a
|
||||
nested dictionary. The nested dictionary contains mappings for different operational modes (model, trainer,
|
||||
|
|
@ -1119,8 +1087,7 @@ class Model(torch.nn.Module):
|
|||
raise NotImplementedError("Please provide task map for your model!")
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
Sets the model to evaluation mode.
|
||||
"""Sets the model to evaluation mode.
|
||||
|
||||
This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization
|
||||
that behave differently during training and evaluation. In evaluation mode, these layers use running statistics
|
||||
|
|
@ -1138,8 +1105,7 @@ class Model(torch.nn.Module):
|
|||
return self
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
Enable accessing model attributes directly through the Model class.
|
||||
"""Enable accessing model attributes directly through the Model class.
|
||||
|
||||
This method provides a way to access attributes of the underlying model directly through the Model class
|
||||
instance. It first checks if the requested attribute is 'model', in which case it returns the model from
|
||||
|
|
|
|||
|
|
@ -68,8 +68,7 @@ Example:
|
|||
|
||||
|
||||
class BasePredictor:
|
||||
"""
|
||||
A base class for creating predictors.
|
||||
"""A base class for creating predictors.
|
||||
|
||||
This class provides the foundation for prediction functionality, handling model setup, inference, and result
|
||||
processing across various input sources.
|
||||
|
|
@ -115,8 +114,7 @@ class BasePredictor:
|
|||
overrides: dict[str, Any] | None = None,
|
||||
_callbacks: dict[str, list[callable]] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the BasePredictor class.
|
||||
"""Initialize the BasePredictor class.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Path to a configuration file or a configuration dictionary.
|
||||
|
|
@ -151,8 +149,7 @@ class BasePredictor:
|
|||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor:
|
||||
"""
|
||||
Prepare input image before inference.
|
||||
"""Prepare input image before inference.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | list[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.
|
||||
|
|
@ -185,8 +182,7 @@ class BasePredictor:
|
|||
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
||||
|
||||
def pre_transform(self, im: list[np.ndarray]) -> list[np.ndarray]:
|
||||
"""
|
||||
Pre-transform input image before inference.
|
||||
"""Pre-transform input image before inference.
|
||||
|
||||
Args:
|
||||
im (list[np.ndarray]): List of images with shape [(H, W, 3) x N].
|
||||
|
|
@ -209,8 +205,7 @@ class BasePredictor:
|
|||
return preds
|
||||
|
||||
def __call__(self, source=None, model=None, stream: bool = False, *args, **kwargs):
|
||||
"""
|
||||
Perform inference on an image or stream.
|
||||
"""Perform inference on an image or stream.
|
||||
|
||||
Args:
|
||||
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
|
||||
|
|
@ -230,8 +225,7 @@ class BasePredictor:
|
|||
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
|
||||
|
||||
def predict_cli(self, source=None, model=None):
|
||||
"""
|
||||
Method used for Command Line Interface (CLI) prediction.
|
||||
"""Method used for Command Line Interface (CLI) prediction.
|
||||
|
||||
This function is designed to run predictions using the CLI. It sets up the source and model, then processes the
|
||||
inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
|
||||
|
|
@ -251,12 +245,11 @@ class BasePredictor:
|
|||
pass
|
||||
|
||||
def setup_source(self, source):
|
||||
"""
|
||||
Set up source and inference mode.
|
||||
"""Set up source and inference mode.
|
||||
|
||||
Args:
|
||||
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor):
|
||||
Source for inference.
|
||||
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor): Source for
|
||||
inference.
|
||||
"""
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
self.dataset = load_inference_source(
|
||||
|
|
@ -282,8 +275,7 @@ class BasePredictor:
|
|||
|
||||
@smart_inference_mode()
|
||||
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
||||
"""
|
||||
Stream real-time inference on camera feed and save results to file.
|
||||
"""Stream real-time inference on camera feed and save results to file.
|
||||
|
||||
Args:
|
||||
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
|
||||
|
|
@ -388,8 +380,7 @@ class BasePredictor:
|
|||
self.run_callbacks("on_predict_end")
|
||||
|
||||
def setup_model(self, model, verbose: bool = True):
|
||||
"""
|
||||
Initialize YOLO model with given parameters and set it to evaluation mode.
|
||||
"""Initialize YOLO model with given parameters and set it to evaluation mode.
|
||||
|
||||
Args:
|
||||
model (str | Path | torch.nn.Module, optional): Model to load or use.
|
||||
|
|
@ -413,8 +404,7 @@ class BasePredictor:
|
|||
self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
|
||||
|
||||
def write_results(self, i: int, p: Path, im: torch.Tensor, s: list[str]) -> str:
|
||||
"""
|
||||
Write inference results to a file or directory.
|
||||
"""Write inference results to a file or directory.
|
||||
|
||||
Args:
|
||||
i (int): Index of the current image in the batch.
|
||||
|
|
@ -464,8 +454,7 @@ class BasePredictor:
|
|||
return string
|
||||
|
||||
def save_predicted_images(self, save_path: Path, frame: int = 0):
|
||||
"""
|
||||
Save video predictions as mp4 or images as jpg at specified path.
|
||||
"""Save video predictions as mp4 or images as jpg at specified path.
|
||||
|
||||
Args:
|
||||
save_path (Path): Path to save the results.
|
||||
|
|
|
|||
|
|
@ -21,8 +21,7 @@ from ultralytics.utils.plotting import Annotator, colors, save_one_box
|
|||
|
||||
|
||||
class BaseTensor(SimpleClass):
|
||||
"""
|
||||
Base tensor class with additional methods for easy manipulation and device handling.
|
||||
"""Base tensor class with additional methods for easy manipulation and device handling.
|
||||
|
||||
This class provides a foundation for tensor-like objects with device management capabilities, supporting both
|
||||
PyTorch tensors and NumPy arrays. It includes methods for moving data between devices and converting between tensor
|
||||
|
|
@ -49,8 +48,7 @@ class BaseTensor(SimpleClass):
|
|||
"""
|
||||
|
||||
def __init__(self, data: torch.Tensor | np.ndarray, orig_shape: tuple[int, int]) -> None:
|
||||
"""
|
||||
Initialize BaseTensor with prediction data and the original shape of the image.
|
||||
"""Initialize BaseTensor with prediction data and the original shape of the image.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor | np.ndarray): Prediction data such as bounding boxes, masks, or keypoints.
|
||||
|
|
@ -68,8 +66,7 @@ class BaseTensor(SimpleClass):
|
|||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
"""
|
||||
Return the shape of the underlying data tensor.
|
||||
"""Return the shape of the underlying data tensor.
|
||||
|
||||
Returns:
|
||||
(tuple[int, ...]): The shape of the data tensor.
|
||||
|
|
@ -83,8 +80,7 @@ class BaseTensor(SimpleClass):
|
|||
return self.data.shape
|
||||
|
||||
def cpu(self):
|
||||
"""
|
||||
Return a copy of the tensor stored in CPU memory.
|
||||
"""Return a copy of the tensor stored in CPU memory.
|
||||
|
||||
Returns:
|
||||
(BaseTensor): A new BaseTensor object with the data tensor moved to CPU memory.
|
||||
|
|
@ -101,8 +97,7 @@ class BaseTensor(SimpleClass):
|
|||
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
"""
|
||||
Return a copy of the tensor as a numpy array.
|
||||
"""Return a copy of the tensor as a numpy array.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): A numpy array containing the same data as the original tensor.
|
||||
|
|
@ -118,12 +113,11 @@ class BaseTensor(SimpleClass):
|
|||
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
"""
|
||||
Move the tensor to GPU memory.
|
||||
"""Move the tensor to GPU memory.
|
||||
|
||||
Returns:
|
||||
(BaseTensor): A new BaseTensor instance with the data moved to GPU memory if it's not already a
|
||||
numpy array, otherwise returns self.
|
||||
(BaseTensor): A new BaseTensor instance with the data moved to GPU memory if it's not already a numpy array,
|
||||
otherwise returns self.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
|
|
@ -137,8 +131,7 @@ class BaseTensor(SimpleClass):
|
|||
return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""
|
||||
Return a copy of the tensor with the specified device and dtype.
|
||||
"""Return a copy of the tensor with the specified device and dtype.
|
||||
|
||||
Args:
|
||||
*args (Any): Variable length argument list to be passed to torch.Tensor.to().
|
||||
|
|
@ -155,8 +148,7 @@ class BaseTensor(SimpleClass):
|
|||
return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Return the length of the underlying data tensor.
|
||||
"""Return the length of the underlying data tensor.
|
||||
|
||||
Returns:
|
||||
(int): The number of elements in the first dimension of the data tensor.
|
||||
|
|
@ -170,8 +162,7 @@ class BaseTensor(SimpleClass):
|
|||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Return a new BaseTensor instance containing the specified indexed elements of the data tensor.
|
||||
"""Return a new BaseTensor instance containing the specified indexed elements of the data tensor.
|
||||
|
||||
Args:
|
||||
idx (int | list[int] | torch.Tensor): Index or indices to select from the data tensor.
|
||||
|
|
@ -190,8 +181,7 @@ class BaseTensor(SimpleClass):
|
|||
|
||||
|
||||
class Results(SimpleClass, DataExportMixin):
|
||||
"""
|
||||
A class for storing and manipulating inference results.
|
||||
"""A class for storing and manipulating inference results.
|
||||
|
||||
This class provides comprehensive functionality for handling inference results from various Ultralytics models,
|
||||
including detection, segmentation, classification, and pose estimation. It supports visualization, data export, and
|
||||
|
|
@ -249,8 +239,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
obb: torch.Tensor | None = None,
|
||||
speed: dict[str, float] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Results class for storing and manipulating inference results.
|
||||
"""Initialize the Results class for storing and manipulating inference results.
|
||||
|
||||
Args:
|
||||
orig_img (np.ndarray): The original image as a numpy array.
|
||||
|
|
@ -290,8 +279,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
self._keys = "boxes", "masks", "probs", "keypoints", "obb"
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Return a Results object for a specific index of inference results.
|
||||
"""Return a Results object for a specific index of inference results.
|
||||
|
||||
Args:
|
||||
idx (int | slice): Index or slice to retrieve from the Results object.
|
||||
|
|
@ -307,12 +295,11 @@ class Results(SimpleClass, DataExportMixin):
|
|||
return self._apply("__getitem__", idx)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Return the number of detections in the Results object.
|
||||
"""Return the number of detections in the Results object.
|
||||
|
||||
Returns:
|
||||
(int): The number of detections, determined by the length of the first non-empty
|
||||
attribute in (masks, probs, keypoints, or obb).
|
||||
(int): The number of detections, determined by the length of the first non-empty attribute in (masks, probs,
|
||||
keypoints, or obb).
|
||||
|
||||
Examples:
|
||||
>>> results = Results(orig_img, path, names, boxes=torch.rand(5, 4))
|
||||
|
|
@ -332,15 +319,14 @@ class Results(SimpleClass, DataExportMixin):
|
|||
obb: torch.Tensor | None = None,
|
||||
keypoints: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
Update the Results object with new detection data.
|
||||
"""Update the Results object with new detection data.
|
||||
|
||||
This method allows updating the boxes, masks, probabilities, and oriented bounding boxes (OBB) of the Results
|
||||
object. It ensures that boxes are clipped to the original image shape.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor | None): A tensor of shape (N, 6) containing bounding box coordinates and
|
||||
confidence scores. The format is (x1, y1, x2, y2, conf, class).
|
||||
boxes (torch.Tensor | None): A tensor of shape (N, 6) containing bounding box coordinates and confidence
|
||||
scores. The format is (x1, y1, x2, y2, conf, class).
|
||||
masks (torch.Tensor | None): A tensor of shape (N, H, W) containing segmentation masks.
|
||||
probs (torch.Tensor | None): A tensor of shape (num_classes,) containing class probabilities.
|
||||
obb (torch.Tensor | None): A tensor of shape (N, 5) containing oriented bounding box coordinates.
|
||||
|
|
@ -363,8 +349,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
self.keypoints = Keypoints(keypoints, self.orig_shape)
|
||||
|
||||
def _apply(self, fn: str, *args, **kwargs):
|
||||
"""
|
||||
Apply a function to all non-empty attributes and return a new Results object with modified attributes.
|
||||
"""Apply a function to all non-empty attributes and return a new Results object with modified attributes.
|
||||
|
||||
This method is internally called by methods like .to(), .cuda(), .cpu(), etc.
|
||||
|
||||
|
|
@ -390,8 +375,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
return r
|
||||
|
||||
def cpu(self):
|
||||
"""
|
||||
Return a copy of the Results object with all its tensors moved to CPU memory.
|
||||
"""Return a copy of the Results object with all its tensors moved to CPU memory.
|
||||
|
||||
This method creates a new Results object with all tensor attributes (boxes, masks, probs, keypoints, obb)
|
||||
transferred to CPU memory. It's useful for moving data from GPU to CPU for further processing or saving.
|
||||
|
|
@ -407,8 +391,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
return self._apply("cpu")
|
||||
|
||||
def numpy(self):
|
||||
"""
|
||||
Convert all tensors in the Results object to numpy arrays.
|
||||
"""Convert all tensors in the Results object to numpy arrays.
|
||||
|
||||
Returns:
|
||||
(Results): A new Results object with all tensors converted to numpy arrays.
|
||||
|
|
@ -426,8 +409,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
return self._apply("numpy")
|
||||
|
||||
def cuda(self):
|
||||
"""
|
||||
Move all tensors in the Results object to GPU memory.
|
||||
"""Move all tensors in the Results object to GPU memory.
|
||||
|
||||
Returns:
|
||||
(Results): A new Results object with all tensors moved to CUDA device.
|
||||
|
|
@ -441,8 +423,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
return self._apply("cuda")
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""
|
||||
Move all tensors in the Results object to the specified device and dtype.
|
||||
"""Move all tensors in the Results object to the specified device and dtype.
|
||||
|
||||
Args:
|
||||
*args (Any): Variable length argument list to be passed to torch.Tensor.to().
|
||||
|
|
@ -460,8 +441,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
return self._apply("to", *args, **kwargs)
|
||||
|
||||
def new(self):
|
||||
"""
|
||||
Create a new Results object with the same image, path, names, and speed attributes.
|
||||
"""Create a new Results object with the same image, path, names, and speed attributes.
|
||||
|
||||
Returns:
|
||||
(Results): A new Results object with copied attributes from the original instance.
|
||||
|
|
@ -493,8 +473,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
color_mode: str = "class",
|
||||
txt_color: tuple[int, int, int] = (255, 255, 255),
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Plot detection results on an input RGB image.
|
||||
"""Plot detection results on an input RGB image.
|
||||
|
||||
Args:
|
||||
conf (bool): Whether to plot detection confidence scores.
|
||||
|
|
@ -613,8 +592,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
return annotator.im if pil else annotator.result()
|
||||
|
||||
def show(self, *args, **kwargs):
|
||||
"""
|
||||
Display the image with annotated inference results.
|
||||
"""Display the image with annotated inference results.
|
||||
|
||||
This method plots the detection results on the original image and displays it. It's a convenient way to
|
||||
visualize the model's predictions directly.
|
||||
|
|
@ -632,15 +610,14 @@ class Results(SimpleClass, DataExportMixin):
|
|||
self.plot(show=True, *args, **kwargs)
|
||||
|
||||
def save(self, filename: str | None = None, *args, **kwargs) -> str:
|
||||
"""
|
||||
Save annotated inference results image to file.
|
||||
"""Save annotated inference results image to file.
|
||||
|
||||
This method plots the detection results on the original image and saves the annotated image to a file. It
|
||||
utilizes the `plot` method to generate the annotated image and then saves it to the specified filename.
|
||||
|
||||
Args:
|
||||
filename (str | Path | None): The filename to save the annotated image. If None, a default filename
|
||||
is generated based on the original image path.
|
||||
filename (str | Path | None): The filename to save the annotated image. If None, a default filename is
|
||||
generated based on the original image path.
|
||||
*args (Any): Variable length argument list to be passed to the `plot` method.
|
||||
**kwargs (Any): Arbitrary keyword arguments to be passed to the `plot` method.
|
||||
|
||||
|
|
@ -661,15 +638,14 @@ class Results(SimpleClass, DataExportMixin):
|
|||
return filename
|
||||
|
||||
def verbose(self) -> str:
|
||||
"""
|
||||
Return a log string for each task in the results, detailing detection and classification outcomes.
|
||||
"""Return a log string for each task in the results, detailing detection and classification outcomes.
|
||||
|
||||
This method generates a human-readable string summarizing the detection and classification results. It includes
|
||||
the number of detections for each class and the top probabilities for classification tasks.
|
||||
|
||||
Returns:
|
||||
(str): A formatted string containing a summary of the results. For detection tasks, it includes the
|
||||
number of detections per class. For classification tasks, it includes the top 5 class probabilities.
|
||||
(str): A formatted string containing a summary of the results. For detection tasks, it includes the number
|
||||
of detections per class. For classification tasks, it includes the top 5 class probabilities.
|
||||
|
||||
Examples:
|
||||
>>> results = model("path/to/image.jpg")
|
||||
|
|
@ -693,8 +669,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
return "".join(f"{n} {self.names[i]}{'s' * (n > 1)}, " for i, n in enumerate(counts) if n > 0)
|
||||
|
||||
def save_txt(self, txt_file: str | Path, save_conf: bool = False) -> str:
|
||||
"""
|
||||
Save detection results to a text file.
|
||||
"""Save detection results to a text file.
|
||||
|
||||
Args:
|
||||
txt_file (str | Path): Path to the output text file.
|
||||
|
|
@ -750,8 +725,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
return str(txt_file)
|
||||
|
||||
def save_crop(self, save_dir: str | Path, file_name: str | Path = Path("im.jpg")):
|
||||
"""
|
||||
Save cropped detection images to specified directory.
|
||||
"""Save cropped detection images to specified directory.
|
||||
|
||||
This method saves cropped images of detected objects to a specified directory. Each crop is saved in a
|
||||
subdirectory named after the object's class, with the filename based on the input file_name.
|
||||
|
|
@ -786,8 +760,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
)
|
||||
|
||||
def summary(self, normalize: bool = False, decimals: int = 5) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert inference results to a summarized dictionary with optional normalization for box coordinates.
|
||||
"""Convert inference results to a summarized dictionary with optional normalization for box coordinates.
|
||||
|
||||
This method creates a list of detection dictionaries, each containing information about a single detection or
|
||||
classification result. For classification tasks, it returns the top class and its
|
||||
|
|
@ -853,8 +826,7 @@ class Results(SimpleClass, DataExportMixin):
|
|||
|
||||
|
||||
class Boxes(BaseTensor):
|
||||
"""
|
||||
A class for managing and manipulating detection boxes.
|
||||
"""A class for managing and manipulating detection boxes.
|
||||
|
||||
This class provides comprehensive functionality for handling detection boxes, including their coordinates,
|
||||
confidence scores, class labels, and optional tracking IDs. It supports various box formats and offers methods for
|
||||
|
|
@ -890,17 +862,15 @@ class Boxes(BaseTensor):
|
|||
"""
|
||||
|
||||
def __init__(self, boxes: torch.Tensor | np.ndarray, orig_shape: tuple[int, int]) -> None:
|
||||
"""
|
||||
Initialize the Boxes class with detection box data and the original image shape.
|
||||
"""Initialize the Boxes class with detection box data and the original image shape.
|
||||
|
||||
This class manages detection boxes, providing easy access and manipulation of box coordinates, confidence
|
||||
scores, class identifiers, and optional tracking IDs. It supports multiple formats for box coordinates,
|
||||
including both absolute and normalized forms.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor | np.ndarray): A tensor or numpy array with detection boxes of shape
|
||||
(num_boxes, 6) or (num_boxes, 7). Columns should contain [x1, y1, x2, y2, (optional) track_id,
|
||||
confidence, class].
|
||||
boxes (torch.Tensor | np.ndarray): A tensor or numpy array with detection boxes of shape (num_boxes, 6) or
|
||||
(num_boxes, 7). Columns should contain [x1, y1, x2, y2, (optional) track_id, confidence, class].
|
||||
orig_shape (tuple[int, int]): The original image shape as (height, width). Used for normalization.
|
||||
|
||||
Attributes:
|
||||
|
|
@ -926,12 +896,11 @@ class Boxes(BaseTensor):
|
|||
|
||||
@property
|
||||
def xyxy(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return bounding boxes in [x1, y1, x2, y2] format.
|
||||
"""Return bounding boxes in [x1, y1, x2, y2] format.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): A tensor or numpy array of shape (n, 4) containing bounding box
|
||||
coordinates in [x1, y1, x2, y2] format, where n is the number of boxes.
|
||||
(torch.Tensor | np.ndarray): A tensor or numpy array of shape (n, 4) containing bounding box coordinates in
|
||||
[x1, y1, x2, y2] format, where n is the number of boxes.
|
||||
|
||||
Examples:
|
||||
>>> results = model("image.jpg")
|
||||
|
|
@ -943,12 +912,11 @@ class Boxes(BaseTensor):
|
|||
|
||||
@property
|
||||
def conf(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return the confidence scores for each detection box.
|
||||
"""Return the confidence scores for each detection box.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): A 1D tensor or array containing confidence scores for each detection,
|
||||
with shape (N,) where N is the number of detections.
|
||||
(torch.Tensor | np.ndarray): A 1D tensor or array containing confidence scores for each detection, with
|
||||
shape (N,) where N is the number of detections.
|
||||
|
||||
Examples:
|
||||
>>> boxes = Boxes(torch.tensor([[10, 20, 30, 40, 0.9, 0]]), orig_shape=(100, 100))
|
||||
|
|
@ -960,12 +928,11 @@ class Boxes(BaseTensor):
|
|||
|
||||
@property
|
||||
def cls(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return the class ID tensor representing category predictions for each bounding box.
|
||||
"""Return the class ID tensor representing category predictions for each bounding box.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): A tensor or numpy array containing the class IDs for each detection box.
|
||||
The shape is (N,), where N is the number of boxes.
|
||||
(torch.Tensor | np.ndarray): A tensor or numpy array containing the class IDs for each detection box. The
|
||||
shape is (N,), where N is the number of boxes.
|
||||
|
||||
Examples:
|
||||
>>> results = model("image.jpg")
|
||||
|
|
@ -977,12 +944,11 @@ class Boxes(BaseTensor):
|
|||
|
||||
@property
|
||||
def id(self) -> torch.Tensor | np.ndarray | None:
|
||||
"""
|
||||
Return the tracking IDs for each detection box if available.
|
||||
"""Return the tracking IDs for each detection box if available.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | None): A tensor containing tracking IDs for each box if tracking is enabled,
|
||||
otherwise None. Shape is (N,) where N is the number of boxes.
|
||||
(torch.Tensor | None): A tensor containing tracking IDs for each box if tracking is enabled, otherwise None.
|
||||
Shape is (N,) where N is the number of boxes.
|
||||
|
||||
Examples:
|
||||
>>> results = model.track("path/to/video.mp4")
|
||||
|
|
@ -1003,13 +969,12 @@ class Boxes(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xywh(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Convert bounding boxes from [x1, y1, x2, y2] format to [x, y, width, height] format.
|
||||
"""Convert bounding boxes from [x1, y1, x2, y2] format to [x, y, width, height] format.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): Boxes in [x_center, y_center, width, height] format, where x_center,
|
||||
y_center are the coordinates of the center point of the bounding box, width, height are the dimensions
|
||||
of the bounding box and the shape of the returned tensor is (N, 4), where N is the number of boxes.
|
||||
(torch.Tensor | np.ndarray): Boxes in [x_center, y_center, width, height] format, where x_center, y_center
|
||||
are the coordinates of the center point of the bounding box, width, height are the dimensions of the
|
||||
bounding box and the shape of the returned tensor is (N, 4), where N is the number of boxes.
|
||||
|
||||
Examples:
|
||||
>>> boxes = Boxes(torch.tensor([[100, 50, 150, 100], [200, 150, 300, 250]]), orig_shape=(480, 640))
|
||||
|
|
@ -1023,15 +988,14 @@ class Boxes(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxyn(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return normalized bounding box coordinates relative to the original image size.
|
||||
"""Return normalized bounding box coordinates relative to the original image size.
|
||||
|
||||
This property calculates and returns the bounding box coordinates in [x1, y1, x2, y2] format, normalized to the
|
||||
range [0, 1] based on the original image dimensions.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): Normalized bounding box coordinates with shape (N, 4), where N is
|
||||
the number of boxes. Each row contains [x1, y1, x2, y2] values normalized to [0, 1].
|
||||
(torch.Tensor | np.ndarray): Normalized bounding box coordinates with shape (N, 4), where N is the number of
|
||||
boxes. Each row contains [x1, y1, x2, y2] values normalized to [0, 1].
|
||||
|
||||
Examples:
|
||||
>>> boxes = Boxes(torch.tensor([[100, 50, 300, 400, 0.9, 0]]), orig_shape=(480, 640))
|
||||
|
|
@ -1047,16 +1011,15 @@ class Boxes(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xywhn(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return normalized bounding boxes in [x, y, width, height] format.
|
||||
"""Return normalized bounding boxes in [x, y, width, height] format.
|
||||
|
||||
This property calculates and returns the normalized bounding box coordinates in the format [x_center, y_center,
|
||||
width, height], where all values are relative to the original image dimensions.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): Normalized bounding boxes with shape (N, 4), where N is the
|
||||
number of boxes. Each row contains [x_center, y_center, width, height] values normalized to [0, 1] based
|
||||
on the original image dimensions.
|
||||
(torch.Tensor | np.ndarray): Normalized bounding boxes with shape (N, 4), where N is the number of boxes.
|
||||
Each row contains [x_center, y_center, width, height] values normalized to [0, 1] based on the original
|
||||
image dimensions.
|
||||
|
||||
Examples:
|
||||
>>> boxes = Boxes(torch.tensor([[100, 50, 150, 100, 0.9, 0]]), orig_shape=(480, 640))
|
||||
|
|
@ -1071,8 +1034,7 @@ class Boxes(BaseTensor):
|
|||
|
||||
|
||||
class Masks(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection masks.
|
||||
"""A class for storing and manipulating detection masks.
|
||||
|
||||
This class extends BaseTensor and provides functionality for handling segmentation masks, including methods for
|
||||
converting between pixel and normalized coordinates.
|
||||
|
|
@ -1098,8 +1060,7 @@ class Masks(BaseTensor):
|
|||
"""
|
||||
|
||||
def __init__(self, masks: torch.Tensor | np.ndarray, orig_shape: tuple[int, int]) -> None:
|
||||
"""
|
||||
Initialize the Masks class with detection mask data and the original image shape.
|
||||
"""Initialize the Masks class with detection mask data and the original image shape.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor | np.ndarray): Detection masks with shape (num_masks, height, width).
|
||||
|
|
@ -1119,15 +1080,14 @@ class Masks(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xyn(self) -> list[np.ndarray]:
|
||||
"""
|
||||
Return normalized xy-coordinates of the segmentation masks.
|
||||
"""Return normalized xy-coordinates of the segmentation masks.
|
||||
|
||||
This property calculates and caches the normalized xy-coordinates of the segmentation masks. The coordinates are
|
||||
normalized relative to the original image shape.
|
||||
|
||||
Returns:
|
||||
(list[np.ndarray]): A list of numpy arrays, where each array contains the normalized xy-coordinates
|
||||
of a single segmentation mask. Each array has shape (N, 2), where N is the number of points in the
|
||||
(list[np.ndarray]): A list of numpy arrays, where each array contains the normalized xy-coordinates of a
|
||||
single segmentation mask. Each array has shape (N, 2), where N is the number of points in the
|
||||
mask contour.
|
||||
|
||||
Examples:
|
||||
|
|
@ -1144,16 +1104,14 @@ class Masks(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xy(self) -> list[np.ndarray]:
|
||||
"""
|
||||
Return the [x, y] pixel coordinates for each segment in the mask tensor.
|
||||
"""Return the [x, y] pixel coordinates for each segment in the mask tensor.
|
||||
|
||||
This property calculates and returns a list of pixel coordinates for each segmentation mask in the Masks object.
|
||||
The coordinates are scaled to match the original image dimensions.
|
||||
|
||||
Returns:
|
||||
(list[np.ndarray]): A list of numpy arrays, where each array contains the [x, y] pixel
|
||||
coordinates for a single segmentation mask. Each array has shape (N, 2), where N is the number of points
|
||||
in the segment.
|
||||
(list[np.ndarray]): A list of numpy arrays, where each array contains the [x, y] pixel coordinates for a
|
||||
single segmentation mask. Each array has shape (N, 2), where N is the number of points in the segment.
|
||||
|
||||
Examples:
|
||||
>>> results = model("image.jpg")
|
||||
|
|
@ -1169,8 +1127,7 @@ class Masks(BaseTensor):
|
|||
|
||||
|
||||
class Keypoints(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection keypoints.
|
||||
"""A class for storing and manipulating detection keypoints.
|
||||
|
||||
This class encapsulates functionality for handling keypoint data, including coordinate manipulation, normalization,
|
||||
and confidence values. It supports keypoint detection results with optional visibility information.
|
||||
|
|
@ -1201,8 +1158,7 @@ class Keypoints(BaseTensor):
|
|||
"""
|
||||
|
||||
def __init__(self, keypoints: torch.Tensor | np.ndarray, orig_shape: tuple[int, int]) -> None:
|
||||
"""
|
||||
Initialize the Keypoints object with detection keypoints and original image dimensions.
|
||||
"""Initialize the Keypoints object with detection keypoints and original image dimensions.
|
||||
|
||||
This method processes the input keypoints tensor, handling both 2D and 3D formats. For 3D tensors (x, y,
|
||||
confidence), it masks out low-confidence keypoints by setting their coordinates to zero.
|
||||
|
|
@ -1226,12 +1182,11 @@ class Keypoints(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xy(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return x, y coordinates of keypoints.
|
||||
"""Return x, y coordinates of keypoints.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor containing the x, y coordinates of keypoints with shape (N, K, 2), where N is
|
||||
the number of detections and K is the number of keypoints per detection.
|
||||
(torch.Tensor): A tensor containing the x, y coordinates of keypoints with shape (N, K, 2), where N is the
|
||||
number of detections and K is the number of keypoints per detection.
|
||||
|
||||
Examples:
|
||||
>>> results = model("image.jpg")
|
||||
|
|
@ -1250,8 +1205,7 @@ class Keypoints(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xyn(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return normalized coordinates (x, y) of keypoints relative to the original image size.
|
||||
"""Return normalized coordinates (x, y) of keypoints relative to the original image size.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): A tensor or array of shape (N, K, 2) containing normalized keypoint
|
||||
|
|
@ -1272,13 +1226,11 @@ class Keypoints(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def conf(self) -> torch.Tensor | np.ndarray | None:
|
||||
"""
|
||||
Return confidence values for each keypoint.
|
||||
"""Return confidence values for each keypoint.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | None): A tensor containing confidence scores for each keypoint if available,
|
||||
otherwise None. Shape is (num_detections, num_keypoints) for batched data or (num_keypoints,) for
|
||||
single detection.
|
||||
(torch.Tensor | None): A tensor containing confidence scores for each keypoint if available, otherwise None.
|
||||
Shape is (num_detections, num_keypoints) for batched data or (num_keypoints,) for single detection.
|
||||
|
||||
Examples:
|
||||
>>> keypoints = Keypoints(torch.rand(1, 17, 3), orig_shape=(640, 640)) # 1 detection, 17 keypoints
|
||||
|
|
@ -1289,8 +1241,7 @@ class Keypoints(BaseTensor):
|
|||
|
||||
|
||||
class Probs(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating classification probabilities.
|
||||
"""A class for storing and manipulating classification probabilities.
|
||||
|
||||
This class extends BaseTensor and provides methods for accessing and manipulating classification probabilities,
|
||||
including top-1 and top-5 predictions.
|
||||
|
|
@ -1323,16 +1274,15 @@ class Probs(BaseTensor):
|
|||
"""
|
||||
|
||||
def __init__(self, probs: torch.Tensor | np.ndarray, orig_shape: tuple[int, int] | None = None) -> None:
|
||||
"""
|
||||
Initialize the Probs class with classification probabilities.
|
||||
"""Initialize the Probs class with classification probabilities.
|
||||
|
||||
This class stores and manages classification probabilities, providing easy access to top predictions and their
|
||||
confidences.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor | np.ndarray): A 1D tensor or array of classification probabilities.
|
||||
orig_shape (tuple | None): The original image shape as (height, width). Not used in this class but kept
|
||||
for consistency with other result classes.
|
||||
orig_shape (tuple | None): The original image shape as (height, width). Not used in this class but kept for
|
||||
consistency with other result classes.
|
||||
|
||||
Attributes:
|
||||
data (torch.Tensor | np.ndarray): The raw tensor or array containing classification probabilities.
|
||||
|
|
@ -1357,8 +1307,7 @@ class Probs(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top1(self) -> int:
|
||||
"""
|
||||
Return the index of the class with the highest probability.
|
||||
"""Return the index of the class with the highest probability.
|
||||
|
||||
Returns:
|
||||
(int): Index of the class with the highest probability.
|
||||
|
|
@ -1373,8 +1322,7 @@ class Probs(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top5(self) -> list[int]:
|
||||
"""
|
||||
Return the indices of the top 5 class probabilities.
|
||||
"""Return the indices of the top 5 class probabilities.
|
||||
|
||||
Returns:
|
||||
(list[int]): A list containing the indices of the top 5 class probabilities, sorted in descending order.
|
||||
|
|
@ -1389,8 +1337,7 @@ class Probs(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top1conf(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return the confidence score of the highest probability class.
|
||||
"""Return the confidence score of the highest probability class.
|
||||
|
||||
This property retrieves the confidence score (probability) of the class with the highest predicted probability
|
||||
from the classification results.
|
||||
|
|
@ -1409,16 +1356,15 @@ class Probs(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top5conf(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return confidence scores for the top 5 classification predictions.
|
||||
"""Return confidence scores for the top 5 classification predictions.
|
||||
|
||||
This property retrieves the confidence scores corresponding to the top 5 class probabilities predicted by the
|
||||
model. It provides a quick way to access the most likely class predictions along with their associated
|
||||
confidence levels.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): A tensor or array containing the confidence scores for the
|
||||
top 5 predicted classes, sorted in descending order of probability.
|
||||
(torch.Tensor | np.ndarray): A tensor or array containing the confidence scores for the top 5 predicted
|
||||
classes, sorted in descending order of probability.
|
||||
|
||||
Examples:
|
||||
>>> results = model("image.jpg")
|
||||
|
|
@ -1430,8 +1376,7 @@ class Probs(BaseTensor):
|
|||
|
||||
|
||||
class OBB(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating Oriented Bounding Boxes (OBB).
|
||||
"""A class for storing and manipulating Oriented Bounding Boxes (OBB).
|
||||
|
||||
This class provides functionality to handle oriented bounding boxes, including conversion between different formats,
|
||||
normalization, and access to various properties of the boxes. It supports both tracking and non-tracking scenarios.
|
||||
|
|
@ -1463,16 +1408,15 @@ class OBB(BaseTensor):
|
|||
"""
|
||||
|
||||
def __init__(self, boxes: torch.Tensor | np.ndarray, orig_shape: tuple[int, int]) -> None:
|
||||
"""
|
||||
Initialize an OBB (Oriented Bounding Box) instance with oriented bounding box data and original image shape.
|
||||
"""Initialize an OBB (Oriented Bounding Box) instance with oriented bounding box data and original image shape.
|
||||
|
||||
This class stores and manipulates Oriented Bounding Boxes (OBB) for object detection tasks. It provides various
|
||||
properties and methods to access and transform the OBB data.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor | np.ndarray): A tensor or numpy array containing the detection boxes,
|
||||
with shape (num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values.
|
||||
If present, the third last column contains track IDs, and the fifth column contains rotation.
|
||||
boxes (torch.Tensor | np.ndarray): A tensor or numpy array containing the detection boxes, with shape
|
||||
(num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values. If present,
|
||||
the third last column contains track IDs, and the fifth column contains rotation.
|
||||
orig_shape (tuple[int, int]): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
|
|
@ -1500,8 +1444,7 @@ class OBB(BaseTensor):
|
|||
|
||||
@property
|
||||
def xywhr(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return boxes in [x_center, y_center, width, height, rotation] format.
|
||||
"""Return boxes in [x_center, y_center, width, height, rotation] format.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): A tensor or numpy array containing the oriented bounding boxes with format
|
||||
|
|
@ -1518,15 +1461,14 @@ class OBB(BaseTensor):
|
|||
|
||||
@property
|
||||
def conf(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return the confidence scores for Oriented Bounding Boxes (OBBs).
|
||||
"""Return the confidence scores for Oriented Bounding Boxes (OBBs).
|
||||
|
||||
This property retrieves the confidence values associated with each OBB detection. The confidence score
|
||||
represents the model's certainty in the detection.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): A tensor or numpy array of shape (N,) containing confidence scores
|
||||
for N detections, where each score is in the range [0, 1].
|
||||
(torch.Tensor | np.ndarray): A tensor or numpy array of shape (N,) containing confidence scores for N
|
||||
detections, where each score is in the range [0, 1].
|
||||
|
||||
Examples:
|
||||
>>> results = model("image.jpg")
|
||||
|
|
@ -1538,12 +1480,11 @@ class OBB(BaseTensor):
|
|||
|
||||
@property
|
||||
def cls(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Return the class values of the oriented bounding boxes.
|
||||
"""Return the class values of the oriented bounding boxes.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): A tensor or numpy array containing the class values for each oriented
|
||||
bounding box. The shape is (N,), where N is the number of boxes.
|
||||
(torch.Tensor | np.ndarray): A tensor or numpy array containing the class values for each oriented bounding
|
||||
box. The shape is (N,), where N is the number of boxes.
|
||||
|
||||
Examples:
|
||||
>>> results = model("image.jpg")
|
||||
|
|
@ -1556,12 +1497,11 @@ class OBB(BaseTensor):
|
|||
|
||||
@property
|
||||
def id(self) -> torch.Tensor | np.ndarray | None:
|
||||
"""
|
||||
Return the tracking IDs of the oriented bounding boxes (if available).
|
||||
"""Return the tracking IDs of the oriented bounding boxes (if available).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray | None): A tensor or numpy array containing the tracking IDs for each
|
||||
oriented bounding box. Returns None if tracking IDs are not available.
|
||||
(torch.Tensor | np.ndarray | None): A tensor or numpy array containing the tracking IDs for each oriented
|
||||
bounding box. Returns None if tracking IDs are not available.
|
||||
|
||||
Examples:
|
||||
>>> results = model("image.jpg", tracker=True) # Run inference with tracking
|
||||
|
|
@ -1576,12 +1516,11 @@ class OBB(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxyxyxy(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Convert OBB format to 8-point (xyxyxyxy) coordinate format for rotated bounding boxes.
|
||||
"""Convert OBB format to 8-point (xyxyxyxy) coordinate format for rotated bounding boxes.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): Rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2), where N is
|
||||
the number of boxes. Each box is represented by 4 points (x, y), starting from the top-left corner and
|
||||
(torch.Tensor | np.ndarray): Rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2), where N is the
|
||||
number of boxes. Each box is represented by 4 points (x, y), starting from the top-left corner and
|
||||
moving clockwise.
|
||||
|
||||
Examples:
|
||||
|
|
@ -1595,8 +1534,7 @@ class OBB(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxyxyxyn(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Convert rotated bounding boxes to normalized xyxyxyxy format.
|
||||
"""Convert rotated bounding boxes to normalized xyxyxyxy format.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): Normalized rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2),
|
||||
|
|
@ -1617,16 +1555,15 @@ class OBB(BaseTensor):
|
|||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxy(self) -> torch.Tensor | np.ndarray:
|
||||
"""
|
||||
Convert oriented bounding boxes (OBB) to axis-aligned bounding boxes in xyxy format.
|
||||
"""Convert oriented bounding boxes (OBB) to axis-aligned bounding boxes in xyxy format.
|
||||
|
||||
This property calculates the minimal enclosing rectangle for each oriented bounding box and returns it in xyxy
|
||||
format (x1, y1, x2, y2). This is useful for operations that require axis-aligned bounding boxes, such as IoU
|
||||
calculation with non-rotated boxes.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): Axis-aligned bounding boxes in xyxy format with shape (N, 4), where N
|
||||
is the number of boxes. Each row contains [x1, y1, x2, y2] coordinates.
|
||||
(torch.Tensor | np.ndarray): Axis-aligned bounding boxes in xyxy format with shape (N, 4), where N is the
|
||||
number of boxes. Each row contains [x1, y1, x2, y2] coordinates.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
|
|
|
|||
|
|
@ -63,8 +63,7 @@ from ultralytics.utils.torch_utils import (
|
|||
|
||||
|
||||
class BaseTrainer:
|
||||
"""
|
||||
A base class for creating trainers.
|
||||
"""A base class for creating trainers.
|
||||
|
||||
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
|
||||
and various training utilities. It supports both single-GPU and multi-GPU distributed training.
|
||||
|
|
@ -114,8 +113,7 @@ class BaseTrainer:
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the BaseTrainer class.
|
||||
"""Initialize the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file.
|
||||
|
|
@ -620,8 +618,7 @@ class BaseTrainer:
|
|||
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
||||
|
||||
def get_dataset(self):
|
||||
"""
|
||||
Get train and validation datasets from data dictionary.
|
||||
"""Get train and validation datasets from data dictionary.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the training/validation/test dataset and category names.
|
||||
|
|
@ -656,8 +653,7 @@ class BaseTrainer:
|
|||
return data
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
Load, create, or download model for any task.
|
||||
"""Load, create, or download model for any task.
|
||||
|
||||
Returns:
|
||||
(dict): Optional checkpoint to resume training from.
|
||||
|
|
@ -690,8 +686,7 @@ class BaseTrainer:
|
|||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Run validation on val set using self.validator.
|
||||
"""Run validation on val set using self.validator.
|
||||
|
||||
Returns:
|
||||
metrics (dict): Dictionary of validation metrics.
|
||||
|
|
@ -726,8 +721,7 @@ class BaseTrainer:
|
|||
raise NotImplementedError("build_dataset function not implemented in trainer")
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
"""
|
||||
Return a loss dict with labeled training loss items tensor.
|
||||
"""Return a loss dict with labeled training loss items tensor.
|
||||
|
||||
Notes:
|
||||
This is not needed for classification but necessary for segmentation & detection
|
||||
|
|
@ -895,18 +889,16 @@ class BaseTrainer:
|
|||
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
|
||||
|
||||
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||
"""
|
||||
Construct an optimizer for the given model.
|
||||
"""Construct an optimizer for the given model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model for which to build an optimizer.
|
||||
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
||||
based on the number of iterations.
|
||||
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
|
||||
number of iterations.
|
||||
lr (float, optional): The learning rate for the optimizer.
|
||||
momentum (float, optional): The momentum factor for the optimizer.
|
||||
decay (float, optional): The weight decay for the optimizer.
|
||||
iterations (float, optional): The number of iterations, which determines the optimizer if
|
||||
name is 'auto'.
|
||||
iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
|
||||
|
||||
Returns:
|
||||
(torch.optim.Optimizer): The constructed optimizer.
|
||||
|
|
|
|||
|
|
@ -34,8 +34,7 @@ from ultralytics.utils.plotting import plot_tune_results
|
|||
|
||||
|
||||
class Tuner:
|
||||
"""
|
||||
A class for hyperparameter tuning of YOLO models.
|
||||
"""A class for hyperparameter tuning of YOLO models.
|
||||
|
||||
The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
|
||||
search space and retraining the model to evaluate their performance. Supports both local CSV storage and distributed
|
||||
|
|
@ -83,8 +82,7 @@ class Tuner:
|
|||
"""
|
||||
|
||||
def __init__(self, args=DEFAULT_CFG, _callbacks: list | None = None):
|
||||
"""
|
||||
Initialize the Tuner with configurations.
|
||||
"""Initialize the Tuner with configurations.
|
||||
|
||||
Args:
|
||||
args (dict): Configuration for hyperparameter evolution.
|
||||
|
|
@ -142,8 +140,7 @@ class Tuner:
|
|||
)
|
||||
|
||||
def _connect(self, uri: str = "mongodb+srv://username:password@cluster.mongodb.net/", max_retries: int = 3):
|
||||
"""
|
||||
Create MongoDB client with exponential backoff retry on connection failures.
|
||||
"""Create MongoDB client with exponential backoff retry on connection failures.
|
||||
|
||||
Args:
|
||||
uri (str): MongoDB connection string with credentials and cluster information.
|
||||
|
|
@ -183,8 +180,7 @@ class Tuner:
|
|||
time.sleep(wait_time)
|
||||
|
||||
def _init_mongodb(self, mongodb_uri="", mongodb_db="", mongodb_collection=""):
|
||||
"""
|
||||
Initialize MongoDB connection for distributed tuning.
|
||||
"""Initialize MongoDB connection for distributed tuning.
|
||||
|
||||
Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines. Each worker
|
||||
saves results to a shared collection and reads the latest best hyperparameters from all workers for evolution.
|
||||
|
|
@ -205,8 +201,7 @@ class Tuner:
|
|||
LOGGER.info(f"{self.prefix}Using MongoDB Atlas for distributed tuning")
|
||||
|
||||
def _get_mongodb_results(self, n: int = 5) -> list:
|
||||
"""
|
||||
Get top N results from MongoDB sorted by fitness.
|
||||
"""Get top N results from MongoDB sorted by fitness.
|
||||
|
||||
Args:
|
||||
n (int): Number of top results to retrieve.
|
||||
|
|
@ -220,8 +215,7 @@ class Tuner:
|
|||
return []
|
||||
|
||||
def _save_to_mongodb(self, fitness: float, hyperparameters: dict[str, float], metrics: dict, iteration: int):
|
||||
"""
|
||||
Save results to MongoDB with proper type conversion.
|
||||
"""Save results to MongoDB with proper type conversion.
|
||||
|
||||
Args:
|
||||
fitness (float): Fitness score achieved with these hyperparameters.
|
||||
|
|
@ -243,8 +237,7 @@ class Tuner:
|
|||
LOGGER.warning(f"{self.prefix}MongoDB save failed: {e}")
|
||||
|
||||
def _sync_mongodb_to_csv(self):
|
||||
"""
|
||||
Sync MongoDB results to CSV for plotting compatibility.
|
||||
"""Sync MongoDB results to CSV for plotting compatibility.
|
||||
|
||||
Downloads all results from MongoDB and writes them to the local CSV file in chronological order. This enables
|
||||
the existing plotting functions to work seamlessly with distributed MongoDB data.
|
||||
|
|
@ -287,8 +280,7 @@ class Tuner:
|
|||
mutation: float = 0.5,
|
||||
sigma: float = 0.2,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
||||
"""Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
||||
|
||||
Args:
|
||||
parent (str): Parent selection method (kept for API compatibility, unused in BLX mode).
|
||||
|
|
@ -348,8 +340,7 @@ class Tuner:
|
|||
return hyp
|
||||
|
||||
def __call__(self, model=None, iterations: int = 10, cleanup: bool = True):
|
||||
"""
|
||||
Execute the hyperparameter evolution process when the Tuner instance is called.
|
||||
"""Execute the hyperparameter evolution process when the Tuner instance is called.
|
||||
|
||||
This method iterates through the specified number of iterations, performing the following steps:
|
||||
1. Sync MongoDB results to CSV (if using distributed mode)
|
||||
|
|
|
|||
|
|
@ -41,8 +41,7 @@ from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_
|
|||
|
||||
|
||||
class BaseValidator:
|
||||
"""
|
||||
A base class for creating validators.
|
||||
"""A base class for creating validators.
|
||||
|
||||
This class provides the foundation for validation processes, including model evaluation, metric computation, and
|
||||
result visualization.
|
||||
|
|
@ -62,8 +61,8 @@ class BaseValidator:
|
|||
nc (int): Number of classes.
|
||||
iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
||||
jdict (list): List to store JSON validation results.
|
||||
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
||||
batch processing times in milliseconds.
|
||||
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch
|
||||
processing times in milliseconds.
|
||||
save_dir (Path): Directory to save results.
|
||||
plots (dict): Dictionary to store plots for visualization.
|
||||
callbacks (dict): Dictionary to store various callback functions.
|
||||
|
|
@ -93,8 +92,7 @@ class BaseValidator:
|
|||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
|
||||
"""
|
||||
Initialize a BaseValidator instance.
|
||||
"""Initialize a BaseValidator instance.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
||||
|
|
@ -131,8 +129,7 @@ class BaseValidator:
|
|||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, trainer=None, model=None):
|
||||
"""
|
||||
Execute validation process, running inference on dataloader and computing performance metrics.
|
||||
"""Execute validation process, running inference on dataloader and computing performance metrics.
|
||||
|
||||
Args:
|
||||
trainer (object, optional): Trainer object that contains the model to validate.
|
||||
|
|
@ -269,8 +266,7 @@ class BaseValidator:
|
|||
def match_predictions(
|
||||
self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Match predictions to ground truth objects using IoU.
|
||||
"""Match predictions to ground truth objects using IoU.
|
||||
|
||||
Args:
|
||||
pred_classes (torch.Tensor): Predicted class indices of shape (N,).
|
||||
|
|
|
|||
|
|
@ -23,15 +23,14 @@ __all__ = (
|
|||
|
||||
|
||||
def login(api_key: str | None = None, save: bool = True) -> bool:
|
||||
"""
|
||||
Log in to the Ultralytics HUB API using the provided API key.
|
||||
"""Log in to the Ultralytics HUB API using the provided API key.
|
||||
|
||||
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
|
||||
environment variable if successfully authenticated.
|
||||
|
||||
Args:
|
||||
api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from
|
||||
SETTINGS or HUB_API_KEY environment variable.
|
||||
api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from SETTINGS
|
||||
or HUB_API_KEY environment variable.
|
||||
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
|
||||
|
||||
Returns:
|
||||
|
|
@ -91,8 +90,7 @@ def export_fmts_hub():
|
|||
|
||||
|
||||
def export_model(model_id: str = "", format: str = "torchscript"):
|
||||
"""
|
||||
Export a model to a specified format for deployment via the Ultralytics HUB API.
|
||||
"""Export a model to a specified format for deployment via the Ultralytics HUB API.
|
||||
|
||||
Args:
|
||||
model_id (str): The ID of the model to export. An empty string will use the default model.
|
||||
|
|
@ -117,13 +115,11 @@ def export_model(model_id: str = "", format: str = "torchscript"):
|
|||
|
||||
|
||||
def get_export(model_id: str = "", format: str = "torchscript"):
|
||||
"""
|
||||
Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
|
||||
"""Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
|
||||
|
||||
Args:
|
||||
model_id (str): The ID of the model to retrieve from Ultralytics HUB.
|
||||
format (str): The export format to retrieve. Must be one of the supported formats returned by
|
||||
export_fmts_hub().
|
||||
format (str): The export format to retrieve. Must be one of the supported formats returned by export_fmts_hub().
|
||||
|
||||
Returns:
|
||||
(dict): JSON response containing the exported model information.
|
||||
|
|
@ -148,8 +144,7 @@ def get_export(model_id: str = "", format: str = "torchscript"):
|
|||
|
||||
|
||||
def check_dataset(path: str, task: str) -> None:
|
||||
"""
|
||||
Check HUB dataset Zip file for errors before upload.
|
||||
"""Check HUB dataset Zip file for errors before upload.
|
||||
|
||||
Args:
|
||||
path (str): Path to data.zip (with data.yaml inside data.zip).
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@ API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
|
|||
|
||||
|
||||
class Auth:
|
||||
"""
|
||||
Manages authentication processes including API key handling, cookie-based authentication, and header generation.
|
||||
"""Manages authentication processes including API key handling, cookie-based authentication, and header generation.
|
||||
|
||||
The class supports different methods of authentication:
|
||||
1. Directly using an API key.
|
||||
|
|
@ -37,8 +36,7 @@ class Auth:
|
|||
id_token = api_key = model_key = False
|
||||
|
||||
def __init__(self, api_key: str = "", verbose: bool = False):
|
||||
"""
|
||||
Initialize Auth class and authenticate user.
|
||||
"""Initialize Auth class and authenticate user.
|
||||
|
||||
Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful
|
||||
authentication.
|
||||
|
|
@ -82,8 +80,7 @@ class Auth:
|
|||
LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'")
|
||||
|
||||
def request_api_key(self, max_attempts: int = 3) -> bool:
|
||||
"""
|
||||
Prompt the user to input their API key.
|
||||
"""Prompt the user to input their API key.
|
||||
|
||||
Args:
|
||||
max_attempts (int): Maximum number of authentication attempts.
|
||||
|
|
@ -102,8 +99,7 @@ class Auth:
|
|||
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
|
||||
|
||||
def authenticate(self) -> bool:
|
||||
"""
|
||||
Attempt to authenticate with the server using either id_token or API key.
|
||||
"""Attempt to authenticate with the server using either id_token or API key.
|
||||
|
||||
Returns:
|
||||
(bool): True if authentication is successful, False otherwise.
|
||||
|
|
@ -123,8 +119,7 @@ class Auth:
|
|||
return False
|
||||
|
||||
def auth_with_cookies(self) -> bool:
|
||||
"""
|
||||
Attempt to fetch authentication via cookies and set id_token.
|
||||
"""Attempt to fetch authentication via cookies and set id_token.
|
||||
|
||||
User must be logged in to HUB and running in a supported browser.
|
||||
|
||||
|
|
@ -145,8 +140,7 @@ class Auth:
|
|||
return False
|
||||
|
||||
def get_auth_header(self):
|
||||
"""
|
||||
Get the authentication header for making API requests.
|
||||
"""Get the authentication header for making API requests.
|
||||
|
||||
Returns:
|
||||
(dict | None): The authentication header if id_token or API key is set, None otherwise.
|
||||
|
|
|
|||
|
|
@ -8,8 +8,7 @@ import time
|
|||
|
||||
|
||||
class GCPRegions:
|
||||
"""
|
||||
A class for managing and analyzing Google Cloud Platform (GCP) regions.
|
||||
"""A class for managing and analyzing Google Cloud Platform (GCP) regions.
|
||||
|
||||
This class provides functionality to initialize, categorize, and analyze GCP regions based on their geographical
|
||||
location, tier classification, and network latency.
|
||||
|
|
@ -82,8 +81,7 @@ class GCPRegions:
|
|||
|
||||
@staticmethod
|
||||
def _ping_region(region: str, attempts: int = 1) -> tuple[str, float, float, float, float]:
|
||||
"""
|
||||
Ping a specified GCP region and measure network latency statistics.
|
||||
"""Ping a specified GCP region and measure network latency statistics.
|
||||
|
||||
Args:
|
||||
region (str): The GCP region identifier to ping (e.g., 'us-central1').
|
||||
|
|
@ -126,8 +124,7 @@ class GCPRegions:
|
|||
tier: int | None = None,
|
||||
attempts: int = 1,
|
||||
) -> list[tuple[str, float, float, float, float]]:
|
||||
"""
|
||||
Determine the GCP regions with the lowest latency based on ping tests.
|
||||
"""Determine the GCP regions with the lowest latency based on ping tests.
|
||||
|
||||
Args:
|
||||
top (int, optional): Number of top regions to return.
|
||||
|
|
@ -136,8 +133,8 @@ class GCPRegions:
|
|||
attempts (int, optional): Number of ping attempts per region.
|
||||
|
||||
Returns:
|
||||
(list[tuple[str, float, float, float, float]]): List of tuples containing region information and
|
||||
latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
|
||||
(list[tuple[str, float, float, float, float]]): List of tuples containing region information and latency
|
||||
statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
|
||||
|
||||
Examples:
|
||||
>>> regions = GCPRegions()
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version_
|
|||
|
||||
|
||||
class HUBTrainingSession:
|
||||
"""
|
||||
HUB training session for Ultralytics HUB YOLO models.
|
||||
"""HUB training session for Ultralytics HUB YOLO models.
|
||||
|
||||
This class encapsulates the functionality for interacting with Ultralytics HUB during model training, including
|
||||
model creation, metrics tracking, and checkpoint uploading.
|
||||
|
|
@ -45,12 +44,11 @@ class HUBTrainingSession:
|
|||
"""
|
||||
|
||||
def __init__(self, identifier: str):
|
||||
"""
|
||||
Initialize the HUBTrainingSession with the provided model identifier.
|
||||
"""Initialize the HUBTrainingSession with the provided model identifier.
|
||||
|
||||
Args:
|
||||
identifier (str): Model identifier used to initialize the HUB training session. It can be a URL string
|
||||
or a model key with specific format.
|
||||
identifier (str): Model identifier used to initialize the HUB training session. It can be a URL string or a
|
||||
model key with specific format.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided model identifier is invalid.
|
||||
|
|
@ -93,8 +91,7 @@ class HUBTrainingSession:
|
|||
|
||||
@classmethod
|
||||
def create_session(cls, identifier: str, args: dict[str, Any] | None = None):
|
||||
"""
|
||||
Create an authenticated HUBTrainingSession or return None.
|
||||
"""Create an authenticated HUBTrainingSession or return None.
|
||||
|
||||
Args:
|
||||
identifier (str): Model identifier used to initialize the HUB training session.
|
||||
|
|
@ -114,8 +111,7 @@ class HUBTrainingSession:
|
|||
return None
|
||||
|
||||
def load_model(self, model_id: str):
|
||||
"""
|
||||
Load an existing model from Ultralytics HUB using the provided model identifier.
|
||||
"""Load an existing model from Ultralytics HUB using the provided model identifier.
|
||||
|
||||
Args:
|
||||
model_id (str): The identifier of the model to load.
|
||||
|
|
@ -140,8 +136,7 @@ class HUBTrainingSession:
|
|||
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||
|
||||
def create_model(self, model_args: dict[str, Any]):
|
||||
"""
|
||||
Initialize a HUB training session with the specified model arguments.
|
||||
"""Initialize a HUB training session with the specified model arguments.
|
||||
|
||||
Args:
|
||||
model_args (dict[str, Any]): Arguments for creating the model, including batch size, epochs, image size,
|
||||
|
|
@ -186,8 +181,7 @@ class HUBTrainingSession:
|
|||
|
||||
@staticmethod
|
||||
def _parse_identifier(identifier: str):
|
||||
"""
|
||||
Parse the given identifier to determine the type and extract relevant components.
|
||||
"""Parse the given identifier to determine the type and extract relevant components.
|
||||
|
||||
The method supports different identifier formats:
|
||||
- A HUB model URL https://hub.ultralytics.com/models/MODEL
|
||||
|
|
@ -218,8 +212,7 @@ class HUBTrainingSession:
|
|||
return api_key, model_id, filename
|
||||
|
||||
def _set_train_args(self):
|
||||
"""
|
||||
Initialize training arguments and create a model entry on the Ultralytics HUB.
|
||||
"""Initialize training arguments and create a model entry on the Ultralytics HUB.
|
||||
|
||||
This method sets up training arguments based on the model's state and updates them with any additional arguments
|
||||
provided. It handles different states of the model, such as whether it's resumable, pretrained, or requires
|
||||
|
|
@ -261,8 +254,7 @@ class HUBTrainingSession:
|
|||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Execute request_func with retries, timeout handling, optional threading, and progress tracking.
|
||||
"""Execute request_func with retries, timeout handling, optional threading, and progress tracking.
|
||||
|
||||
Args:
|
||||
request_func (callable): The function to execute.
|
||||
|
|
@ -342,8 +334,7 @@ class HUBTrainingSession:
|
|||
return status_code in retry_codes
|
||||
|
||||
def _get_failure_message(self, response, retry: int, timeout: int) -> str:
|
||||
"""
|
||||
Generate a retry message based on the response status code.
|
||||
"""Generate a retry message based on the response status code.
|
||||
|
||||
Args:
|
||||
response (requests.Response): The HTTP response object.
|
||||
|
|
@ -379,8 +370,7 @@ class HUBTrainingSession:
|
|||
map: float = 0.0,
|
||||
final: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Upload a model checkpoint to Ultralytics HUB.
|
||||
"""Upload a model checkpoint to Ultralytics HUB.
|
||||
|
||||
Args:
|
||||
epoch (int): The current training epoch.
|
||||
|
|
|
|||
|
|
@ -21,8 +21,7 @@ HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/h
|
|||
|
||||
|
||||
def request_with_credentials(url: str) -> Any:
|
||||
"""
|
||||
Make an AJAX request with cookies attached in a Google Colab environment.
|
||||
"""Make an AJAX request with cookies attached in a Google Colab environment.
|
||||
|
||||
Args:
|
||||
url (str): The URL to make the request to.
|
||||
|
|
@ -62,8 +61,7 @@ def request_with_credentials(url: str) -> Any:
|
|||
|
||||
|
||||
def requests_with_progress(method: str, url: str, **kwargs):
|
||||
"""
|
||||
Make an HTTP request using the specified method and URL, with an optional progress bar.
|
||||
"""Make an HTTP request using the specified method and URL, with an optional progress bar.
|
||||
|
||||
Args:
|
||||
method (str): The HTTP method to use (e.g. 'GET', 'POST').
|
||||
|
|
@ -106,8 +104,7 @@ def smart_request(
|
|||
progress: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
||||
"""Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
||||
|
||||
Args:
|
||||
method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ from .val import FastSAMValidator
|
|||
|
||||
|
||||
class FastSAM(Model):
|
||||
"""
|
||||
FastSAM model interface for segment anything tasks.
|
||||
"""FastSAM model interface for segment anything tasks.
|
||||
|
||||
This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything
|
||||
Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.
|
||||
|
|
@ -53,15 +52,14 @@ class FastSAM(Model):
|
|||
texts: list | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Perform segmentation prediction on image or video source.
|
||||
"""Perform segmentation prediction on image or video source.
|
||||
|
||||
Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these prompts
|
||||
and passes them to the parent class predict method for processing.
|
||||
|
||||
Args:
|
||||
source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image,
|
||||
or numpy array.
|
||||
source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image, or
|
||||
numpy array.
|
||||
stream (bool): Whether to enable real-time streaming mode for video inputs.
|
||||
bboxes (list, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].
|
||||
points (list, optional): Point coordinates for prompted segmentation in format [[x, y]].
|
||||
|
|
|
|||
|
|
@ -13,8 +13,7 @@ from .utils import adjust_bboxes_to_image_border
|
|||
|
||||
|
||||
class FastSAMPredictor(SegmentationPredictor):
|
||||
"""
|
||||
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
|
||||
"""FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
|
||||
|
||||
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
|
||||
adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
|
||||
|
|
@ -33,8 +32,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the FastSAMPredictor with configuration and callbacks.
|
||||
"""Initialize the FastSAMPredictor with configuration and callbacks.
|
||||
|
||||
This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
|
||||
extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
|
||||
|
|
@ -49,8 +47,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|||
self.prompts = {}
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Apply postprocessing to FastSAM predictions and handle prompts.
|
||||
"""Apply postprocessing to FastSAM predictions and handle prompts.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): Raw predictions from the model.
|
||||
|
|
@ -77,8 +74,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|||
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
|
||||
|
||||
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
|
||||
"""
|
||||
Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
|
||||
"""Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
|
||||
|
||||
Args:
|
||||
results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
|
||||
|
|
@ -151,8 +147,7 @@ class FastSAMPredictor(SegmentationPredictor):
|
|||
return prompt_results
|
||||
|
||||
def _clip_inference(self, images, texts):
|
||||
"""
|
||||
Perform CLIP inference to calculate similarity between images and text prompts.
|
||||
"""Perform CLIP inference to calculate similarity between images and text prompts.
|
||||
|
||||
Args:
|
||||
images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@
|
|||
|
||||
|
||||
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
||||
"""
|
||||
Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
"""Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ from ultralytics.models.yolo.segment import SegmentationValidator
|
|||
|
||||
|
||||
class FastSAMValidator(SegmentationValidator):
|
||||
"""
|
||||
Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
|
||||
"""Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
|
||||
|
||||
Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class
|
||||
sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
|
||||
|
|
@ -23,8 +22,7 @@ class FastSAMValidator(SegmentationValidator):
|
|||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
|
||||
"""Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@ from .val import NASValidator
|
|||
|
||||
|
||||
class NAS(Model):
|
||||
"""
|
||||
YOLO-NAS model for object detection.
|
||||
"""YOLO-NAS model for object detection.
|
||||
|
||||
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. It
|
||||
is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
|
||||
|
|
@ -48,8 +47,7 @@ class NAS(Model):
|
|||
super().__init__(model, task="detect")
|
||||
|
||||
def _load(self, weights: str, task=None) -> None:
|
||||
"""
|
||||
Load an existing NAS model weights or create a new NAS model with pretrained weights.
|
||||
"""Load an existing NAS model weights or create a new NAS model with pretrained weights.
|
||||
|
||||
Args:
|
||||
weights (str): Path to the model weights file or model name.
|
||||
|
|
@ -83,8 +81,7 @@ class NAS(Model):
|
|||
self.model.eval()
|
||||
|
||||
def info(self, detailed: bool = False, verbose: bool = True) -> dict[str, Any]:
|
||||
"""
|
||||
Log model information.
|
||||
"""Log model information.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@ from ultralytics.utils import ops
|
|||
|
||||
|
||||
class NASPredictor(DetectionPredictor):
|
||||
"""
|
||||
Ultralytics YOLO NAS Predictor for object detection.
|
||||
"""Ultralytics YOLO NAS Predictor for object detection.
|
||||
|
||||
This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the raw
|
||||
predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and scaling the
|
||||
|
|
@ -33,8 +32,7 @@ class NASPredictor(DetectionPredictor):
|
|||
"""
|
||||
|
||||
def postprocess(self, preds_in, img, orig_imgs):
|
||||
"""
|
||||
Postprocess NAS model predictions to generate final detection results.
|
||||
"""Postprocess NAS model predictions to generate final detection results.
|
||||
|
||||
This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies
|
||||
post-processing operations to generate the final detection results compatible with Ultralytics result
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ __all__ = ["NASValidator"]
|
|||
|
||||
|
||||
class NASValidator(DetectionValidator):
|
||||
"""
|
||||
Ultralytics YOLO NAS Validator for object detection.
|
||||
"""Ultralytics YOLO NAS Validator for object detection.
|
||||
|
||||
Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions
|
||||
generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ from .val import RTDETRValidator
|
|||
|
||||
|
||||
class RTDETR(Model):
|
||||
"""
|
||||
Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
|
||||
"""Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
|
||||
|
||||
This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
|
||||
selection, and adaptable inference speed.
|
||||
|
|
@ -39,8 +38,7 @@ class RTDETR(Model):
|
|||
"""
|
||||
|
||||
def __init__(self, model: str = "rtdetr-l.pt") -> None:
|
||||
"""
|
||||
Initialize the RT-DETR model with the given pre-trained model file.
|
||||
"""Initialize the RT-DETR model with the given pre-trained model file.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
|
||||
|
|
@ -50,8 +48,7 @@ class RTDETR(Model):
|
|||
|
||||
@property
|
||||
def task_map(self) -> dict:
|
||||
"""
|
||||
Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
||||
"""Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ from ultralytics.utils import ops
|
|||
|
||||
|
||||
class RTDETRPredictor(BasePredictor):
|
||||
"""
|
||||
RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
|
||||
"""RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
|
||||
|
||||
This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy. It
|
||||
supports key features like efficient hybrid encoding and IoU-aware query selection.
|
||||
|
|
@ -34,21 +33,20 @@ class RTDETRPredictor(BasePredictor):
|
|||
"""
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
|
||||
"""Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
|
||||
|
||||
The method filters detections based on confidence and class if specified in `self.args`. It converts model
|
||||
predictions to Results objects containing properly scaled bounding boxes.
|
||||
|
||||
Args:
|
||||
preds (list | tuple): List of [predictions, extra] from the model, where predictions contain
|
||||
bounding boxes and scores.
|
||||
preds (list | tuple): List of [predictions, extra] from the model, where predictions contain bounding boxes
|
||||
and scores.
|
||||
img (torch.Tensor): Processed input images with shape (N, 3, H, W).
|
||||
orig_imgs (list | torch.Tensor): Original, unprocessed images.
|
||||
|
||||
Returns:
|
||||
results (list[Results]): A list of Results objects containing the post-processed bounding boxes,
|
||||
confidence scores, and class labels.
|
||||
results (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence
|
||||
scores, and class labels.
|
||||
"""
|
||||
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
||||
preds = [preds, None]
|
||||
|
|
@ -75,15 +73,14 @@ class RTDETRPredictor(BasePredictor):
|
|||
return results
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
Pre-transform input images before feeding them into the model for inference.
|
||||
"""Pre-transform input images before feeding them into the model for inference.
|
||||
|
||||
The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square (640)
|
||||
and scale_filled.
|
||||
|
||||
Args:
|
||||
im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,
|
||||
[(H, W, 3) x N] for list.
|
||||
im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for
|
||||
list.
|
||||
|
||||
Returns:
|
||||
(list): List of pre-transformed images ready for model inference.
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ from .val import RTDETRDataset, RTDETRValidator
|
|||
|
||||
|
||||
class RTDETRTrainer(DetectionTrainer):
|
||||
"""
|
||||
Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
|
||||
"""Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
|
||||
|
||||
This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of
|
||||
RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable
|
||||
|
|
@ -43,8 +42,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|||
"""
|
||||
|
||||
def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
|
||||
"""
|
||||
Initialize and return an RT-DETR model for object detection tasks.
|
||||
"""Initialize and return an RT-DETR model for object detection tasks.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Model configuration.
|
||||
|
|
@ -60,8 +58,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|||
return model
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
|
||||
"""
|
||||
Build and return an RT-DETR dataset for training or validation.
|
||||
"""Build and return an RT-DETR dataset for training or validation.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
|
|
|
|||
|
|
@ -16,8 +16,7 @@ __all__ = ("RTDETRValidator",) # tuple or list
|
|||
|
||||
|
||||
class RTDETRDataset(YOLODataset):
|
||||
"""
|
||||
Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
|
||||
"""Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
|
||||
|
||||
This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
|
||||
real-time detection and tracking tasks.
|
||||
|
|
@ -40,8 +39,7 @@ class RTDETRDataset(YOLODataset):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, data=None, **kwargs):
|
||||
"""
|
||||
Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
|
||||
"""Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
|
||||
|
||||
This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
|
||||
model, building upon the base YOLODataset functionality.
|
||||
|
|
@ -54,8 +52,7 @@ class RTDETRDataset(YOLODataset):
|
|||
super().__init__(*args, data=data, **kwargs)
|
||||
|
||||
def load_image(self, i, rect_mode=False):
|
||||
"""
|
||||
Load one image from dataset index 'i'.
|
||||
"""Load one image from dataset index 'i'.
|
||||
|
||||
Args:
|
||||
i (int): Index of the image to load.
|
||||
|
|
@ -73,8 +70,7 @@ class RTDETRDataset(YOLODataset):
|
|||
return super().load_image(i=i, rect_mode=rect_mode)
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
"""
|
||||
Build transformation pipeline for the dataset.
|
||||
"""Build transformation pipeline for the dataset.
|
||||
|
||||
Args:
|
||||
hyp (dict, optional): Hyperparameters for transformations.
|
||||
|
|
@ -105,8 +101,7 @@ class RTDETRDataset(YOLODataset):
|
|||
|
||||
|
||||
class RTDETRValidator(DetectionValidator):
|
||||
"""
|
||||
RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
|
||||
"""RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
|
||||
the RT-DETR (Real-Time DETR) object detection model.
|
||||
|
||||
The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
|
||||
|
|
@ -132,8 +127,7 @@ class RTDETRValidator(DetectionValidator):
|
|||
"""
|
||||
|
||||
def build_dataset(self, img_path, mode="val", batch=None):
|
||||
"""
|
||||
Build an RTDETR Dataset.
|
||||
"""Build an RTDETR Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
|
|
@ -159,8 +153,7 @@ class RTDETRValidator(DetectionValidator):
|
|||
def postprocess(
|
||||
self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Apply Non-maximum suppression to prediction outputs.
|
||||
"""Apply Non-maximum suppression to prediction outputs.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
|
||||
|
|
@ -191,12 +184,11 @@ class RTDETRValidator(DetectionValidator):
|
|||
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Serialize YOLO predictions to COCO json format.
|
||||
"""Serialize YOLO predictions to COCO json format.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
||||
with bounding box coordinates, confidence scores, and class predictions.
|
||||
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
||||
bounding box coordinates, confidence scores, and class predictions.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
"""
|
||||
path = Path(pbatch["im_file"])
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@ import torch
|
|||
def is_box_near_crop_edge(
|
||||
boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
|
||||
"""Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes in XYXY format.
|
||||
|
|
@ -42,8 +41,7 @@ def is_box_near_crop_edge(
|
|||
|
||||
|
||||
def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
|
||||
"""
|
||||
Yield batches of data from input arguments with specified batch size for efficient processing.
|
||||
"""Yield batches of data from input arguments with specified batch size for efficient processing.
|
||||
|
||||
This function takes a batch size and any number of iterables, then yields batches of elements from those
|
||||
iterables. All input iterables must have the same length.
|
||||
|
|
@ -71,8 +69,7 @@ def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
|
|||
|
||||
|
||||
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
||||
"""
|
||||
Compute the stability score for a batch of masks.
|
||||
"""Compute the stability score for a batch of masks.
|
||||
|
||||
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at high and
|
||||
low values.
|
||||
|
|
@ -117,8 +114,7 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
|
|||
def generate_crop_boxes(
|
||||
im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
|
||||
) -> tuple[list[list[int]], list[int]]:
|
||||
"""
|
||||
Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
||||
"""Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
||||
|
||||
Args:
|
||||
im_size (tuple[int, ...]): Height and width of the input image.
|
||||
|
|
@ -198,8 +194,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w:
|
|||
|
||||
|
||||
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Remove small disconnected regions or holes in a mask based on area threshold and mode.
|
||||
"""Remove small disconnected regions or holes in a mask based on area threshold and mode.
|
||||
|
||||
Args:
|
||||
mask (np.ndarray): Binary mask to process.
|
||||
|
|
@ -236,8 +231,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tup
|
|||
|
||||
|
||||
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate bounding boxes in XYXY format around binary masks.
|
||||
"""Calculate bounding boxes in XYXY format around binary masks.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
|
||||
|
|
|
|||
|
|
@ -127,8 +127,7 @@ def _build_sam(
|
|||
checkpoint=None,
|
||||
mobile_sam=False,
|
||||
):
|
||||
"""
|
||||
Build a Segment Anything Model (SAM) with specified encoder parameters.
|
||||
"""Build a Segment Anything Model (SAM) with specified encoder parameters.
|
||||
|
||||
Args:
|
||||
encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
|
||||
|
|
@ -224,8 +223,7 @@ def _build_sam2(
|
|||
encoder_window_spec=[8, 4, 16, 8],
|
||||
checkpoint=None,
|
||||
):
|
||||
"""
|
||||
Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
||||
"""Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
||||
|
||||
Args:
|
||||
encoder_embed_dim (int, optional): Embedding dimension for the encoder.
|
||||
|
|
@ -326,8 +324,7 @@ sam_model_map = {
|
|||
|
||||
|
||||
def build_sam(ckpt="sam_b.pt"):
|
||||
"""
|
||||
Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
|
||||
"""Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
|
||||
|
||||
Args:
|
||||
ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.
|
||||
|
|
|
|||
|
|
@ -25,8 +25,7 @@ from .predict import Predictor, SAM2Predictor
|
|||
|
||||
|
||||
class SAM(Model):
|
||||
"""
|
||||
SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
|
||||
"""SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
|
||||
|
||||
This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for promptable
|
||||
segmentation with versatility in image analysis. It supports various prompts such as bounding boxes, points, or
|
||||
|
|
@ -49,8 +48,7 @@ class SAM(Model):
|
|||
"""
|
||||
|
||||
def __init__(self, model: str = "sam_b.pt") -> None:
|
||||
"""
|
||||
Initialize the SAM (Segment Anything Model) instance.
|
||||
"""Initialize the SAM (Segment Anything Model) instance.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
|
||||
|
|
@ -68,8 +66,7 @@ class SAM(Model):
|
|||
super().__init__(model=model, task="segment")
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
"""
|
||||
Load the specified weights into the SAM model.
|
||||
"""Load the specified weights into the SAM model.
|
||||
|
||||
Args:
|
||||
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
|
||||
|
|
@ -84,12 +81,11 @@ class SAM(Model):
|
|||
self.model = build_sam(weights)
|
||||
|
||||
def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""
|
||||
Perform segmentation prediction on the given image or video source.
|
||||
"""Perform segmentation prediction on the given image or video source.
|
||||
|
||||
Args:
|
||||
source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or
|
||||
a np.ndarray object.
|
||||
source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or a
|
||||
np.ndarray object.
|
||||
stream (bool): If True, enables real-time streaming.
|
||||
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
|
||||
points (list[list[float]] | None): List of points for prompted segmentation.
|
||||
|
|
@ -111,15 +107,14 @@ class SAM(Model):
|
|||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||
|
||||
def __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""
|
||||
Perform segmentation prediction on the given image or video source.
|
||||
"""Perform segmentation prediction on the given image or video source.
|
||||
|
||||
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model for
|
||||
segmentation tasks.
|
||||
|
||||
Args:
|
||||
source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image
|
||||
object, or a np.ndarray object.
|
||||
source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image object, or a
|
||||
np.ndarray object.
|
||||
stream (bool): If True, enables real-time streaming.
|
||||
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
|
||||
points (list[list[float]] | None): List of points for prompted segmentation.
|
||||
|
|
@ -137,8 +132,7 @@ class SAM(Model):
|
|||
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
||||
|
||||
def info(self, detailed: bool = False, verbose: bool = True):
|
||||
"""
|
||||
Log information about the SAM model.
|
||||
"""Log information about the SAM model.
|
||||
|
||||
Args:
|
||||
detailed (bool): If True, displays detailed information about the model layers and operations.
|
||||
|
|
@ -156,8 +150,7 @@ class SAM(Model):
|
|||
|
||||
@property
|
||||
def task_map(self) -> dict[str, dict[str, type[Predictor]]]:
|
||||
"""
|
||||
Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
|
||||
"""Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
|
||||
|
||||
Returns:
|
||||
(dict[str, dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding
|
||||
|
|
|
|||
|
|
@ -17,8 +17,7 @@ from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis,
|
|||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""
|
||||
Implements stochastic depth regularization for neural networks during training.
|
||||
"""Implements stochastic depth regularization for neural networks during training.
|
||||
|
||||
Attributes:
|
||||
drop_prob (float): Probability of dropping a path during training.
|
||||
|
|
@ -52,15 +51,14 @@ class DropPath(nn.Module):
|
|||
|
||||
|
||||
class MaskDownSampler(nn.Module):
|
||||
"""
|
||||
A mask downsampling and embedding module for efficient processing of input masks.
|
||||
"""A mask downsampling and embedding module for efficient processing of input masks.
|
||||
|
||||
This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks while
|
||||
expanding their channel dimensions using convolutional layers, layer normalization, and activation functions.
|
||||
|
||||
Attributes:
|
||||
encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and
|
||||
activation functions for downsampling and embedding masks.
|
||||
encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and activation
|
||||
functions for downsampling and embedding masks.
|
||||
|
||||
Methods:
|
||||
forward: Downsamples and encodes input mask to embed_dim channels.
|
||||
|
|
@ -111,8 +109,7 @@ class MaskDownSampler(nn.Module):
|
|||
|
||||
|
||||
class CXBlock(nn.Module):
|
||||
"""
|
||||
ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
||||
"""ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
||||
|
||||
This block implements a modified version of the ConvNeXt architecture, offering improved performance and flexibility
|
||||
in feature extraction.
|
||||
|
|
@ -147,8 +144,7 @@ class CXBlock(nn.Module):
|
|||
layer_scale_init_value: float = 1e-6,
|
||||
use_dwconv: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
||||
"""Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
||||
|
||||
This block implements a modified version of the ConvNeXt architecture, offering improved performance and
|
||||
flexibility in feature extraction.
|
||||
|
|
@ -205,8 +201,7 @@ class CXBlock(nn.Module):
|
|||
|
||||
|
||||
class Fuser(nn.Module):
|
||||
"""
|
||||
A module for fusing features through multiple layers of a neural network.
|
||||
"""A module for fusing features through multiple layers of a neural network.
|
||||
|
||||
This class applies a series of identical layers to an input tensor, optionally projecting the input first.
|
||||
|
||||
|
|
@ -227,8 +222,7 @@ class Fuser(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, layer: nn.Module, num_layers: int, dim: int | None = None, input_projection: bool = False):
|
||||
"""
|
||||
Initialize the Fuser module for feature fusion through multiple layers.
|
||||
"""Initialize the Fuser module for feature fusion through multiple layers.
|
||||
|
||||
This module creates a sequence of identical layers and optionally applies an input projection.
|
||||
|
||||
|
|
@ -261,8 +255,7 @@ class Fuser(nn.Module):
|
|||
|
||||
|
||||
class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
|
||||
"""
|
||||
A two-way attention block for performing self-attention and cross-attention in both directions.
|
||||
"""A two-way attention block for performing self-attention and cross-attention in both directions.
|
||||
|
||||
This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse inputs,
|
||||
cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention from dense to sparse
|
||||
|
|
@ -298,8 +291,7 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
|
|||
attention_downsample_rate: int = 2,
|
||||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
|
||||
"""Initialize a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
|
||||
|
||||
This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse
|
||||
inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention from
|
||||
|
|
@ -324,8 +316,7 @@ class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
|
|||
|
||||
|
||||
class SAM2TwoWayTransformer(TwoWayTransformer):
|
||||
"""
|
||||
A Two-Way Transformer module for simultaneous attention to image and query points.
|
||||
"""A Two-Way Transformer module for simultaneous attention to image and query points.
|
||||
|
||||
This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an input
|
||||
image using queries with supplied positional embeddings. It is particularly useful for tasks like object detection,
|
||||
|
|
@ -361,8 +352,7 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
|
|||
activation: type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a SAM2TwoWayTransformer instance.
|
||||
"""Initialize a SAM2TwoWayTransformer instance.
|
||||
|
||||
This transformer decoder attends to an input image using queries with supplied positional embeddings. It is
|
||||
designed for tasks like object detection, image segmentation, and point cloud processing.
|
||||
|
|
@ -402,8 +392,7 @@ class SAM2TwoWayTransformer(TwoWayTransformer):
|
|||
|
||||
|
||||
class RoPEAttention(Attention):
|
||||
"""
|
||||
Implements rotary position encoding for attention mechanisms in transformer architectures.
|
||||
"""Implements rotary position encoding for attention mechanisms in transformer architectures.
|
||||
|
||||
This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance the
|
||||
positional awareness of the attention mechanism.
|
||||
|
|
@ -500,8 +489,7 @@ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.T
|
|||
|
||||
|
||||
class MultiScaleAttention(nn.Module):
|
||||
"""
|
||||
Implements multiscale self-attention with optional query pooling for efficient feature extraction.
|
||||
"""Implements multiscale self-attention with optional query pooling for efficient feature extraction.
|
||||
|
||||
This class provides a flexible implementation of multiscale attention, allowing for optional downsampling of query
|
||||
features through pooling. It's designed to enhance the model's ability to capture multiscale information in visual
|
||||
|
|
@ -580,8 +568,7 @@ class MultiScaleAttention(nn.Module):
|
|||
|
||||
|
||||
class MultiScaleBlock(nn.Module):
|
||||
"""
|
||||
A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
|
||||
"""A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
|
||||
|
||||
This class implements a multiscale attention mechanism with optional window partitioning and downsampling, designed
|
||||
for use in vision transformer architectures.
|
||||
|
|
@ -695,8 +682,7 @@ class MultiScaleBlock(nn.Module):
|
|||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
A module for generating sinusoidal positional embeddings for 2D inputs like images.
|
||||
"""A module for generating sinusoidal positional embeddings for 2D inputs like images.
|
||||
|
||||
This class implements sinusoidal position encoding for 2D spatial positions, which can be used in transformer-based
|
||||
models for computer vision tasks.
|
||||
|
|
@ -810,8 +796,7 @@ class PositionEmbeddingSine(nn.Module):
|
|||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
Positional encoding using random spatial frequencies.
|
||||
"""Positional encoding using random spatial frequencies.
|
||||
|
||||
This class generates positional embeddings for input coordinates using random spatial frequencies. It is
|
||||
particularly useful for transformer-based models that require position information.
|
||||
|
|
@ -877,8 +862,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""
|
||||
Transformer block with support for window attention and residual propagation.
|
||||
"""Transformer block with support for window attention and residual propagation.
|
||||
|
||||
This class implements a transformer block that can use either global or windowed self-attention, followed by a
|
||||
feed-forward network. It supports relative positional embeddings and is designed for use in vision transformer
|
||||
|
|
@ -916,8 +900,7 @@ class Block(nn.Module):
|
|||
window_size: int = 0,
|
||||
input_size: tuple[int, int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a transformer block with optional window attention and relative positional embeddings.
|
||||
"""Initialize a transformer block with optional window attention and relative positional embeddings.
|
||||
|
||||
This constructor sets up a transformer block that can use either global or windowed self-attention, followed by
|
||||
a feed-forward network. It supports relative positional embeddings and is designed for use in vision transformer
|
||||
|
|
@ -977,8 +960,7 @@ class Block(nn.Module):
|
|||
|
||||
|
||||
class REAttention(nn.Module):
|
||||
"""
|
||||
Relative Position Attention module for efficient self-attention in transformer architectures.
|
||||
"""Relative Position Attention module for efficient self-attention in transformer architectures.
|
||||
|
||||
This class implements a multi-head attention mechanism with relative positional embeddings, designed for use in
|
||||
vision transformer models. It supports optional query pooling and window partitioning for efficient processing of
|
||||
|
|
@ -1013,8 +995,7 @@ class REAttention(nn.Module):
|
|||
rel_pos_zero_init: bool = True,
|
||||
input_size: tuple[int, int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a Relative Position Attention module for transformer-based architectures.
|
||||
"""Initialize a Relative Position Attention module for transformer-based architectures.
|
||||
|
||||
This module implements multi-head attention with optional relative positional encodings, designed specifically
|
||||
for vision tasks in transformer models.
|
||||
|
|
@ -1069,8 +1050,7 @@ class REAttention(nn.Module):
|
|||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding module for vision transformer architectures.
|
||||
"""Image to Patch Embedding module for vision transformer architectures.
|
||||
|
||||
This module converts an input image into a sequence of patch embeddings using a convolutional layer. It is commonly
|
||||
used as the first layer in vision transformer architectures to transform image data into a suitable format for
|
||||
|
|
@ -1098,8 +1078,7 @@ class PatchEmbed(nn.Module):
|
|||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the PatchEmbed module for converting image patches to embeddings.
|
||||
"""Initialize the PatchEmbed module for converting image patches to embeddings.
|
||||
|
||||
This module is typically used as the first layer in vision transformer architectures to transform image data
|
||||
into a suitable format for subsequent transformer blocks.
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ from ultralytics.nn.modules import MLP, LayerNorm2d
|
|||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
"""
|
||||
Decoder module for generating masks and their associated quality scores using a transformer architecture.
|
||||
"""Decoder module for generating masks and their associated quality scores using a transformer architecture.
|
||||
|
||||
This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
|
||||
generate mask predictions along with their quality scores.
|
||||
|
|
@ -47,8 +46,7 @@ class MaskDecoder(nn.Module):
|
|||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MaskDecoder module for generating masks and their associated quality scores.
|
||||
"""Initialize the MaskDecoder module for generating masks and their associated quality scores.
|
||||
|
||||
Args:
|
||||
transformer_dim (int): Channel dimension for the transformer module.
|
||||
|
|
@ -94,8 +92,7 @@ class MaskDecoder(nn.Module):
|
|||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
"""Predict masks given image and prompt embeddings.
|
||||
|
||||
Args:
|
||||
image_embeddings (torch.Tensor): Embeddings from the image encoder.
|
||||
|
|
@ -172,8 +169,7 @@ class MaskDecoder(nn.Module):
|
|||
|
||||
|
||||
class SAM2MaskDecoder(nn.Module):
|
||||
"""
|
||||
Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
|
||||
"""Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
|
||||
|
||||
This class extends the functionality of the MaskDecoder, incorporating additional features such as high-resolution
|
||||
feature processing, dynamic multimask output, and object score prediction.
|
||||
|
|
@ -233,8 +229,7 @@ class SAM2MaskDecoder(nn.Module):
|
|||
pred_obj_scores_mlp: bool = False,
|
||||
use_multimask_token_for_obj_ptr: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
|
||||
"""Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
|
||||
|
||||
This decoder extends the functionality of MaskDecoder, incorporating additional features such as high-resolution
|
||||
feature processing, dynamic multimask output, and object score prediction.
|
||||
|
|
@ -319,8 +314,7 @@ class SAM2MaskDecoder(nn.Module):
|
|||
repeat_image: bool,
|
||||
high_res_features: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
"""Predict masks given image and prompt embeddings.
|
||||
|
||||
Args:
|
||||
image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
|
||||
|
|
@ -458,8 +452,7 @@ class SAM2MaskDecoder(nn.Module):
|
|||
return torch.where(area_u > 0, area_i / area_u, 1.0)
|
||||
|
||||
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
||||
"""
|
||||
Dynamically select the most stable mask output based on stability scores and IoU predictions.
|
||||
"""Dynamically select the most stable mask output based on stability scores and IoU predictions.
|
||||
|
||||
This method is used when outputting a single mask. If the stability score from the current single-mask output
|
||||
(based on output token 0) falls below a threshold, it instead selects from multi-mask outputs (based on output
|
||||
|
|
@ -467,8 +460,8 @@ class SAM2MaskDecoder(nn.Module):
|
|||
tracking scenarios.
|
||||
|
||||
Args:
|
||||
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
|
||||
batch size, N is number of masks (typically 4), and H, W are mask dimensions.
|
||||
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is batch size, N
|
||||
is number of masks (typically 4), and H, W are mask dimensions.
|
||||
all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
|
||||
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -21,8 +21,7 @@ from .blocks import (
|
|||
|
||||
|
||||
class ImageEncoderViT(nn.Module):
|
||||
"""
|
||||
An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
|
||||
"""An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
|
||||
|
||||
This class processes images by splitting them into patches, applying transformer blocks, and generating a final
|
||||
encoded representation through a neck module.
|
||||
|
|
@ -64,8 +63,7 @@ class ImageEncoderViT(nn.Module):
|
|||
window_size: int = 0,
|
||||
global_attn_indexes: tuple[int, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
|
||||
"""Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
|
||||
|
||||
Args:
|
||||
img_size (int): Input image size, assumed to be square.
|
||||
|
|
@ -156,8 +154,7 @@ class ImageEncoderViT(nn.Module):
|
|||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
"""
|
||||
Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
|
||||
"""Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
|
||||
|
||||
Attributes:
|
||||
embed_dim (int): Dimension of the embeddings.
|
||||
|
|
@ -193,8 +190,7 @@ class PromptEncoder(nn.Module):
|
|||
mask_in_chans: int,
|
||||
activation: type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the PromptEncoder module for encoding various types of prompts.
|
||||
"""Initialize the PromptEncoder module for encoding various types of prompts.
|
||||
|
||||
Args:
|
||||
embed_dim (int): The dimension of the embeddings.
|
||||
|
|
@ -236,15 +232,14 @@ class PromptEncoder(nn.Module):
|
|||
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Return the dense positional encoding used for encoding point prompts.
|
||||
"""Return the dense positional encoding used for encoding point prompts.
|
||||
|
||||
Generate a positional encoding for a dense set of points matching the shape of the image
|
||||
encoding. The encoding is used to provide spatial information to the model when processing point prompts.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
|
||||
height and width of the image embedding size, respectively.
|
||||
(torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the height and
|
||||
width of the image embedding size, respectively.
|
||||
|
||||
Examples:
|
||||
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
|
||||
|
|
@ -306,12 +301,11 @@ class PromptEncoder(nn.Module):
|
|||
boxes: torch.Tensor | None,
|
||||
masks: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Embed different types of prompts, returning both sparse and dense embeddings.
|
||||
"""Embed different types of prompts, returning both sparse and dense embeddings.
|
||||
|
||||
Args:
|
||||
points (tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
|
||||
tensor contains coordinates of shape (B, N, 2), and the second tensor contains labels of shape (B, N).
|
||||
points (tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first tensor
|
||||
contains coordinates of shape (B, N, 2), and the second tensor contains labels of shape (B, N).
|
||||
boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
|
||||
masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
|
||||
|
||||
|
|
@ -353,8 +347,7 @@ class PromptEncoder(nn.Module):
|
|||
|
||||
|
||||
class MemoryEncoder(nn.Module):
|
||||
"""
|
||||
Encode pixel features and masks into a memory representation for efficient image segmentation.
|
||||
"""Encode pixel features and masks into a memory representation for efficient image segmentation.
|
||||
|
||||
This class processes pixel-level features and masks, fusing them to generate encoded memory representations suitable
|
||||
for downstream tasks in image segmentation models like SAM (Segment Anything Model).
|
||||
|
|
@ -384,8 +377,7 @@ class MemoryEncoder(nn.Module):
|
|||
out_dim,
|
||||
in_dim=256, # in_dim of pix_feats
|
||||
):
|
||||
"""
|
||||
Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
|
||||
"""Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
|
||||
|
||||
This encoder processes pixel-level features and masks, fusing them to generate encoded memory representations
|
||||
suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
|
||||
|
|
@ -438,8 +430,7 @@ class MemoryEncoder(nn.Module):
|
|||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
"""
|
||||
Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
|
||||
"""Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
|
||||
|
||||
This class combines a trunk network for feature extraction with a neck network for feature refinement and positional
|
||||
encoding generation. It can optionally discard the lowest resolution features.
|
||||
|
|
@ -468,8 +459,7 @@ class ImageEncoder(nn.Module):
|
|||
neck: nn.Module,
|
||||
scalp: int = 0,
|
||||
):
|
||||
"""
|
||||
Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.
|
||||
"""Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.
|
||||
|
||||
This encoder combines a trunk network for feature extraction with a neck network for feature refinement and
|
||||
positional encoding generation. It can optionally discard the lowest resolution features.
|
||||
|
|
@ -512,8 +502,7 @@ class ImageEncoder(nn.Module):
|
|||
|
||||
|
||||
class FpnNeck(nn.Module):
|
||||
"""
|
||||
A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
|
||||
"""A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
|
||||
|
||||
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, similar to ViT
|
||||
positional embedding interpolation.
|
||||
|
|
@ -549,8 +538,7 @@ class FpnNeck(nn.Module):
|
|||
fuse_type: str = "sum",
|
||||
fpn_top_down_levels: list[int] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize a modified Feature Pyramid Network (FPN) neck.
|
||||
"""Initialize a modified Feature Pyramid Network (FPN) neck.
|
||||
|
||||
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, similar to
|
||||
ViT positional embedding interpolation.
|
||||
|
|
@ -602,8 +590,7 @@ class FpnNeck(nn.Module):
|
|||
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
||||
|
||||
def forward(self, xs: list[torch.Tensor]):
|
||||
"""
|
||||
Perform forward pass through the Feature Pyramid Network (FPN) neck.
|
||||
"""Perform forward pass through the Feature Pyramid Network (FPN) neck.
|
||||
|
||||
This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
|
||||
and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
|
||||
|
|
@ -612,8 +599,8 @@ class FpnNeck(nn.Module):
|
|||
xs (list[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
|
||||
|
||||
Returns:
|
||||
out (list[torch.Tensor]): List of output feature maps after FPN processing, each with shape
|
||||
(B, d_model, H, W).
|
||||
out (list[torch.Tensor]): List of output feature maps after FPN processing, each with shape (B, d_model, H,
|
||||
W).
|
||||
pos (list[torch.Tensor]): List of positional encodings corresponding to each output feature map.
|
||||
|
||||
Examples:
|
||||
|
|
@ -655,8 +642,7 @@ class FpnNeck(nn.Module):
|
|||
|
||||
|
||||
class Hiera(nn.Module):
|
||||
"""
|
||||
Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
|
||||
"""Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
|
||||
|
||||
This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for efficient
|
||||
multiscale feature extraction. It uses a series of transformer blocks organized into stages, with optional pooling
|
||||
|
|
@ -714,8 +700,7 @@ class Hiera(nn.Module):
|
|||
),
|
||||
return_interm_layers=True, # return feats from every stage
|
||||
):
|
||||
"""
|
||||
Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.
|
||||
"""Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.
|
||||
|
||||
Hiera is a hierarchical vision transformer architecture designed for efficient multiscale feature extraction in
|
||||
image processing tasks. It uses a series of transformer blocks organized into stages, with optional pooling and
|
||||
|
|
@ -816,8 +801,7 @@ class Hiera(nn.Module):
|
|||
return pos_embed
|
||||
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""
|
||||
Perform forward pass through Hiera model, extracting multiscale features from input images.
|
||||
"""Perform forward pass through Hiera model, extracting multiscale features from input images.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with shape (B, C, H, W) representing a batch of images.
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ from .blocks import RoPEAttention
|
|||
|
||||
|
||||
class MemoryAttentionLayer(nn.Module):
|
||||
"""
|
||||
Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
|
||||
"""Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
|
||||
|
||||
This class combines self-attention, cross-attention, and feedforward components to process input tensors and
|
||||
generate memory-based attention outputs.
|
||||
|
|
@ -61,8 +60,7 @@ class MemoryAttentionLayer(nn.Module):
|
|||
pos_enc_at_cross_attn_keys: bool = True,
|
||||
pos_enc_at_cross_attn_queries: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
|
||||
"""Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
|
||||
|
||||
Args:
|
||||
d_model (int): Dimensionality of the model.
|
||||
|
|
@ -145,8 +143,7 @@ class MemoryAttentionLayer(nn.Module):
|
|||
query_pos: torch.Tensor | None = None,
|
||||
num_k_exclude_rope: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process input tensors through self-attention, cross-attention, and feedforward network layers.
|
||||
"""Process input tensors through self-attention, cross-attention, and feedforward network layers.
|
||||
|
||||
Args:
|
||||
tgt (torch.Tensor): Target tensor for self-attention with shape (N, L, D).
|
||||
|
|
@ -168,8 +165,7 @@ class MemoryAttentionLayer(nn.Module):
|
|||
|
||||
|
||||
class MemoryAttention(nn.Module):
|
||||
"""
|
||||
Memory attention module for processing sequential data with self and cross-attention mechanisms.
|
||||
"""Memory attention module for processing sequential data with self and cross-attention mechanisms.
|
||||
|
||||
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
|
||||
processing sequential data, particularly useful in transformer-like architectures.
|
||||
|
|
@ -206,8 +202,7 @@ class MemoryAttention(nn.Module):
|
|||
num_layers: int,
|
||||
batch_first: bool = True, # Do layers expect batch first input?
|
||||
):
|
||||
"""
|
||||
Initialize MemoryAttention with specified layers and normalization for sequential data processing.
|
||||
"""Initialize MemoryAttention with specified layers and normalization for sequential data processing.
|
||||
|
||||
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
|
||||
processing sequential data, particularly useful in transformer-like architectures.
|
||||
|
|
@ -247,8 +242,7 @@ class MemoryAttention(nn.Module):
|
|||
memory_pos: torch.Tensor | None = None, # pos_enc for cross-attention inputs
|
||||
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process inputs through attention layers, applying self and cross-attention with positional encoding.
|
||||
"""Process inputs through attention layers, applying self and cross-attention with positional encoding.
|
||||
|
||||
Args:
|
||||
curr (torch.Tensor): Self-attention input tensor, representing the current state.
|
||||
|
|
|
|||
|
|
@ -23,8 +23,7 @@ NO_OBJ_SCORE = -1024.0
|
|||
|
||||
|
||||
class SAMModel(nn.Module):
|
||||
"""
|
||||
Segment Anything Model (SAM) for object segmentation tasks.
|
||||
"""Segment Anything Model (SAM) for object segmentation tasks.
|
||||
|
||||
This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images and input
|
||||
prompts.
|
||||
|
|
@ -61,8 +60,7 @@ class SAMModel(nn.Module):
|
|||
pixel_mean: list[float] = (123.675, 116.28, 103.53),
|
||||
pixel_std: list[float] = (58.395, 57.12, 57.375),
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the SAMModel class to predict object masks from an image and input prompts.
|
||||
"""Initialize the SAMModel class to predict object masks from an image and input prompts.
|
||||
|
||||
Args:
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
|
||||
|
|
@ -98,8 +96,7 @@ class SAMModel(nn.Module):
|
|||
|
||||
|
||||
class SAM2Model(torch.nn.Module):
|
||||
"""
|
||||
SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
|
||||
"""SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
|
||||
|
||||
This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms for temporal
|
||||
consistency and efficient tracking of objects across frames.
|
||||
|
|
@ -136,24 +133,24 @@ class SAM2Model(torch.nn.Module):
|
|||
use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
|
||||
no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.
|
||||
max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
|
||||
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
|
||||
first frame.
|
||||
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial
|
||||
conditioning frames.
|
||||
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
|
||||
frame.
|
||||
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
|
||||
frames.
|
||||
multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
|
||||
multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
|
||||
multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
|
||||
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
|
||||
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
|
||||
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
|
||||
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
|
||||
memory encoder during evaluation.
|
||||
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
|
||||
encoder during evaluation.
|
||||
sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
|
||||
sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
|
||||
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
|
||||
with clicks during evaluation.
|
||||
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
|
||||
prompt encoder and mask decoder on frames with mask input.
|
||||
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
|
||||
clicks during evaluation.
|
||||
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM prompt
|
||||
encoder and mask decoder on frames with mask input.
|
||||
|
||||
Methods:
|
||||
forward_image: Process image batch through encoder to extract multi-level features.
|
||||
|
|
@ -208,8 +205,7 @@ class SAM2Model(torch.nn.Module):
|
|||
sam_mask_decoder_extra_args=None,
|
||||
compile_image_encoder: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the SAM2Model for video object segmentation with memory-based tracking.
|
||||
"""Initialize the SAM2Model for video object segmentation with memory-based tracking.
|
||||
|
||||
Args:
|
||||
image_encoder (nn.Module): Visual encoder for extracting image features.
|
||||
|
|
@ -220,35 +216,35 @@ class SAM2Model(torch.nn.Module):
|
|||
backbone_stride (int): Stride of the image backbone output.
|
||||
sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
|
||||
sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
|
||||
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
|
||||
with clicks during evaluation.
|
||||
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
|
||||
clicks during evaluation.
|
||||
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
|
||||
prompt encoder and mask decoder on frames with mask input.
|
||||
max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
|
||||
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
|
||||
first frame.
|
||||
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
|
||||
frame.
|
||||
use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
|
||||
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial
|
||||
conditioning frames.
|
||||
multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
|
||||
frames.
|
||||
multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
|
||||
multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
|
||||
multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
|
||||
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
|
||||
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
|
||||
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
|
||||
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
|
||||
memory encoder during evaluation.
|
||||
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
|
||||
encoder during evaluation.
|
||||
use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
|
||||
max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
|
||||
cross-attention.
|
||||
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
|
||||
the encoder.
|
||||
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in the
|
||||
encoder.
|
||||
proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
|
||||
encoding in object pointers.
|
||||
use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in the temporal positional encoding
|
||||
in the object pointers.
|
||||
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
|
||||
during evaluation.
|
||||
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during
|
||||
evaluation.
|
||||
pred_obj_scores (bool): Whether to predict if there is an object in the frame.
|
||||
pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
|
||||
fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
|
||||
|
|
@ -428,25 +424,23 @@ class SAM2Model(torch.nn.Module):
|
|||
high_res_features=None,
|
||||
multimask_output=False,
|
||||
):
|
||||
"""
|
||||
Forward pass through SAM prompt encoders and mask heads.
|
||||
"""Forward pass through SAM prompt encoders and mask heads.
|
||||
|
||||
This method processes image features and optional point/mask inputs to generate object masks and scores.
|
||||
|
||||
Args:
|
||||
backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
|
||||
point_inputs (dict[str, torch.Tensor] | None): Dictionary containing point prompts.
|
||||
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
|
||||
pixel-unit coordinates in (x, y) format for P input points.
|
||||
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
|
||||
0 means negative clicks, and -1 means padding.
|
||||
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
|
||||
same spatial size as the image.
|
||||
high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes
|
||||
(B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps for
|
||||
SAM decoder.
|
||||
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
|
||||
output only 1 mask and its IoU estimate.
|
||||
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute pixel-unit coordinates in
|
||||
(x, y) format for P input points.
|
||||
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, 0 means negative
|
||||
clicks, and -1 means padding.
|
||||
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the same spatial
|
||||
size as the image.
|
||||
high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes (B, C, 4*H, 4*W) and (B,
|
||||
C, 2*H, 2*W) respectively, used as high-resolution feature maps for SAM decoder.
|
||||
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, output only 1
|
||||
mask and its IoU estimate.
|
||||
|
||||
Returns:
|
||||
low_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
|
||||
|
|
|
|||
|
|
@ -22,8 +22,7 @@ from ultralytics.utils.instance import to_2tuple
|
|||
|
||||
|
||||
class Conv2d_BN(torch.nn.Sequential):
|
||||
"""
|
||||
A sequential container that performs 2D convolution followed by batch normalization.
|
||||
"""A sequential container that performs 2D convolution followed by batch normalization.
|
||||
|
||||
This module combines a 2D convolution layer with batch normalization, providing a common building block for
|
||||
convolutional neural networks. The batch normalization weights and biases are initialized to specific values for
|
||||
|
|
@ -52,8 +51,7 @@ class Conv2d_BN(torch.nn.Sequential):
|
|||
groups: int = 1,
|
||||
bn_weight_init: float = 1,
|
||||
):
|
||||
"""
|
||||
Initialize a sequential container with 2D convolution followed by batch normalization.
|
||||
"""Initialize a sequential container with 2D convolution followed by batch normalization.
|
||||
|
||||
Args:
|
||||
a (int): Number of input channels.
|
||||
|
|
@ -74,8 +72,7 @@ class Conv2d_BN(torch.nn.Sequential):
|
|||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Embed images into patches and project them into a specified embedding dimension.
|
||||
"""Embed images into patches and project them into a specified embedding dimension.
|
||||
|
||||
This module converts input images into patch embeddings using a sequence of convolutional layers, effectively
|
||||
downsampling the spatial dimensions while increasing the channel dimension.
|
||||
|
|
@ -97,8 +94,7 @@ class PatchEmbed(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, in_chans: int, embed_dim: int, resolution: int, activation):
|
||||
"""
|
||||
Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.
|
||||
"""Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
|
|
@ -125,8 +121,7 @@ class PatchEmbed(nn.Module):
|
|||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
"""
|
||||
Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
|
||||
"""Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
|
||||
|
||||
This module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution, and
|
||||
projection phases, along with residual connections for improved gradient flow.
|
||||
|
|
@ -153,8 +148,7 @@ class MBConv(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, in_chans: int, out_chans: int, expand_ratio: float, activation, drop_path: float):
|
||||
"""
|
||||
Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.
|
||||
"""Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
|
|
@ -195,8 +189,7 @@ class MBConv(nn.Module):
|
|||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
"""
|
||||
Merge neighboring patches in the feature map and project to a new dimension.
|
||||
"""Merge neighboring patches in the feature map and project to a new dimension.
|
||||
|
||||
This class implements a patch merging operation that combines spatial information and adjusts the feature dimension
|
||||
using a series of convolutional layers with batch normalization. It effectively reduces spatial resolution while
|
||||
|
|
@ -221,8 +214,7 @@ class PatchMerging(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, input_resolution: tuple[int, int], dim: int, out_dim: int, activation):
|
||||
"""
|
||||
Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.
|
||||
"""Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int, int]): The input resolution (height, width) of the feature map.
|
||||
|
|
@ -259,8 +251,7 @@ class PatchMerging(nn.Module):
|
|||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
"""
|
||||
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
|
||||
"""Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
|
||||
|
||||
This layer optionally applies downsample operations to the output and supports gradient checkpointing for memory
|
||||
efficiency during training.
|
||||
|
|
@ -293,8 +284,7 @@ class ConvLayer(nn.Module):
|
|||
out_dim: int | None = None,
|
||||
conv_expand_ratio: float = 4.0,
|
||||
):
|
||||
"""
|
||||
Initialize the ConvLayer with the given dimensions and settings.
|
||||
"""Initialize the ConvLayer with the given dimensions and settings.
|
||||
|
||||
This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and optionally
|
||||
applies downsampling to the output.
|
||||
|
|
@ -345,8 +335,7 @@ class ConvLayer(nn.Module):
|
|||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Multi-layer Perceptron (MLP) module for transformer architectures.
|
||||
"""Multi-layer Perceptron (MLP) module for transformer architectures.
|
||||
|
||||
This module applies layer normalization, two fully-connected layers with an activation function in between, and
|
||||
dropout. It is commonly used in transformer-based architectures for processing token embeddings.
|
||||
|
|
@ -376,8 +365,7 @@ class MLP(nn.Module):
|
|||
activation=nn.GELU,
|
||||
drop: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.
|
||||
"""Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.
|
||||
|
||||
Args:
|
||||
in_features (int): Number of input features.
|
||||
|
|
@ -406,8 +394,7 @@ class MLP(nn.Module):
|
|||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
"""
|
||||
Multi-head attention module with spatial awareness and trainable attention biases.
|
||||
"""Multi-head attention module with spatial awareness and trainable attention biases.
|
||||
|
||||
This module implements a multi-head attention mechanism with support for spatial awareness, applying attention
|
||||
biases based on spatial resolution. It includes trainable attention biases for each unique offset between spatial
|
||||
|
|
@ -444,8 +431,7 @@ class Attention(torch.nn.Module):
|
|||
attn_ratio: float = 4,
|
||||
resolution: tuple[int, int] = (14, 14),
|
||||
):
|
||||
"""
|
||||
Initialize the Attention module for multi-head attention with spatial awareness.
|
||||
"""Initialize the Attention module for multi-head attention with spatial awareness.
|
||||
|
||||
This module implements a multi-head attention mechanism with support for spatial awareness, applying attention
|
||||
biases based on spatial resolution. It includes trainable attention biases for each unique offset between
|
||||
|
|
@ -521,8 +507,7 @@ class Attention(torch.nn.Module):
|
|||
|
||||
|
||||
class TinyViTBlock(nn.Module):
|
||||
"""
|
||||
TinyViT Block that applies self-attention and a local convolution to the input.
|
||||
"""TinyViT Block that applies self-attention and a local convolution to the input.
|
||||
|
||||
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with local
|
||||
convolutions to process input features efficiently. It supports windowed attention for computational efficiency and
|
||||
|
|
@ -559,8 +544,7 @@ class TinyViTBlock(nn.Module):
|
|||
local_conv_size: int = 3,
|
||||
activation=nn.GELU,
|
||||
):
|
||||
"""
|
||||
Initialize a TinyViT block with self-attention and local convolution.
|
||||
"""Initialize a TinyViT block with self-attention and local convolution.
|
||||
|
||||
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with local
|
||||
convolutions to process input features efficiently.
|
||||
|
|
@ -644,8 +628,7 @@ class TinyViTBlock(nn.Module):
|
|||
return x + self.drop_path(self.mlp(x))
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""
|
||||
Return a string representation of the TinyViTBlock's parameters.
|
||||
"""Return a string representation of the TinyViTBlock's parameters.
|
||||
|
||||
This method provides a formatted string containing key information about the TinyViTBlock, including its
|
||||
dimension, input resolution, number of attention heads, window size, and MLP ratio.
|
||||
|
|
@ -665,8 +648,7 @@ class TinyViTBlock(nn.Module):
|
|||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
"""
|
||||
A basic TinyViT layer for one stage in a TinyViT architecture.
|
||||
"""A basic TinyViT layer for one stage in a TinyViT architecture.
|
||||
|
||||
This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks and an optional
|
||||
downsampling operation. It processes features at a specific resolution and dimensionality within the overall
|
||||
|
|
@ -704,8 +686,7 @@ class BasicLayer(nn.Module):
|
|||
activation=nn.GELU,
|
||||
out_dim: int | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize a BasicLayer in the TinyViT architecture.
|
||||
"""Initialize a BasicLayer in the TinyViT architecture.
|
||||
|
||||
This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to process
|
||||
feature maps at a specific resolution and dimensionality within the TinyViT model.
|
||||
|
|
@ -770,8 +751,7 @@ class BasicLayer(nn.Module):
|
|||
|
||||
|
||||
class TinyViT(nn.Module):
|
||||
"""
|
||||
TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
|
||||
"""TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
|
||||
|
||||
This class implements the TinyViT model, which combines elements of vision transformers and convolutional neural
|
||||
networks for improved efficiency and performance on vision tasks. It features hierarchical processing with patch
|
||||
|
|
@ -815,8 +795,7 @@ class TinyViT(nn.Module):
|
|||
local_conv_size: int = 3,
|
||||
layer_lr_decay: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Initialize the TinyViT model.
|
||||
"""Initialize the TinyViT model.
|
||||
|
||||
This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of attention and
|
||||
convolution blocks, and a classification head.
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ from ultralytics.nn.modules import MLPBlock
|
|||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
"""
|
||||
A Two-Way Transformer module for simultaneous attention to image and query points.
|
||||
"""A Two-Way Transformer module for simultaneous attention to image and query points.
|
||||
|
||||
This class implements a specialized transformer decoder that attends to an input image using queries with supplied
|
||||
positional embeddings. It's useful for tasks like object detection, image segmentation, and point cloud processing.
|
||||
|
|
@ -47,8 +46,7 @@ class TwoWayTransformer(nn.Module):
|
|||
activation: type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a Two-Way Transformer for simultaneous attention to image and query points.
|
||||
"""Initialize a Two-Way Transformer for simultaneous attention to image and query points.
|
||||
|
||||
Args:
|
||||
depth (int): Number of layers in the transformer.
|
||||
|
|
@ -86,8 +84,7 @@ class TwoWayTransformer(nn.Module):
|
|||
image_pe: torch.Tensor,
|
||||
point_embedding: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Process image and point embeddings through the Two-Way Transformer.
|
||||
"""Process image and point embeddings through the Two-Way Transformer.
|
||||
|
||||
Args:
|
||||
image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
|
||||
|
|
@ -126,8 +123,7 @@ class TwoWayTransformer(nn.Module):
|
|||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
"""
|
||||
A two-way attention block for simultaneous attention to image and query points.
|
||||
"""A two-way attention block for simultaneous attention to image and query points.
|
||||
|
||||
This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
|
||||
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense inputs to
|
||||
|
|
@ -166,8 +162,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
attention_downsample_rate: int = 2,
|
||||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
|
||||
"""Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
|
||||
|
||||
This block implements a specialized transformer layer with four main components: self-attention on sparse
|
||||
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of
|
||||
|
|
@ -199,8 +194,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
def forward(
|
||||
self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply two-way attention to process query and key embeddings in a transformer block.
|
||||
"""Apply two-way attention to process query and key embeddings in a transformer block.
|
||||
|
||||
Args:
|
||||
queries (torch.Tensor): Query embeddings with shape (B, N_queries, embedding_dim).
|
||||
|
|
@ -244,8 +238,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer with downscaling capability for embedding size after projection.
|
||||
"""An attention layer with downscaling capability for embedding size after projection.
|
||||
|
||||
This class implements a multi-head attention mechanism with the option to downsample the internal dimension of
|
||||
queries, keys, and values.
|
||||
|
|
@ -281,8 +274,7 @@ class Attention(nn.Module):
|
|||
downsample_rate: int = 1,
|
||||
kv_in_dim: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Attention module with specified dimensions and settings.
|
||||
"""Initialize the Attention module with specified dimensions and settings.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): Dimensionality of input embeddings.
|
||||
|
|
@ -320,8 +312,7 @@ class Attention(nn.Module):
|
|||
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply multi-head attention to query, key, and value tensors with optional downsampling.
|
||||
"""Apply multi-head attention to query, key, and value tensors with optional downsampling.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): Query tensor with shape (B, N_q, embedding_dim).
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ import torch.nn.functional as F
|
|||
|
||||
|
||||
def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any], max_cond_frame_num: int):
|
||||
"""
|
||||
Select the closest conditioning frames to a given frame index.
|
||||
"""Select the closest conditioning frames to a given frame index.
|
||||
|
||||
Args:
|
||||
frame_idx (int): Current frame index.
|
||||
|
|
@ -62,8 +61,7 @@ def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any
|
|||
|
||||
|
||||
def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):
|
||||
"""
|
||||
Generate 1D sinusoidal positional embeddings for given positions and dimensions.
|
||||
"""Generate 1D sinusoidal positional embeddings for given positions and dimensions.
|
||||
|
||||
Args:
|
||||
pos_inds (torch.Tensor): Position indices for which to generate embeddings.
|
||||
|
|
@ -89,8 +87,7 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
|
|||
|
||||
|
||||
def init_t_xy(end_x: int, end_y: int):
|
||||
"""
|
||||
Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
|
||||
"""Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
|
||||
|
||||
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
|
||||
tensor and corresponding x and y coordinate tensors.
|
||||
|
|
@ -117,8 +114,7 @@ def init_t_xy(end_x: int, end_y: int):
|
|||
|
||||
|
||||
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
||||
"""
|
||||
Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
|
||||
"""Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
|
||||
|
||||
This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
|
||||
frequency components for the x and y dimensions.
|
||||
|
|
@ -150,8 +146,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
"""
|
||||
Reshape frequency tensor for broadcasting with input tensor.
|
||||
"""Reshape frequency tensor for broadcasting with input tensor.
|
||||
|
||||
Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor. This function
|
||||
is typically used in positional encoding operations.
|
||||
|
|
@ -179,8 +174,7 @@ def apply_rotary_enc(
|
|||
freqs_cis: torch.Tensor,
|
||||
repeat_freqs_k: bool = False,
|
||||
):
|
||||
"""
|
||||
Apply rotary positional encoding to query and key tensors.
|
||||
"""Apply rotary positional encoding to query and key tensors.
|
||||
|
||||
This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
|
||||
components. RoPE is a technique that injects relative position information into self-attention mechanisms.
|
||||
|
|
@ -188,10 +182,10 @@ def apply_rotary_enc(
|
|||
Args:
|
||||
xq (torch.Tensor): Query tensor to encode with positional information.
|
||||
xk (torch.Tensor): Key tensor to encode with positional information.
|
||||
freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
|
||||
last two dimensions of xq.
|
||||
repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
|
||||
to match key sequence length.
|
||||
freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the last
|
||||
two dimensions of xq.
|
||||
repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension to match
|
||||
key sequence length.
|
||||
|
||||
Returns:
|
||||
xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
|
||||
|
|
@ -220,8 +214,7 @@ def apply_rotary_enc(
|
|||
|
||||
|
||||
def window_partition(x: torch.Tensor, window_size: int):
|
||||
"""
|
||||
Partition input tensor into non-overlapping windows with padding if needed.
|
||||
"""Partition input tensor into non-overlapping windows with padding if needed.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with shape (B, H, W, C).
|
||||
|
|
@ -251,8 +244,7 @@ def window_partition(x: torch.Tensor, window_size: int):
|
|||
|
||||
|
||||
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]):
|
||||
"""
|
||||
Unpartition windowed sequences into original sequences and remove padding.
|
||||
"""Unpartition windowed sequences into original sequences and remove padding.
|
||||
|
||||
This function reverses the windowing process, reconstructing the original input from windowed segments and removing
|
||||
any padding that was added during the windowing process.
|
||||
|
|
@ -266,8 +258,8 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[in
|
|||
hw (tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
|
||||
are the original height and width, and C is the number of channels.
|
||||
(torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W are the
|
||||
original height and width, and C is the number of channels.
|
||||
|
||||
Examples:
|
||||
>>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
|
||||
|
|
@ -289,18 +281,16 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[in
|
|||
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extract relative positional embeddings based on query and key sizes.
|
||||
"""Extract relative positional embeddings based on query and key sizes.
|
||||
|
||||
Args:
|
||||
q_size (int): Size of the query.
|
||||
k_size (int): Size of the key.
|
||||
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
|
||||
distance and C is the embedding dimension.
|
||||
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative distance
|
||||
and C is the embedding dimension.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
|
||||
k_size, C).
|
||||
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size, k_size, C).
|
||||
|
||||
Examples:
|
||||
>>> q_size, k_size = 8, 16
|
||||
|
|
@ -338,8 +328,7 @@ def add_decomposed_rel_pos(
|
|||
q_size: tuple[int, int],
|
||||
k_size: tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add decomposed Relative Positional Embeddings to the attention map.
|
||||
"""Add decomposed Relative Positional Embeddings to the attention map.
|
||||
|
||||
This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
|
||||
paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
|
||||
|
|
@ -354,8 +343,8 @@ def add_decomposed_rel_pos(
|
|||
k_size (tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Updated attention map with added relative positional embeddings, shape
|
||||
(B, q_h * q_w, k_h * k_w).
|
||||
(torch.Tensor): Updated attention map with added relative positional embeddings, shape (B, q_h * q_w, k_h *
|
||||
k_w).
|
||||
|
||||
Examples:
|
||||
>>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
|
||||
|
|
|
|||
|
|
@ -38,8 +38,7 @@ from .amg import (
|
|||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
"""
|
||||
Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
|
||||
"""Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
|
||||
|
||||
This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image segmentation
|
||||
tasks. It supports various input prompts like points, bounding boxes, and masks for fine-grained control over
|
||||
|
|
@ -81,8 +80,7 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the Predictor with configuration, overrides, and callbacks.
|
||||
"""Initialize the Predictor with configuration, overrides, and callbacks.
|
||||
|
||||
Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
|
||||
callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True for
|
||||
|
|
@ -109,8 +107,7 @@ class Predictor(BasePredictor):
|
|||
self.segment_all = False
|
||||
|
||||
def preprocess(self, im):
|
||||
"""
|
||||
Preprocess the input image for model inference.
|
||||
"""Preprocess the input image for model inference.
|
||||
|
||||
This method prepares the input image by applying transformations and normalization. It supports both
|
||||
torch.Tensor and list of np.ndarray as input formats.
|
||||
|
|
@ -142,8 +139,7 @@ class Predictor(BasePredictor):
|
|||
return im
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
Perform initial transformations on the input image for preprocessing.
|
||||
"""Perform initial transformations on the input image for preprocessing.
|
||||
|
||||
This method applies transformations such as resizing to prepare the image for further preprocessing. Currently,
|
||||
batched inference is not supported; hence the list length should be 1.
|
||||
|
|
@ -169,8 +165,7 @@ class Predictor(BasePredictor):
|
|||
return [letterbox(image=x) for x in im]
|
||||
|
||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
||||
"""
|
||||
Perform image segmentation inference based on the given input cues, using the currently loaded image.
|
||||
"""Perform image segmentation inference based on the given input cues, using the currently loaded image.
|
||||
|
||||
This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder,
|
||||
and mask decoder for real-time and promptable segmentation tasks.
|
||||
|
|
@ -208,8 +203,7 @@ class Predictor(BasePredictor):
|
|||
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
||||
|
||||
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
||||
"""
|
||||
Perform image segmentation inference based on input cues using SAM's specialized architecture.
|
||||
"""Perform image segmentation inference based on input cues using SAM's specialized architecture.
|
||||
|
||||
This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation. It
|
||||
processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
|
||||
|
|
@ -248,8 +242,7 @@ class Predictor(BasePredictor):
|
|||
masks=None,
|
||||
multimask_output=False,
|
||||
):
|
||||
"""
|
||||
Perform inference on image features using the SAM model.
|
||||
"""Perform inference on image features using the SAM model.
|
||||
|
||||
Args:
|
||||
features (torch.Tensor): Extracted image features with shape (B, C, H, W) from the SAM model image encoder.
|
||||
|
|
@ -281,8 +274,7 @@ class Predictor(BasePredictor):
|
|||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||
|
||||
def _prepare_prompts(self, dst_shape, src_shape, bboxes=None, points=None, labels=None, masks=None):
|
||||
"""
|
||||
Prepare and transform the input prompts for processing based on the destination shape.
|
||||
"""Prepare and transform the input prompts for processing based on the destination shape.
|
||||
|
||||
Args:
|
||||
dst_shape (tuple[int, int]): The target shape (height, width) for the prompts.
|
||||
|
|
@ -346,8 +338,7 @@ class Predictor(BasePredictor):
|
|||
stability_score_offset=0.95,
|
||||
crop_nms_thresh=0.7,
|
||||
):
|
||||
"""
|
||||
Perform image segmentation using the Segment Anything Model (SAM).
|
||||
"""Perform image segmentation using the Segment Anything Model (SAM).
|
||||
|
||||
This method segments an entire image into constituent parts by leveraging SAM's advanced architecture and
|
||||
real-time performance capabilities. It can optionally work on image crops for finer segmentation.
|
||||
|
|
@ -445,8 +436,7 @@ class Predictor(BasePredictor):
|
|||
return pred_masks, pred_scores, pred_bboxes
|
||||
|
||||
def setup_model(self, model=None, verbose=True):
|
||||
"""
|
||||
Initialize the Segment Anything Model (SAM) for inference.
|
||||
"""Initialize the Segment Anything Model (SAM) for inference.
|
||||
|
||||
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
|
||||
parameters for image normalization and other Ultralytics compatibility settings.
|
||||
|
|
@ -484,8 +474,7 @@ class Predictor(BasePredictor):
|
|||
return build_sam(self.args.model)
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Post-process SAM's inference outputs to generate object detection masks and bounding boxes.
|
||||
"""Post-process SAM's inference outputs to generate object detection masks and bounding boxes.
|
||||
|
||||
This method scales masks and boxes to the original image size and applies a threshold to the mask
|
||||
predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
|
||||
|
|
@ -499,8 +488,8 @@ class Predictor(BasePredictor):
|
|||
orig_imgs (list[np.ndarray] | torch.Tensor): The original, unprocessed images.
|
||||
|
||||
Returns:
|
||||
(list[Results]): List of Results objects containing detection masks, bounding boxes, and other
|
||||
metadata for each processed image.
|
||||
(list[Results]): List of Results objects containing detection masks, bounding boxes, and other metadata for
|
||||
each processed image.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
|
|
@ -537,15 +526,14 @@ class Predictor(BasePredictor):
|
|||
return results
|
||||
|
||||
def setup_source(self, source):
|
||||
"""
|
||||
Set up the data source for inference.
|
||||
"""Set up the data source for inference.
|
||||
|
||||
This method configures the data source from which images will be fetched for inference. It supports various
|
||||
input types such as image files, directories, video files, and other compatible data sources.
|
||||
|
||||
Args:
|
||||
source (str | Path | None): The path or identifier for the image data source. Can be a file path,
|
||||
directory path, URL, or other supported source types.
|
||||
source (str | Path | None): The path or identifier for the image data source. Can be a file path, directory
|
||||
path, URL, or other supported source types.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
|
|
@ -562,16 +550,15 @@ class Predictor(BasePredictor):
|
|||
super().setup_source(source)
|
||||
|
||||
def set_image(self, image):
|
||||
"""
|
||||
Preprocess and set a single image for inference.
|
||||
"""Preprocess and set a single image for inference.
|
||||
|
||||
This method prepares the model for inference on a single image by setting up the model if not already
|
||||
initialized, configuring the data source, and preprocessing the image for feature extraction. It ensures that
|
||||
only one image is set at a time and extracts image features for subsequent use.
|
||||
|
||||
Args:
|
||||
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
|
||||
an image read by cv2.
|
||||
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing an image read by
|
||||
cv2.
|
||||
|
||||
Raises:
|
||||
AssertionError: If more than one image is attempted to be set.
|
||||
|
|
@ -613,8 +600,7 @@ class Predictor(BasePredictor):
|
|||
|
||||
@staticmethod
|
||||
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
||||
"""
|
||||
Remove small disconnected regions and holes from segmentation masks.
|
||||
"""Remove small disconnected regions and holes from segmentation masks.
|
||||
|
||||
This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM). It
|
||||
removes small disconnected regions and holes from the input masks, and then performs Non-Maximum Suppression
|
||||
|
|
@ -675,8 +661,7 @@ class Predictor(BasePredictor):
|
|||
masks=None,
|
||||
multimask_output=False,
|
||||
):
|
||||
"""
|
||||
Perform prompts preprocessing and inference on provided image features using the SAM model.
|
||||
"""Perform prompts preprocessing and inference on provided image features using the SAM model.
|
||||
|
||||
Args:
|
||||
features (torch.Tensor | dict[str, Any]): Extracted image features from the SAM/SAM2 model image encoder.
|
||||
|
|
@ -714,8 +699,7 @@ class Predictor(BasePredictor):
|
|||
|
||||
|
||||
class SAM2Predictor(Predictor):
|
||||
"""
|
||||
SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
|
||||
"""SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
|
||||
|
||||
This class extends the base Predictor class to implement SAM2-specific functionality for image segmentation tasks.
|
||||
It provides methods for model initialization, feature extraction, and prompt-based inference.
|
||||
|
|
@ -755,8 +739,7 @@ class SAM2Predictor(Predictor):
|
|||
return build_sam(self.args.model)
|
||||
|
||||
def _prepare_prompts(self, dst_shape, src_shape, bboxes=None, points=None, labels=None, masks=None):
|
||||
"""
|
||||
Prepare and transform the input prompts for processing based on the destination shape.
|
||||
"""Prepare and transform the input prompts for processing based on the destination shape.
|
||||
|
||||
Args:
|
||||
dst_shape (tuple[int, int]): The target shape (height, width) for the prompts.
|
||||
|
|
@ -790,8 +773,7 @@ class SAM2Predictor(Predictor):
|
|||
return points, labels, masks
|
||||
|
||||
def set_image(self, image):
|
||||
"""
|
||||
Preprocess and set a single image for inference using the SAM2 model.
|
||||
"""Preprocess and set a single image for inference using the SAM2 model.
|
||||
|
||||
This method initializes the model if not already done, configures the data source to the specified image, and
|
||||
preprocesses the image for feature extraction. It supports setting only one image at a time.
|
||||
|
|
@ -847,8 +829,7 @@ class SAM2Predictor(Predictor):
|
|||
multimask_output=False,
|
||||
img_idx=-1,
|
||||
):
|
||||
"""
|
||||
Perform inference on image features using the SAM2 model.
|
||||
"""Perform inference on image features using the SAM2 model.
|
||||
|
||||
Args:
|
||||
features (torch.Tensor | dict[str, Any]): Extracted image features with shape (B, C, H, W) from the SAM2
|
||||
|
|
@ -892,8 +873,7 @@ class SAM2Predictor(Predictor):
|
|||
|
||||
|
||||
class SAM2VideoPredictor(SAM2Predictor):
|
||||
"""
|
||||
SAM2VideoPredictor to handle user interactions with videos and manage inference states.
|
||||
"""SAM2VideoPredictor to handle user interactions with videos and manage inference states.
|
||||
|
||||
This class extends the functionality of SAM2Predictor to support video processing and maintains the state of
|
||||
inference operations. It includes configurations for managing non-overlapping masks, clearing memory for
|
||||
|
|
@ -929,8 +909,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
# fill_hole_area = 8 # not used
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the predictor with configuration and optional overrides.
|
||||
"""Initialize the predictor with configuration and optional overrides.
|
||||
|
||||
This constructor initializes the SAM2VideoPredictor with a given configuration, applies any specified overrides,
|
||||
and sets up the inference state along with certain flags that control the behavior of the predictor.
|
||||
|
|
@ -953,8 +932,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
self.callbacks["on_predict_start"].append(self.init_state)
|
||||
|
||||
def get_model(self):
|
||||
"""
|
||||
Retrieve and configure the model with binarization enabled.
|
||||
"""Retrieve and configure the model with binarization enabled.
|
||||
|
||||
Notes:
|
||||
This method overrides the base class implementation to set the binarize flag to True.
|
||||
|
|
@ -964,10 +942,9 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
return model
|
||||
|
||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None):
|
||||
"""
|
||||
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
||||
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
|
||||
mask decoder for real-time and promptable segmentation tasks.
|
||||
"""Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
||||
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
|
||||
encoder, and mask decoder for real-time and promptable segmentation tasks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
|
|
@ -1037,8 +1014,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
return pred_masks, torch.ones(pred_masks.shape[0], dtype=pred_masks.dtype, device=pred_masks.device)
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Post-process the predictions to apply non-overlapping constraints if required.
|
||||
"""Post-process the predictions to apply non-overlapping constraints if required.
|
||||
|
||||
This method extends the post-processing functionality by applying non-overlapping constraints to the predicted
|
||||
masks if the `non_overlap_masks` flag is set to True. This ensures that the masks do not overlap, which can be
|
||||
|
|
@ -1072,8 +1048,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
masks=None,
|
||||
frame_idx=0,
|
||||
):
|
||||
"""
|
||||
Add new points or masks to a specific frame for a given object ID.
|
||||
"""Add new points or masks to a specific frame for a given object ID.
|
||||
|
||||
This method updates the inference state with new prompts (points or masks) for a specified object and frame
|
||||
index. It ensures that the prompts are either points or masks, but not both, and updates the internal state
|
||||
|
|
@ -1169,8 +1144,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
|
||||
@smart_inference_mode()
|
||||
def propagate_in_video_preflight(self):
|
||||
"""
|
||||
Prepare inference_state and consolidate temporary outputs before tracking.
|
||||
"""Prepare inference_state and consolidate temporary outputs before tracking.
|
||||
|
||||
This method marks the start of tracking, disallowing the addition of new objects until the session is reset. It
|
||||
consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. Additionally,
|
||||
|
|
@ -1240,8 +1214,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
|
||||
@staticmethod
|
||||
def init_state(predictor):
|
||||
"""
|
||||
Initialize an inference state for the predictor.
|
||||
"""Initialize an inference state for the predictor.
|
||||
|
||||
This function sets up the initial state required for performing inference on video data. It includes
|
||||
initializing various dictionaries and ordered dictionaries that will store inputs, outputs, and other metadata
|
||||
|
|
@ -1287,8 +1260,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
predictor.inference_state = inference_state
|
||||
|
||||
def get_im_features(self, im, batch=1):
|
||||
"""
|
||||
Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.
|
||||
"""Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The input image tensor.
|
||||
|
|
@ -1315,8 +1287,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
return vis_feats, vis_pos_embed, feat_sizes
|
||||
|
||||
def _obj_id_to_idx(self, obj_id):
|
||||
"""
|
||||
Map client-side object id to model-side object index.
|
||||
"""Map client-side object id to model-side object index.
|
||||
|
||||
Args:
|
||||
obj_id (int): The unique identifier of the object provided by the client side.
|
||||
|
|
@ -1378,8 +1349,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
run_mem_encoder,
|
||||
prev_sam_mask_logits=None,
|
||||
):
|
||||
"""
|
||||
Run tracking on a single frame based on current inputs and previous memory.
|
||||
"""Run tracking on a single frame based on current inputs and previous memory.
|
||||
|
||||
Args:
|
||||
output_dict (dict): The dictionary containing the output states of the tracking process.
|
||||
|
|
@ -1442,8 +1412,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
return current_out
|
||||
|
||||
def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
|
||||
"""
|
||||
Cache and manage the positional encoding for mask memory across frames and objects.
|
||||
"""Cache and manage the positional encoding for mask memory across frames and objects.
|
||||
|
||||
This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for mask memory, which is
|
||||
constant across frames and objects, thus reducing the amount of redundant information stored during an inference
|
||||
|
|
@ -1452,8 +1421,8 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
batch size.
|
||||
|
||||
Args:
|
||||
out_maskmem_pos_enc (list[torch.Tensor] | None): The positional encoding for mask memory.
|
||||
Should be a list of tensors or None.
|
||||
out_maskmem_pos_enc (list[torch.Tensor] | None): The positional encoding for mask memory. Should be a list
|
||||
of tensors or None.
|
||||
|
||||
Returns:
|
||||
(list[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
|
||||
|
|
@ -1486,8 +1455,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
is_cond=False,
|
||||
run_mem_encoder=False,
|
||||
):
|
||||
"""
|
||||
Consolidate per-object temporary outputs into a single output for all objects.
|
||||
"""Consolidate per-object temporary outputs into a single output for all objects.
|
||||
|
||||
This method combines the temporary outputs for each object on a given frame into a unified
|
||||
output. It fills in any missing objects either from the main output dictionary or leaves
|
||||
|
|
@ -1497,8 +1465,8 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
Args:
|
||||
frame_idx (int): The index of the frame for which to consolidate outputs.
|
||||
is_cond (bool, optional): Indicates if the frame is considered a conditioning frame.
|
||||
run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after
|
||||
consolidating the outputs.
|
||||
run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after consolidating the
|
||||
outputs.
|
||||
|
||||
Returns:
|
||||
(dict): A consolidated output dictionary containing the combined results for all objects.
|
||||
|
|
@ -1587,8 +1555,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
return consolidated_out
|
||||
|
||||
def _get_empty_mask_ptr(self, frame_idx):
|
||||
"""
|
||||
Get a dummy object pointer based on an empty mask on the current frame.
|
||||
"""Get a dummy object pointer based on an empty mask on the current frame.
|
||||
|
||||
Args:
|
||||
frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
|
||||
|
|
@ -1618,8 +1585,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
return current_out["obj_ptr"]
|
||||
|
||||
def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
|
||||
"""
|
||||
Run the memory encoder on masks.
|
||||
"""Run the memory encoder on masks.
|
||||
|
||||
This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
|
||||
memory also needs to be computed again with the memory encoder.
|
||||
|
|
@ -1651,8 +1617,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
), maskmem_pos_enc
|
||||
|
||||
def _add_output_per_object(self, frame_idx, current_out, storage_key):
|
||||
"""
|
||||
Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
|
||||
"""Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
|
||||
|
||||
The resulting slices share the same tensor storage.
|
||||
|
||||
|
|
@ -1682,8 +1647,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
obj_output_dict[storage_key][frame_idx] = obj_out
|
||||
|
||||
def _clear_non_cond_mem_around_input(self, frame_idx):
|
||||
"""
|
||||
Remove the non-conditioning memory around the input frame.
|
||||
"""Remove the non-conditioning memory around the input frame.
|
||||
|
||||
When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain
|
||||
outdated object appearance information and could confuse the model. This method clears those non-conditioning
|
||||
|
|
@ -1703,8 +1667,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|||
|
||||
|
||||
class SAM2DynamicInteractivePredictor(SAM2Predictor):
|
||||
"""
|
||||
SAM2DynamicInteractivePredictor extends SAM2Predictor to support dynamic interactions with video frames or a
|
||||
"""SAM2DynamicInteractivePredictor extends SAM2Predictor to support dynamic interactions with video frames or a
|
||||
sequence of images.
|
||||
|
||||
Attributes:
|
||||
|
|
@ -1736,8 +1699,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
|
|||
max_obj_num: int = 3,
|
||||
_callbacks: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the predictor with configuration and optional overrides.
|
||||
"""Initialize the predictor with configuration and optional overrides.
|
||||
|
||||
This constructor initializes the SAM2DynamicInteractivePredictor with a given configuration, applies any
|
||||
specified overrides
|
||||
|
|
@ -1783,12 +1745,11 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
|
|||
obj_ids: list[int] | None = None,
|
||||
update_memory: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform inference on a single image with optional bounding boxes, masks, points and object IDs. It has two
|
||||
modes: one is to run inference on a single image without updating the memory, and the other is to update the
|
||||
memory with the provided prompts and object IDs. When update_memory is True, it will update the memory with the
|
||||
provided prompts and obj_ids. When update_memory is False, it will only run inference on the provided image
|
||||
without updating the memory.
|
||||
"""Perform inference on a single image with optional bounding boxes, masks, points and object IDs. It has two
|
||||
modes: one is to run inference on a single image without updating the memory, and the other is to update
|
||||
the memory with the provided prompts and object IDs. When update_memory is True, it will update the
|
||||
memory with the provided prompts and obj_ids. When update_memory is False, it will only run inference on
|
||||
the provided image without updating the memory.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | np.ndarray): The input image tensor or numpy array.
|
||||
|
|
@ -1842,8 +1803,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
|
|||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||
|
||||
def get_im_features(self, img: torch.Tensor | np.ndarray) -> None:
|
||||
"""
|
||||
Initialize the image state by processing the input image and extracting features.
|
||||
"""Initialize the image state by processing the input image and extracting features.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor | np.ndarray): The input image tensor or numpy array.
|
||||
|
|
@ -1866,8 +1826,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
|
|||
labels: torch.Tensor | None = None,
|
||||
masks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Append the imgState to the memory_bank and update the memory for the model.
|
||||
"""Append the imgState to the memory_bank and update the memory for the model.
|
||||
|
||||
Args:
|
||||
obj_ids (list[int]): List of object IDs corresponding to the prompts.
|
||||
|
|
@ -1941,12 +1900,11 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
|
|||
self.memory_bank.append(consolidated_out)
|
||||
|
||||
def _prepare_memory_conditioned_features(self, obj_idx: int | None) -> torch.Tensor:
|
||||
"""
|
||||
Prepare the memory-conditioned features for the current image state. If obj_idx is provided, it supposes to
|
||||
prepare features for a specific prompted object in the image. If obj_idx is None, it prepares features for all
|
||||
objects in the image. If there is no memory, it will directly add a no-memory embedding to the current vision
|
||||
features. If there is memory, it will use the memory features from previous frames to condition the current
|
||||
vision features using a transformer attention mechanism.
|
||||
"""Prepare the memory-conditioned features for the current image state. If obj_idx is provided, it supposes to
|
||||
prepare features for a specific prompted object in the image. If obj_idx is None, it prepares features
|
||||
for all objects in the image. If there is no memory, it will directly add a no-memory embedding to the
|
||||
current vision features. If there is memory, it will use the memory features from previous frames to
|
||||
condition the current vision features using a transformer attention mechanism.
|
||||
|
||||
Args:
|
||||
obj_idx (int | None): The index of the object for which to prepare the features.
|
||||
|
|
@ -1989,8 +1947,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
|
|||
return memory, memory_pos_embed
|
||||
|
||||
def _obj_id_to_idx(self, obj_id: int) -> int | None:
|
||||
"""
|
||||
Map client-side object id to model-side object index.
|
||||
"""Map client-side object id to model-side object index.
|
||||
|
||||
Args:
|
||||
obj_id (int): The client-side object ID.
|
||||
|
|
@ -2007,8 +1964,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
|
|||
label: torch.Tensor | None = None,
|
||||
mask: torch.Tensor | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Tracking step for the current image state to predict masks.
|
||||
"""Tracking step for the current image state to predict masks.
|
||||
|
||||
This method processes the image features and runs the SAM heads to predict masks. If obj_idx is provided, it
|
||||
processes the features for a specific prompted object in the image. If obj_idx is None, it processes the
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from .ops import HungarianMatcher
|
|||
|
||||
|
||||
class DETRLoss(nn.Module):
|
||||
"""
|
||||
DETR (DEtection TRansformer) Loss class for calculating various loss components.
|
||||
"""DETR (DEtection TRansformer) Loss class for calculating various loss components.
|
||||
|
||||
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the DETR
|
||||
object detection model.
|
||||
|
|
@ -47,8 +46,7 @@ class DETRLoss(nn.Module):
|
|||
gamma: float = 1.5,
|
||||
alpha: float = 0.25,
|
||||
):
|
||||
"""
|
||||
Initialize DETR loss function with customizable components and gains.
|
||||
"""Initialize DETR loss function with customizable components and gains.
|
||||
|
||||
Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
|
||||
losses and various loss types.
|
||||
|
|
@ -82,8 +80,7 @@ class DETRLoss(nn.Module):
|
|||
def _get_loss_class(
|
||||
self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute classification loss based on predictions, target values, and ground truth scores.
|
||||
"""Compute classification loss based on predictions, target values, and ground truth scores.
|
||||
|
||||
Args:
|
||||
pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
|
||||
|
|
@ -124,8 +121,7 @@ class DETRLoss(nn.Module):
|
|||
def _get_loss_bbox(
|
||||
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
|
||||
"""Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
|
||||
|
|
@ -199,8 +195,7 @@ class DETRLoss(nn.Module):
|
|||
masks: torch.Tensor | None = None,
|
||||
gt_mask: torch.Tensor | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get auxiliary losses for intermediate decoder layers.
|
||||
"""Get auxiliary losses for intermediate decoder layers.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
|
||||
|
|
@ -261,8 +256,7 @@ class DETRLoss(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _get_index(match_indices: list[tuple]) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
"""
|
||||
Extract batch indices, source indices, and destination indices from match indices.
|
||||
"""Extract batch indices, source indices, and destination indices from match indices.
|
||||
|
||||
Args:
|
||||
match_indices (list[tuple]): List of tuples containing matched indices.
|
||||
|
|
@ -279,8 +273,7 @@ class DETRLoss(nn.Module):
|
|||
def _get_assigned_bboxes(
|
||||
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: list[tuple]
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
|
||||
"""Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
||||
|
|
@ -317,8 +310,7 @@ class DETRLoss(nn.Module):
|
|||
postfix: str = "",
|
||||
match_indices: list[tuple] | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate losses for a single prediction layer.
|
||||
"""Calculate losses for a single prediction layer.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
||||
|
|
@ -364,8 +356,7 @@ class DETRLoss(nn.Module):
|
|||
postfix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate loss for predicted bounding boxes and scores.
|
||||
"""Calculate loss for predicted bounding boxes and scores.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
|
||||
|
|
@ -400,8 +391,7 @@ class DETRLoss(nn.Module):
|
|||
|
||||
|
||||
class RTDETRDetectionLoss(DETRLoss):
|
||||
"""
|
||||
Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
||||
"""Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
||||
|
||||
This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
|
||||
an additional denoising training loss when provided with denoising metadata.
|
||||
|
|
@ -415,8 +405,7 @@ class RTDETRDetectionLoss(DETRLoss):
|
|||
dn_scores: torch.Tensor | None = None,
|
||||
dn_meta: dict[str, Any] | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass to compute detection loss with optional denoising loss.
|
||||
"""Forward pass to compute detection loss with optional denoising loss.
|
||||
|
||||
Args:
|
||||
preds (tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
|
||||
|
|
@ -452,8 +441,7 @@ class RTDETRDetectionLoss(DETRLoss):
|
|||
def get_dn_match_indices(
|
||||
dn_pos_idx: list[torch.Tensor], dn_num_group: int, gt_groups: list[int]
|
||||
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Get match indices for denoising.
|
||||
"""Get match indices for denoising.
|
||||
|
||||
Args:
|
||||
dn_pos_idx (list[torch.Tensor]): List of tensors containing positive indices for denoising.
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
|||
|
||||
|
||||
class HungarianMatcher(nn.Module):
|
||||
"""
|
||||
A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
|
||||
"""A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
|
||||
|
||||
HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
|
||||
function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
|
||||
|
|
@ -56,8 +55,7 @@ class HungarianMatcher(nn.Module):
|
|||
alpha: float = 0.25,
|
||||
gamma: float = 2.0,
|
||||
):
|
||||
"""
|
||||
Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
|
||||
"""Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
|
||||
|
|
@ -88,8 +86,7 @@ class HungarianMatcher(nn.Module):
|
|||
masks: torch.Tensor | None = None,
|
||||
gt_mask: list[torch.Tensor] | None = None,
|
||||
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
|
||||
"""Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
|
||||
|
||||
This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
|
||||
mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
|
||||
|
|
@ -105,9 +102,9 @@ class HungarianMatcher(nn.Module):
|
|||
gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
|
||||
|
||||
Returns:
|
||||
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple
|
||||
(index_i, index_j), where index_i is the tensor of indices of the selected predictions (in order) and
|
||||
index_j is the tensor of indices of the corresponding selected ground truth targets (in order).
|
||||
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i,
|
||||
index_j), where index_i is the tensor of indices of the selected predictions (in order) and index_j is
|
||||
the tensor of indices of the corresponding selected ground truth targets (in order).
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
|
||||
"""
|
||||
bs, nq, nc = pred_scores.shape
|
||||
|
|
@ -198,16 +195,15 @@ def get_cdn_group(
|
|||
box_noise_scale: float = 1.0,
|
||||
training: bool = False,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
|
||||
"""
|
||||
Generate contrastive denoising training group with positive and negative samples from ground truths.
|
||||
"""Generate contrastive denoising training group with positive and negative samples from ground truths.
|
||||
|
||||
This function creates denoising queries for contrastive denoising training by adding noise to ground truth bounding
|
||||
boxes and class labels. It generates both positive and negative samples to improve model robustness.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),
|
||||
'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of ground
|
||||
truths per image.
|
||||
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)), 'gt_bboxes'
|
||||
(torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of ground truths
|
||||
per image.
|
||||
num_classes (int): Total number of object classes.
|
||||
num_queries (int): Number of object queries.
|
||||
class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|||
|
||||
|
||||
class ClassificationPredictor(BasePredictor):
|
||||
"""
|
||||
A class extending the BasePredictor class for prediction based on a classification model.
|
||||
"""A class extending the BasePredictor class for prediction based on a classification model.
|
||||
|
||||
This predictor handles the specific requirements of classification models, including preprocessing images and
|
||||
postprocessing predictions to generate classification results.
|
||||
|
|
@ -36,8 +35,7 @@ class ClassificationPredictor(BasePredictor):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
||||
"""Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
||||
|
||||
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
|
||||
tasks. It ensures the task is set to 'classify' regardless of input configuration.
|
||||
|
|
@ -72,8 +70,7 @@ class ClassificationPredictor(BasePredictor):
|
|||
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Process predictions to return Results objects with classification probabilities.
|
||||
"""Process predictions to return Results objects with classification probabilities.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
|
|
|
|||
|
|
@ -17,8 +17,7 @@ from ultralytics.utils.torch_utils import is_parallel, torch_distributed_zero_fi
|
|||
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
"""
|
||||
A trainer class extending BaseTrainer for training image classification models.
|
||||
"""A trainer class extending BaseTrainer for training image classification models.
|
||||
|
||||
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
|
||||
and torchvision models with comprehensive dataset handling and validation.
|
||||
|
|
@ -51,8 +50,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a ClassificationTrainer object.
|
||||
"""Initialize a ClassificationTrainer object.
|
||||
|
||||
Args:
|
||||
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
|
||||
|
|
@ -71,8 +69,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||
self.model.names = self.data["names"]
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
||||
"""
|
||||
Return a modified PyTorch model configured for training YOLO classification.
|
||||
"""Return a modified PyTorch model configured for training YOLO classification.
|
||||
|
||||
Args:
|
||||
cfg (Any, optional): Model configuration.
|
||||
|
|
@ -96,8 +93,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||
return model
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
Load, create or download model for classification tasks.
|
||||
"""Load, create or download model for classification tasks.
|
||||
|
||||
Returns:
|
||||
(Any): Model checkpoint if applicable, otherwise None.
|
||||
|
|
@ -115,8 +111,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||
return ckpt
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
||||
"""
|
||||
Create a ClassificationDataset instance given an image path and mode.
|
||||
"""Create a ClassificationDataset instance given an image path and mode.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the dataset images.
|
||||
|
|
@ -129,8 +124,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
||||
|
||||
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
||||
"""
|
||||
Return PyTorch DataLoader with transforms to preprocess images.
|
||||
"""Return PyTorch DataLoader with transforms to preprocess images.
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path to the dataset.
|
||||
|
|
@ -177,8 +171,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||
)
|
||||
|
||||
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
|
||||
"""
|
||||
Return a loss dict with labeled training loss items tensor.
|
||||
"""Return a loss dict with labeled training loss items tensor.
|
||||
|
||||
Args:
|
||||
loss_items (torch.Tensor, optional): Loss tensor items.
|
||||
|
|
@ -195,8 +188,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||
return dict(zip(keys, loss_items))
|
||||
|
||||
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
|
||||
"""
|
||||
Plot training samples with their annotations.
|
||||
"""Plot training samples with their annotations.
|
||||
|
||||
Args:
|
||||
batch (dict[str, torch.Tensor]): Batch containing images and class labels.
|
||||
|
|
|
|||
|
|
@ -16,8 +16,7 @@ from ultralytics.utils.plotting import plot_images
|
|||
|
||||
|
||||
class ClassificationValidator(BaseValidator):
|
||||
"""
|
||||
A class extending the BaseValidator class for validation based on a classification model.
|
||||
"""A class extending the BaseValidator class for validation based on a classification model.
|
||||
|
||||
This validator handles the validation process for classification models, including metrics calculation, confusion
|
||||
matrix generation, and visualization of results.
|
||||
|
|
@ -55,8 +54,7 @@ class ClassificationValidator(BaseValidator):
|
|||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
||||
"""Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
||||
|
|
@ -96,8 +94,7 @@ class ClassificationValidator(BaseValidator):
|
|||
return batch
|
||||
|
||||
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update running metrics with model predictions and batch targets.
|
||||
"""Update running metrics with model predictions and batch targets.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
||||
|
|
@ -112,8 +109,7 @@ class ClassificationValidator(BaseValidator):
|
|||
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
||||
|
||||
def finalize_metrics(self) -> None:
|
||||
"""
|
||||
Finalize metrics including confusion matrix and processing speed.
|
||||
"""Finalize metrics including confusion matrix and processing speed.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
|
|
@ -161,8 +157,7 @@ class ClassificationValidator(BaseValidator):
|
|||
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
||||
|
||||
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
|
||||
"""
|
||||
Build and return a data loader for classification validation.
|
||||
"""Build and return a data loader for classification validation.
|
||||
|
||||
Args:
|
||||
dataset_path (str | Path): Path to the dataset directory.
|
||||
|
|
@ -180,8 +175,7 @@ class ClassificationValidator(BaseValidator):
|
|||
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
||||
|
||||
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
||||
"""
|
||||
Plot validation image samples with their ground truth labels.
|
||||
"""Plot validation image samples with their ground truth labels.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
||||
|
|
@ -201,8 +195,7 @@ class ClassificationValidator(BaseValidator):
|
|||
)
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
|
||||
"""
|
||||
Plot images with their predicted class labels and save the visualization.
|
||||
"""Plot images with their predicted class labels and save the visualization.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch data containing images and other information.
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ from ultralytics.utils import nms, ops
|
|||
|
||||
|
||||
class DetectionPredictor(BasePredictor):
|
||||
"""
|
||||
A class extending the BasePredictor class for prediction based on a detection model.
|
||||
"""A class extending the BasePredictor class for prediction based on a detection model.
|
||||
|
||||
This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
|
||||
with bounding boxes and class predictions.
|
||||
|
|
@ -32,8 +31,7 @@ class DetectionPredictor(BasePredictor):
|
|||
"""
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs, **kwargs):
|
||||
"""
|
||||
Post-process predictions and return a list of Results objects.
|
||||
"""Post-process predictions and return a list of Results objects.
|
||||
|
||||
This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
|
||||
further analysis.
|
||||
|
|
@ -92,8 +90,7 @@ class DetectionPredictor(BasePredictor):
|
|||
return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
|
||||
|
||||
def construct_results(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Construct a list of Results objects from model predictions.
|
||||
"""Construct a list of Results objects from model predictions.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
|
||||
|
|
@ -109,8 +106,7 @@ class DetectionPredictor(BasePredictor):
|
|||
]
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Construct a single Results object from one image prediction.
|
||||
"""Construct a single Results object from one image prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
|
||||
|
|
|
|||
|
|
@ -22,8 +22,7 @@ from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_m
|
|||
|
||||
|
||||
class DetectionTrainer(BaseTrainer):
|
||||
"""
|
||||
A class extending the BaseTrainer class for training based on a detection model.
|
||||
"""A class extending the BaseTrainer class for training based on a detection model.
|
||||
|
||||
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for
|
||||
object detection including dataset building, data loading, preprocessing, and model configuration.
|
||||
|
|
@ -54,8 +53,7 @@ class DetectionTrainer(BaseTrainer):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a DetectionTrainer object for training YOLO object detection model training.
|
||||
"""Initialize a DetectionTrainer object for training YOLO object detection model training.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
||||
|
|
@ -65,8 +63,7 @@ class DetectionTrainer(BaseTrainer):
|
|||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation.
|
||||
"""Build YOLO Dataset for training or validation.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
|
|
@ -80,8 +77,7 @@ class DetectionTrainer(BaseTrainer):
|
|||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||
|
||||
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
||||
"""
|
||||
Construct and return dataloader for the specified mode.
|
||||
"""Construct and return dataloader for the specified mode.
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path to the dataset.
|
||||
|
|
@ -109,8 +105,7 @@ class DetectionTrainer(BaseTrainer):
|
|||
)
|
||||
|
||||
def preprocess_batch(self, batch: dict) -> dict:
|
||||
"""
|
||||
Preprocess a batch of images by scaling and converting to float.
|
||||
"""Preprocess a batch of images by scaling and converting to float.
|
||||
|
||||
Args:
|
||||
batch (dict): Dictionary containing batch data with 'img' tensor.
|
||||
|
|
@ -150,8 +145,7 @@ class DetectionTrainer(BaseTrainer):
|
|||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||
|
||||
def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
|
||||
"""
|
||||
Return a YOLO detection model.
|
||||
"""Return a YOLO detection model.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to model configuration file.
|
||||
|
|
@ -174,8 +168,7 @@ class DetectionTrainer(BaseTrainer):
|
|||
)
|
||||
|
||||
def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
|
||||
"""
|
||||
Return a loss dict with labeled training loss items tensor.
|
||||
"""Return a loss dict with labeled training loss items tensor.
|
||||
|
||||
Args:
|
||||
loss_items (list[float], optional): List of loss values.
|
||||
|
|
@ -202,8 +195,7 @@ class DetectionTrainer(BaseTrainer):
|
|||
)
|
||||
|
||||
def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
|
||||
"""
|
||||
Plot training samples with their annotations.
|
||||
"""Plot training samples with their annotations.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Dictionary containing batch data.
|
||||
|
|
@ -223,8 +215,7 @@ class DetectionTrainer(BaseTrainer):
|
|||
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
||||
def auto_batch(self):
|
||||
"""
|
||||
Get optimal batch size by calculating memory occupation of model.
|
||||
"""Get optimal batch size by calculating memory occupation of model.
|
||||
|
||||
Returns:
|
||||
(int): Optimal batch size.
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ from ultralytics.utils.plotting import plot_images
|
|||
|
||||
|
||||
class DetectionValidator(BaseValidator):
|
||||
"""
|
||||
A class extending the BaseValidator class for validation based on a detection model.
|
||||
"""A class extending the BaseValidator class for validation based on a detection model.
|
||||
|
||||
This class implements validation functionality specific to object detection tasks, including metrics calculation,
|
||||
prediction processing, and visualization of results.
|
||||
|
|
@ -44,8 +43,7 @@ class DetectionValidator(BaseValidator):
|
|||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize detection validator with necessary variables and settings.
|
||||
"""Initialize detection validator with necessary variables and settings.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
||||
|
|
@ -63,8 +61,7 @@ class DetectionValidator(BaseValidator):
|
|||
self.metrics = DetMetrics()
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Preprocess batch of images for YOLO validation.
|
||||
"""Preprocess batch of images for YOLO validation.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
|
|
@ -79,8 +76,7 @@ class DetectionValidator(BaseValidator):
|
|||
return batch
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize evaluation metrics for YOLO detection validation.
|
||||
"""Initialize evaluation metrics for YOLO detection validation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
|
|
@ -107,15 +103,14 @@ class DetectionValidator(BaseValidator):
|
|||
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
||||
|
||||
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Apply Non-maximum suppression to prediction outputs.
|
||||
"""Apply Non-maximum suppression to prediction outputs.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
|
||||
Returns:
|
||||
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
|
||||
'bboxes', 'conf', 'cls', and 'extra' tensors.
|
||||
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains 'bboxes', 'conf',
|
||||
'cls', and 'extra' tensors.
|
||||
"""
|
||||
outputs = nms.non_max_suppression(
|
||||
preds,
|
||||
|
|
@ -131,8 +126,7 @@ class DetectionValidator(BaseValidator):
|
|||
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare a batch of images and annotations for validation.
|
||||
"""Prepare a batch of images and annotations for validation.
|
||||
|
||||
Args:
|
||||
si (int): Batch index.
|
||||
|
|
@ -159,8 +153,7 @@ class DetectionValidator(BaseValidator):
|
|||
}
|
||||
|
||||
def _prepare_pred(self, pred: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Prepare predictions for evaluation against ground truth.
|
||||
"""Prepare predictions for evaluation against ground truth.
|
||||
|
||||
Args:
|
||||
pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
|
||||
|
|
@ -173,8 +166,7 @@ class DetectionValidator(BaseValidator):
|
|||
return pred
|
||||
|
||||
def update_metrics(self, preds: list[dict[str, torch.Tensor]], batch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update metrics with new predictions and ground truth.
|
||||
"""Update metrics with new predictions and ground truth.
|
||||
|
||||
Args:
|
||||
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
||||
|
|
@ -250,8 +242,7 @@ class DetectionValidator(BaseValidator):
|
|||
self.metrics.clear_stats()
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Calculate and return metrics statistics.
|
||||
"""Calculate and return metrics statistics.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Dictionary containing metrics results.
|
||||
|
|
@ -281,8 +272,7 @@ class DetectionValidator(BaseValidator):
|
|||
)
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Return correct prediction matrix.
|
||||
"""Return correct prediction matrix.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
|
||||
|
|
@ -298,8 +288,7 @@ class DetectionValidator(BaseValidator):
|
|||
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None) -> torch.utils.data.Dataset:
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
"""Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
|
|
@ -312,8 +301,7 @@ class DetectionValidator(BaseValidator):
|
|||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
|
||||
|
||||
def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
|
||||
"""
|
||||
Construct and return dataloader.
|
||||
"""Construct and return dataloader.
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path to the dataset.
|
||||
|
|
@ -334,8 +322,7 @@ class DetectionValidator(BaseValidator):
|
|||
)
|
||||
|
||||
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
||||
"""
|
||||
Plot validation image samples.
|
||||
"""Plot validation image samples.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
|
|
@ -352,8 +339,7 @@ class DetectionValidator(BaseValidator):
|
|||
def plot_predictions(
|
||||
self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: int | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Plot predicted bounding boxes on input images and save the result.
|
||||
"""Plot predicted bounding boxes on input images and save the result.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
|
|
@ -379,8 +365,7 @@ class DetectionValidator(BaseValidator):
|
|||
) # pred
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
||||
"""Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
|
||||
|
|
@ -398,12 +383,11 @@ class DetectionValidator(BaseValidator):
|
|||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Serialize YOLO predictions to COCO json format.
|
||||
"""Serialize YOLO predictions to COCO json format.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
||||
with bounding box coordinates, confidence scores, and class predictions.
|
||||
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
||||
bounding box coordinates, confidence scores, and class predictions.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
|
||||
Examples:
|
||||
|
|
@ -444,8 +428,7 @@ class DetectionValidator(BaseValidator):
|
|||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate YOLO output in JSON format and return performance statistics.
|
||||
"""Evaluate YOLO output in JSON format and return performance statistics.
|
||||
|
||||
Args:
|
||||
stats (dict[str, Any]): Current statistics dictionary.
|
||||
|
|
@ -469,8 +452,7 @@ class DetectionValidator(BaseValidator):
|
|||
iou_types: str | list[str] = "bbox",
|
||||
suffix: str | list[str] = "Box",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate COCO/LVIS metrics using faster-coco-eval library.
|
||||
"""Evaluate COCO/LVIS metrics using faster-coco-eval library.
|
||||
|
||||
Performs evaluation using the faster-coco-eval library to compute mAP metrics for object detection. Updates the
|
||||
provided stats dictionary with computed metrics including mAP50, mAP50-95, and LVIS-specific metrics if
|
||||
|
|
@ -480,10 +462,10 @@ class DetectionValidator(BaseValidator):
|
|||
stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
|
||||
pred_json (str | Path]): Path to JSON file containing predictions in COCO format.
|
||||
anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.
|
||||
iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
|
||||
Common values include "bbox", "segm", "keypoints". Defaults to "bbox".
|
||||
suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond
|
||||
to iou_types if multiple types provided. Defaults to "Box".
|
||||
iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings. Common
|
||||
values include "bbox", "segm", "keypoints". Defaults to "bbox".
|
||||
suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond to
|
||||
iou_types if multiple types provided. Defaults to "Box".
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
|
||||
|
|
|
|||
|
|
@ -24,8 +24,7 @@ from ultralytics.utils import ROOT, YAML
|
|||
|
||||
|
||||
class YOLO(Model):
|
||||
"""
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
"""YOLO (You Only Look Once) object detection model.
|
||||
|
||||
This class provides a unified interface for YOLO models, automatically switching to specialized model types
|
||||
(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
|
||||
|
|
@ -52,16 +51,15 @@ class YOLO(Model):
|
|||
"""
|
||||
|
||||
def __init__(self, model: str | Path = "yolo11n.pt", task: str | None = None, verbose: bool = False):
|
||||
"""
|
||||
Initialize a YOLO model.
|
||||
"""Initialize a YOLO model.
|
||||
|
||||
This constructor initializes a YOLO model, automatically switching to specialized model types (YOLOWorld or
|
||||
YOLOE) based on the model filename.
|
||||
|
||||
Args:
|
||||
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
|
||||
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
|
||||
Defaults to auto-detection based on model.
|
||||
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'. Defaults
|
||||
to auto-detection based on model.
|
||||
verbose (bool): Display model info on load.
|
||||
|
||||
Examples:
|
||||
|
|
@ -126,8 +124,7 @@ class YOLO(Model):
|
|||
|
||||
|
||||
class YOLOWorld(Model):
|
||||
"""
|
||||
YOLO-World object detection model.
|
||||
"""YOLO-World object detection model.
|
||||
|
||||
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions without
|
||||
requiring training on specific classes. It extends the YOLO architecture to support real-time open-vocabulary
|
||||
|
|
@ -152,8 +149,7 @@ class YOLOWorld(Model):
|
|||
"""
|
||||
|
||||
def __init__(self, model: str | Path = "yolov8s-world.pt", verbose: bool = False) -> None:
|
||||
"""
|
||||
Initialize YOLOv8-World model with a pre-trained model file.
|
||||
"""Initialize YOLOv8-World model with a pre-trained model file.
|
||||
|
||||
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default COCO
|
||||
class names.
|
||||
|
|
@ -181,8 +177,7 @@ class YOLOWorld(Model):
|
|||
}
|
||||
|
||||
def set_classes(self, classes: list[str]) -> None:
|
||||
"""
|
||||
Set the model's class names for detection.
|
||||
"""Set the model's class names for detection.
|
||||
|
||||
Args:
|
||||
classes (list[str]): A list of categories i.e. ["person"].
|
||||
|
|
@ -200,8 +195,7 @@ class YOLOWorld(Model):
|
|||
|
||||
|
||||
class YOLOE(Model):
|
||||
"""
|
||||
YOLOE object detection and segmentation model.
|
||||
"""YOLOE object detection and segmentation model.
|
||||
|
||||
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with improved
|
||||
performance and additional features like visual and text positional embeddings.
|
||||
|
|
@ -235,8 +229,7 @@ class YOLOE(Model):
|
|||
"""
|
||||
|
||||
def __init__(self, model: str | Path = "yoloe-11s-seg.pt", task: str | None = None, verbose: bool = False) -> None:
|
||||
"""
|
||||
Initialize YOLOE model with a pre-trained model file.
|
||||
"""Initialize YOLOE model with a pre-trained model file.
|
||||
|
||||
Args:
|
||||
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
||||
|
|
@ -269,8 +262,7 @@ class YOLOE(Model):
|
|||
return self.model.get_text_pe(texts)
|
||||
|
||||
def get_visual_pe(self, img, visual):
|
||||
"""
|
||||
Get visual positional embeddings for the given image and visual features.
|
||||
"""Get visual positional embeddings for the given image and visual features.
|
||||
|
||||
This method extracts positional embeddings from visual features based on the input image. It requires that the
|
||||
model is an instance of YOLOEModel.
|
||||
|
|
@ -292,8 +284,7 @@ class YOLOE(Model):
|
|||
return self.model.get_visual_pe(img, visual)
|
||||
|
||||
def set_vocab(self, vocab: list[str], names: list[str]) -> None:
|
||||
"""
|
||||
Set vocabulary and class names for the YOLOE model.
|
||||
"""Set vocabulary and class names for the YOLOE model.
|
||||
|
||||
This method configures the vocabulary and class names used by the model for text processing and classification
|
||||
tasks. The model must be an instance of YOLOEModel.
|
||||
|
|
@ -318,8 +309,7 @@ class YOLOE(Model):
|
|||
return self.model.get_vocab(names)
|
||||
|
||||
def set_classes(self, classes: list[str], embeddings: torch.Tensor | None = None) -> None:
|
||||
"""
|
||||
Set the model's class names and embeddings for detection.
|
||||
"""Set the model's class names and embeddings for detection.
|
||||
|
||||
Args:
|
||||
classes (list[str]): A list of categories i.e. ["person"].
|
||||
|
|
@ -344,8 +334,7 @@ class YOLOE(Model):
|
|||
refer_data: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Validate the model using text or visual prompts.
|
||||
"""Validate the model using text or visual prompts.
|
||||
|
||||
Args:
|
||||
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
|
||||
|
|
@ -373,19 +362,18 @@ class YOLOE(Model):
|
|||
predictor=yolo.yoloe.YOLOEVPDetectPredictor,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Run prediction on images, videos, directories, streams, etc.
|
||||
"""Run prediction on images, videos, directories, streams, etc.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
|
||||
directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
||||
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
|
||||
generator as they are computed.
|
||||
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include
|
||||
'bboxes' and 'cls' keys when non-empty.
|
||||
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths, directory
|
||||
paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
||||
stream (bool): Whether to stream the prediction results. If True, results are yielded as a generator as they
|
||||
are computed.
|
||||
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include 'bboxes'
|
||||
and 'cls' keys when non-empty.
|
||||
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
|
||||
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
|
||||
loaded based on the task.
|
||||
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically loaded
|
||||
based on the task.
|
||||
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
||||
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -8,8 +8,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|||
|
||||
|
||||
class OBBPredictor(DetectionPredictor):
|
||||
"""
|
||||
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
||||
"""A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
|
||||
bounding boxes.
|
||||
|
|
@ -27,8 +26,7 @@ class OBBPredictor(DetectionPredictor):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize OBBPredictor with optional model and data configuration overrides.
|
||||
"""Initialize OBBPredictor with optional model and data configuration overrides.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration for the predictor.
|
||||
|
|
@ -45,12 +43,11 @@ class OBBPredictor(DetectionPredictor):
|
|||
self.args.task = "obb"
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Construct the result object from the prediction.
|
||||
"""Construct the result object from the prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
|
||||
the last dimension contains [x, y, w, h, confidence, class_id, angle].
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where the
|
||||
last dimension contains [x, y, w, h, confidence, class_id, angle].
|
||||
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
||||
orig_img (np.ndarray): The original image before preprocessing.
|
||||
img_path (str): The path to the original image.
|
||||
|
|
|
|||
|
|
@ -12,15 +12,14 @@ from ultralytics.utils import DEFAULT_CFG, RANK
|
|||
|
||||
|
||||
class OBBTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
||||
"""A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for detecting
|
||||
objects at arbitrary angles rather than just axis-aligned rectangles.
|
||||
|
||||
Attributes:
|
||||
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
|
||||
and dfl_loss.
|
||||
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss, and
|
||||
dfl_loss.
|
||||
|
||||
Methods:
|
||||
get_model: Return OBBModel initialized with specified config and weights.
|
||||
|
|
@ -34,14 +33,13 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
|
||||
"""
|
||||
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
||||
"""Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
||||
model configuration.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
||||
will take precedence over those in cfg.
|
||||
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and model
|
||||
configuration.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here will
|
||||
take precedence over those in cfg.
|
||||
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
|
||||
"""
|
||||
if overrides is None:
|
||||
|
|
@ -52,8 +50,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|||
def get_model(
|
||||
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
|
||||
) -> OBBModel:
|
||||
"""
|
||||
Return OBBModel initialized with specified config and weights.
|
||||
"""Return OBBModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from ultralytics.utils.nms import TorchNMS
|
|||
|
||||
|
||||
class OBBValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
||||
"""A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
|
||||
satellite imagery where objects can appear at various orientations.
|
||||
|
|
@ -44,8 +43,7 @@ class OBBValidator(DetectionValidator):
|
|||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
||||
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
||||
|
||||
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models. It
|
||||
extends the DetectionValidator class and configures it specifically for the OBB task.
|
||||
|
|
@ -61,8 +59,7 @@ class OBBValidator(DetectionValidator):
|
|||
self.metrics = OBBMetrics()
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize evaluation metrics for YOLO obb validation.
|
||||
"""Initialize evaluation metrics for YOLO obb validation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
|
|
@ -73,18 +70,17 @@ class OBBValidator(DetectionValidator):
|
|||
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
||||
"""Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
|
||||
class labels and bounding boxes.
|
||||
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
|
||||
class labels and bounding boxes.
|
||||
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth class
|
||||
labels and bounding boxes.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
|
||||
array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy of
|
||||
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy array
|
||||
with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy of
|
||||
predictions compared to the ground truth.
|
||||
|
||||
Examples:
|
||||
|
|
@ -99,7 +95,8 @@ class OBBValidator(DetectionValidator):
|
|||
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
||||
|
||||
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
"""Postprocess OBB predictions.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
|
||||
|
|
@ -112,8 +109,7 @@ class OBBValidator(DetectionValidator):
|
|||
return preds
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare batch data for OBB validation with proper scaling and formatting.
|
||||
"""Prepare batch data for OBB validation with proper scaling and formatting.
|
||||
|
||||
Args:
|
||||
si (int): Batch index to process.
|
||||
|
|
@ -146,8 +142,7 @@ class OBBValidator(DetectionValidator):
|
|||
}
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
|
||||
"""
|
||||
Plot predicted bounding boxes on input images and save the result.
|
||||
"""Plot predicted bounding boxes on input images and save the result.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
||||
|
|
@ -166,12 +161,11 @@ class OBBValidator(DetectionValidator):
|
|||
super().plot_predictions(batch, preds, ni) # plot bboxes
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
||||
"""Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
|
||||
with bounding box coordinates, confidence scores, and class predictions.
|
||||
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
||||
bounding box coordinates, confidence scores, and class predictions.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
|
||||
Notes:
|
||||
|
|
@ -197,8 +191,7 @@ class OBBValidator(DetectionValidator):
|
|||
)
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO OBB detections to a text file in normalized coordinates.
|
||||
"""Save YOLO OBB detections to a text file in normalized coordinates.
|
||||
|
||||
Args:
|
||||
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
||||
|
|
@ -233,8 +226,7 @@ class OBBValidator(DetectionValidator):
|
|||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
||||
"""Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
||||
|
||||
Args:
|
||||
stats (dict[str, Any]): Performance statistics dictionary.
|
||||
|
|
|
|||
|
|
@ -5,8 +5,7 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
|
|||
|
||||
|
||||
class PosePredictor(DetectionPredictor):
|
||||
"""
|
||||
A class extending the DetectionPredictor class for prediction based on a pose model.
|
||||
"""A class extending the DetectionPredictor class for prediction based on a pose model.
|
||||
|
||||
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
|
||||
capabilities inherited from DetectionPredictor.
|
||||
|
|
@ -27,8 +26,7 @@ class PosePredictor(DetectionPredictor):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize PosePredictor for pose estimation tasks.
|
||||
"""Initialize PosePredictor for pose estimation tasks.
|
||||
|
||||
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific warnings
|
||||
for Apple MPS.
|
||||
|
|
@ -54,8 +52,7 @@ class PosePredictor(DetectionPredictor):
|
|||
)
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Construct the result object from the prediction, including keypoints.
|
||||
"""Construct the result object from the prediction, including keypoints.
|
||||
|
||||
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
|
||||
result object.
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER
|
|||
|
||||
|
||||
class PoseTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
||||
"""A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
||||
|
||||
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
||||
of pose keypoints alongside bounding boxes.
|
||||
|
|
@ -39,8 +38,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
||||
"""Initialize a PoseTrainer object for training YOLO pose estimation models.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
||||
|
|
@ -68,8 +66,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|||
weights: str | Path | None = None,
|
||||
verbose: bool = True,
|
||||
) -> PoseModel:
|
||||
"""
|
||||
Get pose estimation model with specified configuration and weights.
|
||||
"""Get pose estimation model with specified configuration and weights.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
|
||||
|
|
@ -105,8 +102,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|||
)
|
||||
|
||||
def get_dataset(self) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
||||
"""Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the training/validation/test dataset and category names.
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@ from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
|
|||
|
||||
|
||||
class PoseValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on a pose model.
|
||||
"""A class extending the DetectionValidator class for validation based on a pose model.
|
||||
|
||||
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
|
||||
metrics for pose evaluation.
|
||||
|
|
@ -33,8 +32,8 @@ class PoseValidator(DetectionValidator):
|
|||
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
|
||||
dimensions.
|
||||
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
|
||||
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
|
||||
detections and ground truth.
|
||||
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
|
||||
and ground truth.
|
||||
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
||||
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
|
||||
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
|
||||
|
|
@ -49,8 +48,7 @@ class PoseValidator(DetectionValidator):
|
|||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize a PoseValidator object for pose estimation validation.
|
||||
"""Initialize a PoseValidator object for pose estimation validation.
|
||||
|
||||
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
||||
specialized metrics for pose evaluation.
|
||||
|
|
@ -106,8 +104,7 @@ class PoseValidator(DetectionValidator):
|
|||
)
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize evaluation metrics for YOLO pose validation.
|
||||
"""Initialize evaluation metrics for YOLO pose validation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
|
|
@ -119,16 +116,15 @@ class PoseValidator(DetectionValidator):
|
|||
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
||||
|
||||
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
||||
"""Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
||||
|
||||
This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
|
||||
predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
|
||||
flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
|
||||
bounding boxes, confidence scores, class predictions, and keypoint data.
|
||||
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
|
||||
scores, class predictions, and keypoint data.
|
||||
|
||||
Returns:
|
||||
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
||||
|
|
@ -148,8 +144,7 @@ class PoseValidator(DetectionValidator):
|
|||
return preds
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
||||
"""Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
||||
|
||||
Args:
|
||||
si (int): Batch index.
|
||||
|
|
@ -172,18 +167,18 @@ class PoseValidator(DetectionValidator):
|
|||
return pbatch
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
|
||||
"""Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
|
||||
truth.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
||||
and 'keypoints' for keypoint predictions.
|
||||
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
|
||||
'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
|
||||
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
|
||||
for bounding boxes, and 'keypoints' for keypoint annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
|
||||
true positives across 10 IoU levels.
|
||||
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
|
||||
positives across 10 IoU levels.
|
||||
|
||||
Notes:
|
||||
`0.53` scale factor used in area computation is referenced from
|
||||
|
|
@ -202,8 +197,7 @@ class PoseValidator(DetectionValidator):
|
|||
return tp
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO pose detections to a text file in normalized coordinates.
|
||||
"""Save YOLO pose detections to a text file in normalized coordinates.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
||||
|
|
@ -226,15 +220,14 @@ class PoseValidator(DetectionValidator):
|
|||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert YOLO predictions to COCO JSON format.
|
||||
"""Convert YOLO predictions to COCO JSON format.
|
||||
|
||||
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
|
||||
format, and appends the results to the internal JSON dictionary (self.jdict).
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
|
||||
and 'keypoints' tensors.
|
||||
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
|
||||
tensors.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
|
||||
Notes:
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|||
|
||||
|
||||
class SegmentationPredictor(DetectionPredictor):
|
||||
"""
|
||||
A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
||||
"""A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
||||
|
||||
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
||||
prediction results.
|
||||
|
|
@ -31,8 +30,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
||||
"""Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
||||
|
||||
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
||||
prediction results.
|
||||
|
|
@ -46,8 +44,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||
self.args.task = "segment"
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Apply non-max suppression and process segmentation detections for each image in the input batch.
|
||||
"""Apply non-max suppression and process segmentation detections for each image in the input batch.
|
||||
|
||||
Args:
|
||||
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
||||
|
|
@ -55,8 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|||
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing the segmentation predictions for each image in the batch.
|
||||
Each Results object includes both bounding boxes and segmentation masks.
|
||||
(list): List of Results objects containing the segmentation predictions for each image in the batch. Each
|
||||
Results object includes both bounding boxes and segmentation masks.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
|
||||
|
|
@ -67,8 +64,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
||||
|
||||
def construct_results(self, preds, img, orig_imgs, protos):
|
||||
"""
|
||||
Construct a list of result objects from the predictions.
|
||||
"""Construct a list of result objects from the predictions.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
||||
|
|
@ -77,8 +73,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|||
protos (list[torch.Tensor]): List of prototype masks.
|
||||
|
||||
Returns:
|
||||
(list[Results]): List of result objects containing the original images, image paths, class names,
|
||||
bounding boxes, and masks.
|
||||
(list[Results]): List of result objects containing the original images, image paths, class names, bounding
|
||||
boxes, and masks.
|
||||
"""
|
||||
return [
|
||||
self.construct_result(pred, img, orig_img, img_path, proto)
|
||||
|
|
@ -86,8 +82,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||
]
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path, proto):
|
||||
"""
|
||||
Construct a single result object from the prediction.
|
||||
"""Construct a single result object from the prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ from ultralytics.utils import DEFAULT_CFG, RANK
|
|||
|
||||
|
||||
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training based on a segmentation model.
|
||||
"""A class extending the DetectionTrainer class for training based on a segmentation model.
|
||||
|
||||
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
||||
functionality including model initialization, validation, and visualization.
|
||||
|
|
@ -28,8 +27,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a SegmentationTrainer object.
|
||||
"""Initialize a SegmentationTrainer object.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary with default training settings.
|
||||
|
|
@ -42,8 +40,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
|
||||
"""
|
||||
Initialize and return a SegmentationModel with specified configuration and weights.
|
||||
"""Initialize and return a SegmentationModel with specified configuration and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
|
||||
|
|
|
|||
|
|
@ -17,8 +17,7 @@ from ultralytics.utils.metrics import SegmentMetrics, mask_iou
|
|||
|
||||
|
||||
class SegmentationValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on a segmentation model.
|
||||
"""A class extending the DetectionValidator class for validation based on a segmentation model.
|
||||
|
||||
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions to
|
||||
compute metrics such as mAP for both detection and segmentation tasks.
|
||||
|
|
@ -38,8 +37,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
||||
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
||||
|
|
@ -53,8 +51,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
self.metrics = SegmentMetrics()
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Preprocess batch of images for YOLO segmentation validation.
|
||||
"""Preprocess batch of images for YOLO segmentation validation.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
|
|
@ -67,8 +64,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
return batch
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize metrics and select mask processing function based on save_json flag.
|
||||
"""Initialize metrics and select mask processing function based on save_json flag.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
|
|
@ -96,8 +92,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
)
|
||||
|
||||
def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Post-process YOLO predictions and return output detections with proto.
|
||||
"""Post-process YOLO predictions and return output detections with proto.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): Raw predictions from the model.
|
||||
|
|
@ -122,8 +117,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
return preds
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare a batch for training or inference by processing images and targets.
|
||||
"""Prepare a batch for training or inference by processing images and targets.
|
||||
|
||||
Args:
|
||||
si (int): Batch index.
|
||||
|
|
@ -149,8 +143,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
return prepared_batch
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
||||
"""Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
|
||||
|
|
@ -179,8 +172,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
return tp
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
|
||||
"""
|
||||
Plot batch predictions with masks and bounding boxes.
|
||||
"""Plot batch predictions with masks and bounding boxes.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
|
|
@ -195,8 +187,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
|
||||
|
||||
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
||||
"""Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
||||
|
||||
Args:
|
||||
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
||||
|
|
@ -215,8 +206,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Save one JSON result for COCO evaluation.
|
||||
"""Save one JSON result for COCO evaluation.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
|
||||
|
|
|
|||
|
|
@ -24,8 +24,7 @@ def on_pretrain_routine_end(trainer) -> None:
|
|||
|
||||
|
||||
class WorldTrainer(DetectionTrainer):
|
||||
"""
|
||||
A trainer class for fine-tuning YOLO World models on close-set datasets.
|
||||
"""A trainer class for fine-tuning YOLO World models on close-set datasets.
|
||||
|
||||
This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
|
||||
features for improved object detection and understanding. It handles text embedding generation and caching to
|
||||
|
|
@ -54,8 +53,7 @@ class WorldTrainer(DetectionTrainer):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a WorldTrainer object with given arguments.
|
||||
"""Initialize a WorldTrainer object with given arguments.
|
||||
|
||||
Args:
|
||||
cfg (dict[str, Any]): Configuration for the trainer.
|
||||
|
|
@ -69,8 +67,7 @@ class WorldTrainer(DetectionTrainer):
|
|||
self.text_embeddings = None
|
||||
|
||||
def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
|
||||
"""
|
||||
Return WorldModel initialized with specified config and weights.
|
||||
"""Return WorldModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict[str, Any] | str, optional): Model configuration.
|
||||
|
|
@ -95,8 +92,7 @@ class WorldTrainer(DetectionTrainer):
|
|||
return model
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation.
|
||||
"""Build YOLO Dataset for training or validation.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
|
|
@ -115,8 +111,7 @@ class WorldTrainer(DetectionTrainer):
|
|||
return dataset
|
||||
|
||||
def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
|
||||
"""
|
||||
Set text embeddings for datasets to accelerate training by caching category names.
|
||||
"""Set text embeddings for datasets to accelerate training by caching category names.
|
||||
|
||||
This method collects unique category names from all datasets, then generates and caches text embeddings for
|
||||
these categories to improve training efficiency.
|
||||
|
|
@ -141,8 +136,7 @@ class WorldTrainer(DetectionTrainer):
|
|||
self.text_embeddings = text_embeddings
|
||||
|
||||
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Generate text embeddings for a list of text samples.
|
||||
"""Generate text embeddings for a list of text samples.
|
||||
|
||||
Args:
|
||||
texts (list[str]): List of text samples to encode.
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@ from ultralytics.utils.torch_utils import unwrap_model
|
|||
|
||||
|
||||
class WorldTrainerFromScratch(WorldTrainer):
|
||||
"""
|
||||
A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
|
||||
"""A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
|
||||
|
||||
This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
|
||||
supporting training YOLO-World models with combined vision-language capabilities.
|
||||
|
|
@ -53,8 +52,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize a WorldTrainerFromScratch object.
|
||||
"""Initialize a WorldTrainerFromScratch object.
|
||||
|
||||
This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both object
|
||||
detection and grounding datasets for vision-language capabilities.
|
||||
|
|
@ -87,8 +85,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def build_dataset(self, img_path, mode="train", batch=None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation.
|
||||
"""Build YOLO Dataset for training or validation.
|
||||
|
||||
This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
|
||||
datasets and grounding datasets with different formats.
|
||||
|
|
@ -122,8 +119,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
||||
|
||||
def get_dataset(self):
|
||||
"""
|
||||
Get train and validation paths from data dictionary.
|
||||
"""Get train and validation paths from data dictionary.
|
||||
|
||||
Processes the data configuration to extract paths for training and validation datasets, handling both YOLO
|
||||
detection datasets and grounding datasets.
|
||||
|
|
@ -187,8 +183,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
pass
|
||||
|
||||
def final_eval(self):
|
||||
"""
|
||||
Perform final evaluation and validation for the YOLO-World model.
|
||||
"""Perform final evaluation and validation for the YOLO-World model.
|
||||
|
||||
Configures the validator with appropriate dataset and split information before running evaluation.
|
||||
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ from ultralytics.models.yolo.segment import SegmentationPredictor
|
|||
|
||||
|
||||
class YOLOEVPDetectPredictor(DetectionPredictor):
|
||||
"""
|
||||
A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
||||
"""A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
||||
|
||||
This mixin provides common functionality for YOLO models that use visual prompting, including model setup, prompt
|
||||
handling, and preprocessing transformations.
|
||||
|
|
@ -29,8 +28,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|||
"""
|
||||
|
||||
def setup_model(self, model, verbose: bool = True):
|
||||
"""
|
||||
Set up the model for prediction.
|
||||
"""Set up the model for prediction.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to load or use.
|
||||
|
|
@ -40,18 +38,16 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|||
self.done_warmup = True
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""
|
||||
Set the visual prompts for the model.
|
||||
"""Set the visual prompts for the model.
|
||||
|
||||
Args:
|
||||
prompts (dict): Dictionary containing class indices and bounding boxes or masks.
|
||||
Must include a 'cls' key with class indices.
|
||||
prompts (dict): Dictionary containing class indices and bounding boxes or masks. Must include a 'cls' key
|
||||
with class indices.
|
||||
"""
|
||||
self.prompts = prompts
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
Preprocess images and prompts before inference.
|
||||
"""Preprocess images and prompts before inference.
|
||||
|
||||
This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks)
|
||||
accordingly.
|
||||
|
|
@ -94,8 +90,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|||
return img
|
||||
|
||||
def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
|
||||
"""
|
||||
Process a single image by resizing bounding boxes or masks and generating visuals.
|
||||
"""Process a single image by resizing bounding boxes or masks and generating visuals.
|
||||
|
||||
Args:
|
||||
dst_shape (tuple): The target shape (height, width) of the image.
|
||||
|
|
@ -131,8 +126,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|||
return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
|
||||
|
||||
def inference(self, im, *args, **kwargs):
|
||||
"""
|
||||
Run inference with visual prompts.
|
||||
"""Run inference with visual prompts.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): Input image tensor.
|
||||
|
|
@ -145,13 +139,12 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|||
return super().inference(im, vpe=self.prompts, *args, **kwargs)
|
||||
|
||||
def get_vpe(self, source):
|
||||
"""
|
||||
Process the source to get the visual prompt embeddings (VPE).
|
||||
"""Process the source to get the visual prompt embeddings (VPE).
|
||||
|
||||
Args:
|
||||
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
|
||||
of the image to make predictions on. Accepts various types including file paths, URLs, PIL images, numpy
|
||||
arrays, and torch tensors.
|
||||
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image to
|
||||
make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
|
||||
torch tensors.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The visual prompt embeddings (VPE) from the model.
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ from .val import YOLOEDetectValidator
|
|||
|
||||
|
||||
class YOLOETrainer(DetectionTrainer):
|
||||
"""
|
||||
A trainer class for YOLOE object detection models.
|
||||
"""A trainer class for YOLOE object detection models.
|
||||
|
||||
This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom
|
||||
model initialization, validation, and dataset building with multi-modal support.
|
||||
|
|
@ -35,8 +34,7 @@ class YOLOETrainer(DetectionTrainer):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize the YOLOE Trainer with specified configurations.
|
||||
"""Initialize the YOLOE Trainer with specified configurations.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
|
||||
|
|
@ -50,12 +48,11 @@ class YOLOETrainer(DetectionTrainer):
|
|||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
||||
"""
|
||||
Return a YOLOEModel initialized with the specified configuration and weights.
|
||||
"""Return a YOLOEModel initialized with the specified configuration and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,
|
||||
a direct path to a YAML file, or None to use default configuration.
|
||||
cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key, a direct
|
||||
path to a YAML file, or None to use default configuration.
|
||||
weights (str | Path, optional): Path to pretrained weights file to load into the model.
|
||||
verbose (bool): Whether to display model information during initialization.
|
||||
|
||||
|
|
@ -88,8 +85,7 @@ class YOLOETrainer(DetectionTrainer):
|
|||
)
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
"""Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
|
|
@ -106,8 +102,7 @@ class YOLOETrainer(DetectionTrainer):
|
|||
|
||||
|
||||
class YOLOEPETrainer(DetectionTrainer):
|
||||
"""
|
||||
Fine-tune YOLOE model using linear probing approach.
|
||||
"""Fine-tune YOLOE model using linear probing approach.
|
||||
|
||||
This trainer freezes most model layers and only trains specific projection layers for efficient fine-tuning on new
|
||||
datasets while preserving pretrained features.
|
||||
|
|
@ -117,8 +112,7 @@ class YOLOEPETrainer(DetectionTrainer):
|
|||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
||||
"""
|
||||
Return YOLOEModel initialized with specified config and weights.
|
||||
"""Return YOLOEModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration.
|
||||
|
|
@ -160,8 +154,7 @@ class YOLOEPETrainer(DetectionTrainer):
|
|||
|
||||
|
||||
class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
||||
"""
|
||||
Train YOLOE models from scratch with text embedding support.
|
||||
"""Train YOLOE models from scratch with text embedding support.
|
||||
|
||||
This trainer combines YOLOE training capabilities with world training features, enabling training from scratch with
|
||||
text embeddings and grounding datasets.
|
||||
|
|
@ -172,8 +165,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|||
"""
|
||||
|
||||
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation.
|
||||
"""Build YOLO Dataset for training or validation.
|
||||
|
||||
This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
|
||||
datasets and grounding datasets with different formats.
|
||||
|
|
@ -189,8 +181,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|||
return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
|
||||
|
||||
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
|
||||
"""
|
||||
Generate text embeddings for a list of text samples.
|
||||
"""Generate text embeddings for a list of text samples.
|
||||
|
||||
Args:
|
||||
texts (list[str]): List of text samples to encode.
|
||||
|
|
@ -216,8 +207,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|||
|
||||
|
||||
class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
||||
"""
|
||||
Train prompt-free YOLOE model.
|
||||
"""Train prompt-free YOLOE model.
|
||||
|
||||
This trainer combines linear probing capabilities with from-scratch training for prompt-free YOLOE models that don't
|
||||
require text prompts during inference.
|
||||
|
|
@ -240,8 +230,7 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
|||
return DetectionTrainer.preprocess_batch(self, batch)
|
||||
|
||||
def set_text_embeddings(self, datasets, batch: int):
|
||||
"""
|
||||
Set text embeddings for datasets to accelerate training by caching category names.
|
||||
"""Set text embeddings for datasets to accelerate training by caching category names.
|
||||
|
||||
This method collects unique category names from all datasets, generates text embeddings for them, and caches
|
||||
these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of
|
||||
|
|
@ -260,8 +249,7 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
|||
|
||||
|
||||
class YOLOEVPTrainer(YOLOETrainerFromScratch):
|
||||
"""
|
||||
Train YOLOE model with visual prompts.
|
||||
"""Train YOLOE model with visual prompts.
|
||||
|
||||
This trainer extends YOLOETrainerFromScratch to support visual prompt-based training, where visual cues are provided
|
||||
alongside images to guide the detection process.
|
||||
|
|
@ -271,8 +259,7 @@ class YOLOEVPTrainer(YOLOETrainerFromScratch):
|
|||
"""
|
||||
|
||||
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation with visual prompts.
|
||||
"""Build YOLO Dataset for training or validation with visual prompts.
|
||||
|
||||
Args:
|
||||
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ from .val import YOLOESegValidator
|
|||
|
||||
|
||||
class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
||||
"""
|
||||
Trainer class for YOLOE segmentation models.
|
||||
"""Trainer class for YOLOE segmentation models.
|
||||
|
||||
This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
|
||||
segmentation models, enabling both object detection and instance segmentation capabilities.
|
||||
|
|
@ -24,8 +23,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""
|
||||
Return YOLOESegModel initialized with specified config and weights.
|
||||
"""Return YOLOESegModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
||||
|
|
@ -49,8 +47,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""
|
||||
Create and return a validator for YOLOE segmentation model evaluation.
|
||||
"""Create and return a validator for YOLOE segmentation model evaluation.
|
||||
|
||||
Returns:
|
||||
(YOLOESegValidator): Validator for YOLOE segmentation models.
|
||||
|
|
@ -62,8 +59,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|||
|
||||
|
||||
class YOLOEPESegTrainer(SegmentationTrainer):
|
||||
"""
|
||||
Fine-tune YOLOESeg model in linear probing way.
|
||||
"""Fine-tune YOLOESeg model in linear probing way.
|
||||
|
||||
This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
|
||||
most of the model and only training specific layers for efficient adaptation to new tasks.
|
||||
|
|
@ -73,8 +69,7 @@ class YOLOEPESegTrainer(SegmentationTrainer):
|
|||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""
|
||||
Return YOLOESegModel initialized with specified config and weights for linear probing.
|
||||
"""Return YOLOESegModel initialized with specified config and weights for linear probing.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
||||
|
|
|
|||
|
|
@ -21,8 +21,7 @@ from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
|||
|
||||
|
||||
class YOLOEDetectValidator(DetectionValidator):
|
||||
"""
|
||||
A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
|
||||
"""A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
|
||||
|
||||
This class extends DetectionValidator to provide specialized validation functionality for YOLOE models. It supports
|
||||
validation using either text prompts or visual prompt embeddings extracted from training samples, enabling flexible
|
||||
|
|
@ -50,8 +49,7 @@ class YOLOEDetectValidator(DetectionValidator):
|
|||
|
||||
@smart_inference_mode()
|
||||
def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
|
||||
"""
|
||||
Extract visual prompt embeddings from training samples.
|
||||
"""Extract visual prompt embeddings from training samples.
|
||||
|
||||
This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It
|
||||
normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to
|
||||
|
|
@ -99,8 +97,7 @@ class YOLOEDetectValidator(DetectionValidator):
|
|||
return visual_pe.unsqueeze(0)
|
||||
|
||||
def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
|
||||
"""
|
||||
Create a dataloader for LVIS training visual prompt samples.
|
||||
"""Create a dataloader for LVIS training visual prompt samples.
|
||||
|
||||
This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies
|
||||
necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.
|
||||
|
|
@ -140,8 +137,7 @@ class YOLOEDetectValidator(DetectionValidator):
|
|||
refer_data: str | None = None,
|
||||
load_vp: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run validation on the model using either text or visual prompt embeddings.
|
||||
"""Run validation on the model using either text or visual prompt embeddings.
|
||||
|
||||
This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It
|
||||
supports validation during training (using a trainer object) or standalone validation with a provided model. For
|
||||
|
|
|
|||
|
|
@ -23,8 +23,7 @@ from ultralytics.utils.nms import non_max_suppression
|
|||
|
||||
|
||||
def check_class_names(names: list | dict) -> dict[int, str]:
|
||||
"""
|
||||
Check class names and convert to dict format if needed.
|
||||
"""Check class names and convert to dict format if needed.
|
||||
|
||||
Args:
|
||||
names (list | dict): Class names as list or dict format.
|
||||
|
|
@ -53,8 +52,7 @@ def check_class_names(names: list | dict) -> dict[int, str]:
|
|||
|
||||
|
||||
def default_class_names(data: str | Path | None = None) -> dict[int, str]:
|
||||
"""
|
||||
Apply default class names to an input YAML file or return numerical class names.
|
||||
"""Apply default class names to an input YAML file or return numerical class names.
|
||||
|
||||
Args:
|
||||
data (str | Path, optional): Path to YAML file containing class names.
|
||||
|
|
@ -71,8 +69,7 @@ def default_class_names(data: str | Path | None = None) -> dict[int, str]:
|
|||
|
||||
|
||||
class AutoBackend(nn.Module):
|
||||
"""
|
||||
Handle dynamic backend selection for running inference using Ultralytics YOLO models.
|
||||
"""Handle dynamic backend selection for running inference using Ultralytics YOLO models.
|
||||
|
||||
The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide
|
||||
range of formats, each with specific naming conventions as outlined below:
|
||||
|
|
@ -148,8 +145,7 @@ class AutoBackend(nn.Module):
|
|||
fuse: bool = True,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the AutoBackend for inference.
|
||||
"""Initialize the AutoBackend for inference.
|
||||
|
||||
Args:
|
||||
model (str | torch.nn.Module): Path to the model weights file or a module instance.
|
||||
|
|
@ -639,8 +635,7 @@ class AutoBackend(nn.Module):
|
|||
embed: list | None = None,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
"""
|
||||
Run inference on an AutoBackend model.
|
||||
"""Run inference on an AutoBackend model.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The image tensor to perform inference on.
|
||||
|
|
@ -860,8 +855,7 @@ class AutoBackend(nn.Module):
|
|||
return self.from_numpy(y)
|
||||
|
||||
def from_numpy(self, x: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Convert a numpy array to a tensor.
|
||||
"""Convert a numpy array to a tensor.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): The array to be converted.
|
||||
|
|
@ -872,8 +866,7 @@ class AutoBackend(nn.Module):
|
|||
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
|
||||
|
||||
def warmup(self, imgsz: tuple[int, int, int, int] = (1, 3, 640, 640)) -> None:
|
||||
"""
|
||||
Warm up the model by running one forward pass with a dummy input.
|
||||
"""Warm up the model by running one forward pass with a dummy input.
|
||||
|
||||
Args:
|
||||
imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
|
||||
|
|
@ -889,8 +882,7 @@ class AutoBackend(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _model_type(p: str = "path/to/model.pt") -> list[bool]:
|
||||
"""
|
||||
Take a path to a model file and return the model type.
|
||||
"""Take a path to a model file and return the model type.
|
||||
|
||||
Args:
|
||||
p (str): Path to the model file.
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ import torch.nn as nn
|
|||
|
||||
|
||||
class AGLU(nn.Module):
|
||||
"""
|
||||
Unified activation function module from AGLU.
|
||||
"""Unified activation function module from AGLU.
|
||||
|
||||
This class implements a parameterized activation function with learnable parameters lambda and kappa, based on the
|
||||
AGLU (Adaptive Gated Linear Unit) approach.
|
||||
|
|
@ -40,8 +39,7 @@ class AGLU(nn.Module):
|
|||
self.kappa = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # kappa parameter
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the Adaptive Gated Linear Unit (AGLU) activation function.
|
||||
"""Apply the Adaptive Gated Linear Unit (AGLU) activation function.
|
||||
|
||||
This forward method implements the AGLU activation function with learnable parameters lambda and kappa. The
|
||||
function applies a transformation that adaptively combines linear and non-linear components.
|
||||
|
|
|
|||
|
|
@ -56,15 +56,13 @@ __all__ = (
|
|||
|
||||
|
||||
class DFL(nn.Module):
|
||||
"""
|
||||
Integral module of Distribution Focal Loss (DFL).
|
||||
"""Integral module of Distribution Focal Loss (DFL).
|
||||
|
||||
Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
||||
"""
|
||||
|
||||
def __init__(self, c1: int = 16):
|
||||
"""
|
||||
Initialize a convolutional layer with a given number of input channels.
|
||||
"""Initialize a convolutional layer with a given number of input channels.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -86,8 +84,7 @@ class Proto(nn.Module):
|
|||
"""Ultralytics YOLO models mask Proto module for segmentation models."""
|
||||
|
||||
def __init__(self, c1: int, c_: int = 256, c2: int = 32):
|
||||
"""
|
||||
Initialize the Ultralytics YOLO models mask Proto module with specified number of protos and masks.
|
||||
"""Initialize the Ultralytics YOLO models mask Proto module with specified number of protos and masks.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -106,15 +103,13 @@ class Proto(nn.Module):
|
|||
|
||||
|
||||
class HGStem(nn.Module):
|
||||
"""
|
||||
StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
|
||||
"""StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
|
||||
|
||||
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
|
||||
"""
|
||||
|
||||
def __init__(self, c1: int, cm: int, c2: int):
|
||||
"""
|
||||
Initialize the StemBlock of PPHGNetV2.
|
||||
"""Initialize the StemBlock of PPHGNetV2.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -144,8 +139,7 @@ class HGStem(nn.Module):
|
|||
|
||||
|
||||
class HGBlock(nn.Module):
|
||||
"""
|
||||
HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
|
||||
"""HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
|
||||
|
||||
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
|
||||
"""
|
||||
|
|
@ -161,8 +155,7 @@ class HGBlock(nn.Module):
|
|||
shortcut: bool = False,
|
||||
act: nn.Module = nn.ReLU(),
|
||||
):
|
||||
"""
|
||||
Initialize HGBlock with specified parameters.
|
||||
"""Initialize HGBlock with specified parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -193,8 +186,7 @@ class SPP(nn.Module):
|
|||
"""Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, k: tuple[int, ...] = (5, 9, 13)):
|
||||
"""
|
||||
Initialize the SPP layer with input/output channels and pooling kernel sizes.
|
||||
"""Initialize the SPP layer with input/output channels and pooling kernel sizes.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -217,8 +209,7 @@ class SPPF(nn.Module):
|
|||
"""Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, k: int = 5):
|
||||
"""
|
||||
Initialize the SPPF layer with given input/output channels and kernel size.
|
||||
"""Initialize the SPPF layer with given input/output channels and kernel size.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -245,8 +236,7 @@ class C1(nn.Module):
|
|||
"""CSP Bottleneck with 1 convolution."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1):
|
||||
"""
|
||||
Initialize the CSP Bottleneck with 1 convolution.
|
||||
"""Initialize the CSP Bottleneck with 1 convolution.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -267,8 +257,7 @@ class C2(nn.Module):
|
|||
"""CSP Bottleneck with 2 convolutions."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
|
||||
"""
|
||||
Initialize a CSP Bottleneck with 2 convolutions.
|
||||
"""Initialize a CSP Bottleneck with 2 convolutions.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -295,8 +284,7 @@ class C2f(nn.Module):
|
|||
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):
|
||||
"""
|
||||
Initialize a CSP bottleneck with 2 convolutions.
|
||||
"""Initialize a CSP bottleneck with 2 convolutions.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -330,8 +318,7 @@ class C3(nn.Module):
|
|||
"""CSP Bottleneck with 3 convolutions."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
|
||||
"""
|
||||
Initialize the CSP Bottleneck with 3 convolutions.
|
||||
"""Initialize the CSP Bottleneck with 3 convolutions.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -357,8 +344,7 @@ class C3x(C3):
|
|||
"""C3 module with cross-convolutions."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
|
||||
"""
|
||||
Initialize C3 module with cross-convolutions.
|
||||
"""Initialize C3 module with cross-convolutions.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -377,8 +363,7 @@ class RepC3(nn.Module):
|
|||
"""Rep C3."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 3, e: float = 1.0):
|
||||
"""
|
||||
Initialize CSP Bottleneck with a single convolution.
|
||||
"""Initialize CSP Bottleneck with a single convolution.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -402,8 +387,7 @@ class C3TR(C3):
|
|||
"""C3 module with TransformerBlock()."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
|
||||
"""
|
||||
Initialize C3 module with TransformerBlock.
|
||||
"""Initialize C3 module with TransformerBlock.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -422,8 +406,7 @@ class C3Ghost(C3):
|
|||
"""C3 module with GhostBottleneck()."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
|
||||
"""
|
||||
Initialize C3 module with GhostBottleneck.
|
||||
"""Initialize C3 module with GhostBottleneck.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -442,8 +425,7 @@ class GhostBottleneck(nn.Module):
|
|||
"""Ghost Bottleneck https://github.com/huawei-noah/Efficient-AI-Backbones."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, k: int = 3, s: int = 1):
|
||||
"""
|
||||
Initialize Ghost Bottleneck module.
|
||||
"""Initialize Ghost Bottleneck module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -473,8 +455,7 @@ class Bottleneck(nn.Module):
|
|||
def __init__(
|
||||
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: tuple[int, int] = (3, 3), e: float = 0.5
|
||||
):
|
||||
"""
|
||||
Initialize a standard bottleneck module.
|
||||
"""Initialize a standard bottleneck module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -499,8 +480,7 @@ class BottleneckCSP(nn.Module):
|
|||
"""CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
|
||||
"""
|
||||
Initialize CSP Bottleneck.
|
||||
"""Initialize CSP Bottleneck.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -531,8 +511,7 @@ class ResNetBlock(nn.Module):
|
|||
"""ResNet block with standard convolution layers."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, s: int = 1, e: int = 4):
|
||||
"""
|
||||
Initialize ResNet block.
|
||||
"""Initialize ResNet block.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -556,8 +535,7 @@ class ResNetLayer(nn.Module):
|
|||
"""ResNet layer with multiple ResNet blocks."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, s: int = 1, is_first: bool = False, n: int = 1, e: int = 4):
|
||||
"""
|
||||
Initialize ResNet layer.
|
||||
"""Initialize ResNet layer.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -588,8 +566,7 @@ class MaxSigmoidAttnBlock(nn.Module):
|
|||
"""Max Sigmoid attention block."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, nh: int = 1, ec: int = 128, gc: int = 512, scale: bool = False):
|
||||
"""
|
||||
Initialize MaxSigmoidAttnBlock.
|
||||
"""Initialize MaxSigmoidAttnBlock.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -609,8 +586,7 @@ class MaxSigmoidAttnBlock(nn.Module):
|
|||
self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0
|
||||
|
||||
def forward(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of MaxSigmoidAttnBlock.
|
||||
"""Forward pass of MaxSigmoidAttnBlock.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -653,8 +629,7 @@ class C2fAttn(nn.Module):
|
|||
g: int = 1,
|
||||
e: float = 0.5,
|
||||
):
|
||||
"""
|
||||
Initialize C2f module with attention mechanism.
|
||||
"""Initialize C2f module with attention mechanism.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -675,8 +650,7 @@ class C2fAttn(nn.Module):
|
|||
self.attn = MaxSigmoidAttnBlock(self.c, self.c, gc=gc, ec=ec, nh=nh)
|
||||
|
||||
def forward(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through C2f layer with attention.
|
||||
"""Forward pass through C2f layer with attention.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -691,8 +665,7 @@ class C2fAttn(nn.Module):
|
|||
return self.cv2(torch.cat(y, 1))
|
||||
|
||||
def forward_split(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass using split() instead of chunk().
|
||||
"""Forward pass using split() instead of chunk().
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -713,8 +686,7 @@ class ImagePoolingAttn(nn.Module):
|
|||
def __init__(
|
||||
self, ec: int = 256, ch: tuple[int, ...] = (), ct: int = 512, nh: int = 8, k: int = 3, scale: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize ImagePoolingAttn module.
|
||||
"""Initialize ImagePoolingAttn module.
|
||||
|
||||
Args:
|
||||
ec (int): Embedding channels.
|
||||
|
|
@ -741,8 +713,7 @@ class ImagePoolingAttn(nn.Module):
|
|||
self.k = k
|
||||
|
||||
def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of ImagePoolingAttn.
|
||||
"""Forward pass of ImagePoolingAttn.
|
||||
|
||||
Args:
|
||||
x (list[torch.Tensor]): List of input feature maps.
|
||||
|
|
@ -785,8 +756,7 @@ class ContrastiveHead(nn.Module):
|
|||
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
|
||||
|
||||
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward function of contrastive learning.
|
||||
"""Forward function of contrastive learning.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Image features.
|
||||
|
|
@ -802,16 +772,14 @@ class ContrastiveHead(nn.Module):
|
|||
|
||||
|
||||
class BNContrastiveHead(nn.Module):
|
||||
"""
|
||||
Batch Norm Contrastive Head using batch norm instead of l2-normalization.
|
||||
"""Batch Norm Contrastive Head using batch norm instead of l2-normalization.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Embed dimensions of text and image features.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dims: int):
|
||||
"""
|
||||
Initialize BNContrastiveHead.
|
||||
"""Initialize BNContrastiveHead.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Embedding dimensions for features.
|
||||
|
|
@ -835,8 +803,7 @@ class BNContrastiveHead(nn.Module):
|
|||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward function of contrastive learning with batch normalization.
|
||||
"""Forward function of contrastive learning with batch normalization.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Image features.
|
||||
|
|
@ -858,8 +825,7 @@ class RepBottleneck(Bottleneck):
|
|||
def __init__(
|
||||
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: tuple[int, int] = (3, 3), e: float = 0.5
|
||||
):
|
||||
"""
|
||||
Initialize RepBottleneck.
|
||||
"""Initialize RepBottleneck.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -878,8 +844,7 @@ class RepCSP(C3):
|
|||
"""Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
|
||||
"""
|
||||
Initialize RepCSP layer.
|
||||
"""Initialize RepCSP layer.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -898,8 +863,7 @@ class RepNCSPELAN4(nn.Module):
|
|||
"""CSP-ELAN."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, c3: int, c4: int, n: int = 1):
|
||||
"""
|
||||
Initialize CSP-ELAN layer.
|
||||
"""Initialize CSP-ELAN layer.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -932,8 +896,7 @@ class ELAN1(RepNCSPELAN4):
|
|||
"""ELAN1 module with 4 convolutions."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, c3: int, c4: int):
|
||||
"""
|
||||
Initialize ELAN1 layer.
|
||||
"""Initialize ELAN1 layer.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -953,8 +916,7 @@ class AConv(nn.Module):
|
|||
"""AConv."""
|
||||
|
||||
def __init__(self, c1: int, c2: int):
|
||||
"""
|
||||
Initialize AConv module.
|
||||
"""Initialize AConv module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -973,8 +935,7 @@ class ADown(nn.Module):
|
|||
"""ADown."""
|
||||
|
||||
def __init__(self, c1: int, c2: int):
|
||||
"""
|
||||
Initialize ADown module.
|
||||
"""Initialize ADown module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -999,8 +960,7 @@ class SPPELAN(nn.Module):
|
|||
"""SPP-ELAN."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, c3: int, k: int = 5):
|
||||
"""
|
||||
Initialize SPP-ELAN block.
|
||||
"""Initialize SPP-ELAN block.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -1027,8 +987,7 @@ class CBLinear(nn.Module):
|
|||
"""CBLinear."""
|
||||
|
||||
def __init__(self, c1: int, c2s: list[int], k: int = 1, s: int = 1, p: int | None = None, g: int = 1):
|
||||
"""
|
||||
Initialize CBLinear module.
|
||||
"""Initialize CBLinear module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -1051,8 +1010,7 @@ class CBFuse(nn.Module):
|
|||
"""CBFuse."""
|
||||
|
||||
def __init__(self, idx: list[int]):
|
||||
"""
|
||||
Initialize CBFuse module.
|
||||
"""Initialize CBFuse module.
|
||||
|
||||
Args:
|
||||
idx (list[int]): Indices for feature selection.
|
||||
|
|
@ -1061,8 +1019,7 @@ class CBFuse(nn.Module):
|
|||
self.idx = idx
|
||||
|
||||
def forward(self, xs: list[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through CBFuse layer.
|
||||
"""Forward pass through CBFuse layer.
|
||||
|
||||
Args:
|
||||
xs (list[torch.Tensor]): List of input tensors.
|
||||
|
|
@ -1079,8 +1036,7 @@ class C3f(nn.Module):
|
|||
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):
|
||||
"""
|
||||
Initialize CSP bottleneck layer with two convolutions.
|
||||
"""Initialize CSP bottleneck layer with two convolutions.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -1110,8 +1066,7 @@ class C3k2(C2f):
|
|||
def __init__(
|
||||
self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize C3k2 module.
|
||||
"""Initialize C3k2 module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -1132,8 +1087,7 @@ class C3k(C3):
|
|||
"""C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks."""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5, k: int = 3):
|
||||
"""
|
||||
Initialize C3k module.
|
||||
"""Initialize C3k module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -1154,8 +1108,7 @@ class RepVGGDW(torch.nn.Module):
|
|||
"""RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture."""
|
||||
|
||||
def __init__(self, ed: int) -> None:
|
||||
"""
|
||||
Initialize RepVGGDW module.
|
||||
"""Initialize RepVGGDW module.
|
||||
|
||||
Args:
|
||||
ed (int): Input and output channels.
|
||||
|
|
@ -1167,8 +1120,7 @@ class RepVGGDW(torch.nn.Module):
|
|||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Perform a forward pass of the RepVGGDW block.
|
||||
"""Perform a forward pass of the RepVGGDW block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1179,8 +1131,7 @@ class RepVGGDW(torch.nn.Module):
|
|||
return self.act(self.conv(x) + self.conv1(x))
|
||||
|
||||
def forward_fuse(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Perform a forward pass of the RepVGGDW block without fusing the convolutions.
|
||||
"""Perform a forward pass of the RepVGGDW block without fusing the convolutions.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1192,8 +1143,7 @@ class RepVGGDW(torch.nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
"""
|
||||
Fuse the convolutional layers in the RepVGGDW block.
|
||||
"""Fuse the convolutional layers in the RepVGGDW block.
|
||||
|
||||
This method fuses the convolutional layers and updates the weights and biases accordingly.
|
||||
"""
|
||||
|
|
@ -1218,8 +1168,7 @@ class RepVGGDW(torch.nn.Module):
|
|||
|
||||
|
||||
class CIB(nn.Module):
|
||||
"""
|
||||
Conditional Identity Block (CIB) module.
|
||||
"""Conditional Identity Block (CIB) module.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -1230,8 +1179,7 @@ class CIB(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, c1: int, c2: int, shortcut: bool = True, e: float = 0.5, lk: bool = False):
|
||||
"""
|
||||
Initialize the CIB module.
|
||||
"""Initialize the CIB module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -1253,8 +1201,7 @@ class CIB(nn.Module):
|
|||
self.add = shortcut and c1 == c2
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the CIB module.
|
||||
"""Forward pass of the CIB module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1266,8 +1213,7 @@ class CIB(nn.Module):
|
|||
|
||||
|
||||
class C2fCIB(C2f):
|
||||
"""
|
||||
C2fCIB class represents a convolutional block with C2f and CIB modules.
|
||||
"""C2fCIB class represents a convolutional block with C2f and CIB modules.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -1282,8 +1228,7 @@ class C2fCIB(C2f):
|
|||
def __init__(
|
||||
self, c1: int, c2: int, n: int = 1, shortcut: bool = False, lk: bool = False, g: int = 1, e: float = 0.5
|
||||
):
|
||||
"""
|
||||
Initialize C2fCIB module.
|
||||
"""Initialize C2fCIB module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -1299,8 +1244,7 @@ class C2fCIB(C2f):
|
|||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Attention module that performs self-attention on the input tensor.
|
||||
"""Attention module that performs self-attention on the input tensor.
|
||||
|
||||
Args:
|
||||
dim (int): The input tensor dimension.
|
||||
|
|
@ -1318,8 +1262,7 @@ class Attention(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5):
|
||||
"""
|
||||
Initialize multi-head attention module.
|
||||
"""Initialize multi-head attention module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
|
|
@ -1338,8 +1281,7 @@ class Attention(nn.Module):
|
|||
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the Attention module.
|
||||
"""Forward pass of the Attention module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
|
@ -1362,8 +1304,7 @@ class Attention(nn.Module):
|
|||
|
||||
|
||||
class PSABlock(nn.Module):
|
||||
"""
|
||||
PSABlock class implementing a Position-Sensitive Attention block for neural networks.
|
||||
"""PSABlock class implementing a Position-Sensitive Attention block for neural networks.
|
||||
|
||||
This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers
|
||||
with optional shortcut connections.
|
||||
|
|
@ -1384,8 +1325,7 @@ class PSABlock(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, c: int, attn_ratio: float = 0.5, num_heads: int = 4, shortcut: bool = True) -> None:
|
||||
"""
|
||||
Initialize the PSABlock.
|
||||
"""Initialize the PSABlock.
|
||||
|
||||
Args:
|
||||
c (int): Input and output channels.
|
||||
|
|
@ -1400,8 +1340,7 @@ class PSABlock(nn.Module):
|
|||
self.add = shortcut
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Execute a forward pass through PSABlock.
|
||||
"""Execute a forward pass through PSABlock.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1415,8 +1354,7 @@ class PSABlock(nn.Module):
|
|||
|
||||
|
||||
class PSA(nn.Module):
|
||||
"""
|
||||
PSA class for implementing Position-Sensitive Attention in neural networks.
|
||||
"""PSA class for implementing Position-Sensitive Attention in neural networks.
|
||||
|
||||
This class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to
|
||||
input tensors, enhancing feature extraction and processing capabilities.
|
||||
|
|
@ -1439,8 +1377,7 @@ class PSA(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, c1: int, c2: int, e: float = 0.5):
|
||||
"""
|
||||
Initialize PSA module.
|
||||
"""Initialize PSA module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -1457,8 +1394,7 @@ class PSA(nn.Module):
|
|||
self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Execute forward pass in PSA module.
|
||||
"""Execute forward pass in PSA module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1473,8 +1409,7 @@ class PSA(nn.Module):
|
|||
|
||||
|
||||
class C2PSA(nn.Module):
|
||||
"""
|
||||
C2PSA module with attention mechanism for enhanced feature extraction and processing.
|
||||
"""C2PSA module with attention mechanism for enhanced feature extraction and processing.
|
||||
|
||||
This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing
|
||||
capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.
|
||||
|
|
@ -1498,8 +1433,7 @@ class C2PSA(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):
|
||||
"""
|
||||
Initialize C2PSA module.
|
||||
"""Initialize C2PSA module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -1516,8 +1450,7 @@ class C2PSA(nn.Module):
|
|||
self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Process the input tensor through a series of PSA blocks.
|
||||
"""Process the input tensor through a series of PSA blocks.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1531,8 +1464,7 @@ class C2PSA(nn.Module):
|
|||
|
||||
|
||||
class C2fPSA(C2f):
|
||||
"""
|
||||
C2fPSA module with enhanced feature extraction using PSA blocks.
|
||||
"""C2fPSA module with enhanced feature extraction using PSA blocks.
|
||||
|
||||
This class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature
|
||||
extraction.
|
||||
|
|
@ -1557,8 +1489,7 @@ class C2fPSA(C2f):
|
|||
"""
|
||||
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):
|
||||
"""
|
||||
Initialize C2fPSA module.
|
||||
"""Initialize C2fPSA module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -1572,8 +1503,7 @@ class C2fPSA(C2f):
|
|||
|
||||
|
||||
class SCDown(nn.Module):
|
||||
"""
|
||||
SCDown module for downsampling with separable convolutions.
|
||||
"""SCDown module for downsampling with separable convolutions.
|
||||
|
||||
This module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in
|
||||
efficiently reducing the spatial dimensions of the input tensor while maintaining the channel information.
|
||||
|
|
@ -1596,8 +1526,7 @@ class SCDown(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, c1: int, c2: int, k: int, s: int):
|
||||
"""
|
||||
Initialize SCDown module.
|
||||
"""Initialize SCDown module.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
|
|
@ -1610,8 +1539,7 @@ class SCDown(nn.Module):
|
|||
self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply convolution and downsampling to the input tensor.
|
||||
"""Apply convolution and downsampling to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1623,8 +1551,7 @@ class SCDown(nn.Module):
|
|||
|
||||
|
||||
class TorchVision(nn.Module):
|
||||
"""
|
||||
TorchVision module to allow loading any torchvision model.
|
||||
"""TorchVision module to allow loading any torchvision model.
|
||||
|
||||
This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and
|
||||
customize the model by truncating or unwrapping layers.
|
||||
|
|
@ -1643,8 +1570,7 @@ class TorchVision(nn.Module):
|
|||
def __init__(
|
||||
self, model: str, weights: str = "DEFAULT", unwrap: bool = True, truncate: int = 2, split: bool = False
|
||||
):
|
||||
"""
|
||||
Load the model and weights from torchvision.
|
||||
"""Load the model and weights from torchvision.
|
||||
|
||||
Args:
|
||||
model (str): Name of the torchvision model to load.
|
||||
|
|
@ -1671,8 +1597,7 @@ class TorchVision(nn.Module):
|
|||
self.m.head = self.m.heads = nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the model.
|
||||
"""Forward pass through the model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1689,8 +1614,7 @@ class TorchVision(nn.Module):
|
|||
|
||||
|
||||
class AAttn(nn.Module):
|
||||
"""
|
||||
Area-attention module for YOLO models, providing efficient attention mechanisms.
|
||||
"""Area-attention module for YOLO models, providing efficient attention mechanisms.
|
||||
|
||||
This module implements an area-based attention mechanism that processes input features in a spatially-aware manner,
|
||||
making it particularly effective for object detection tasks.
|
||||
|
|
@ -1715,8 +1639,7 @@ class AAttn(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, dim: int, num_heads: int, area: int = 1):
|
||||
"""
|
||||
Initialize an Area-attention module for YOLO models.
|
||||
"""Initialize an Area-attention module for YOLO models.
|
||||
|
||||
Args:
|
||||
dim (int): Number of hidden channels.
|
||||
|
|
@ -1735,8 +1658,7 @@ class AAttn(nn.Module):
|
|||
self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Process the input tensor through the area-attention.
|
||||
"""Process the input tensor through the area-attention.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1775,8 +1697,7 @@ class AAttn(nn.Module):
|
|||
|
||||
|
||||
class ABlock(nn.Module):
|
||||
"""
|
||||
Area-attention block module for efficient feature extraction in YOLO models.
|
||||
"""Area-attention block module for efficient feature extraction in YOLO models.
|
||||
|
||||
This module implements an area-attention mechanism combined with a feed-forward network for processing feature maps.
|
||||
It uses a novel area-based attention approach that is more efficient than traditional self-attention while
|
||||
|
|
@ -1799,8 +1720,7 @@ class ABlock(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 1.2, area: int = 1):
|
||||
"""
|
||||
Initialize an Area-attention block module.
|
||||
"""Initialize an Area-attention block module.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
|
|
@ -1817,8 +1737,7 @@ class ABlock(nn.Module):
|
|||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m: nn.Module):
|
||||
"""
|
||||
Initialize weights using a truncated normal distribution.
|
||||
"""Initialize weights using a truncated normal distribution.
|
||||
|
||||
Args:
|
||||
m (nn.Module): Module to initialize.
|
||||
|
|
@ -1829,8 +1748,7 @@ class ABlock(nn.Module):
|
|||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through ABlock.
|
||||
"""Forward pass through ABlock.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1843,8 +1761,7 @@ class ABlock(nn.Module):
|
|||
|
||||
|
||||
class A2C2f(nn.Module):
|
||||
"""
|
||||
Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.
|
||||
"""Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.
|
||||
|
||||
This module extends the C2f architecture by incorporating area-attention and ABlock layers for improved feature
|
||||
processing. It supports both area-attention and standard convolution modes.
|
||||
|
|
@ -1879,8 +1796,7 @@ class A2C2f(nn.Module):
|
|||
g: int = 1,
|
||||
shortcut: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize Area-Attention C2f module.
|
||||
"""Initialize Area-Attention C2f module.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -1910,8 +1826,7 @@ class A2C2f(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through A2C2f layer.
|
||||
"""Forward pass through A2C2f layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1931,8 +1846,7 @@ class SwiGLUFFN(nn.Module):
|
|||
"""SwiGLU Feed-Forward Network for transformer-based architectures."""
|
||||
|
||||
def __init__(self, gc: int, ec: int, e: int = 4) -> None:
|
||||
"""
|
||||
Initialize SwiGLU FFN with input dimension, output dimension, and expansion factor.
|
||||
"""Initialize SwiGLU FFN with input dimension, output dimension, and expansion factor.
|
||||
|
||||
Args:
|
||||
gc (int): Guide channels.
|
||||
|
|
@ -1955,8 +1869,7 @@ class Residual(nn.Module):
|
|||
"""Residual connection wrapper for neural network modules."""
|
||||
|
||||
def __init__(self, m: nn.Module) -> None:
|
||||
"""
|
||||
Initialize residual module with the wrapped module.
|
||||
"""Initialize residual module with the wrapped module.
|
||||
|
||||
Args:
|
||||
m (nn.Module): Module to wrap with residual connection.
|
||||
|
|
@ -1977,8 +1890,7 @@ class SAVPE(nn.Module):
|
|||
"""Spatial-Aware Visual Prompt Embedding module for feature enhancement."""
|
||||
|
||||
def __init__(self, ch: list[int], c3: int, embed: int):
|
||||
"""
|
||||
Initialize SAVPE module with channels, intermediate channels, and embedding dimension.
|
||||
"""Initialize SAVPE module with channels, intermediate channels, and embedding dimension.
|
||||
|
||||
Args:
|
||||
ch (list[int]): List of input channel dimensions.
|
||||
|
|
|
|||
|
|
@ -37,8 +37,7 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation
|
|||
|
||||
|
||||
class Conv(nn.Module):
|
||||
"""
|
||||
Standard convolution module with batch normalization and activation.
|
||||
"""Standard convolution module with batch normalization and activation.
|
||||
|
||||
Attributes:
|
||||
conv (nn.Conv2d): Convolutional layer.
|
||||
|
|
@ -50,8 +49,7 @@ class Conv(nn.Module):
|
|||
default_act = nn.SiLU() # default activation
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
||||
"""
|
||||
Initialize Conv layer with given parameters.
|
||||
"""Initialize Conv layer with given parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -69,8 +67,7 @@ class Conv(nn.Module):
|
|||
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Apply convolution, batch normalization and activation to input tensor.
|
||||
"""Apply convolution, batch normalization and activation to input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -81,8 +78,7 @@ class Conv(nn.Module):
|
|||
return self.act(self.bn(self.conv(x)))
|
||||
|
||||
def forward_fuse(self, x):
|
||||
"""
|
||||
Apply convolution and activation without batch normalization.
|
||||
"""Apply convolution and activation without batch normalization.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -94,8 +90,7 @@ class Conv(nn.Module):
|
|||
|
||||
|
||||
class Conv2(Conv):
|
||||
"""
|
||||
Simplified RepConv module with Conv fusing.
|
||||
"""Simplified RepConv module with Conv fusing.
|
||||
|
||||
Attributes:
|
||||
conv (nn.Conv2d): Main 3x3 convolutional layer.
|
||||
|
|
@ -105,8 +100,7 @@ class Conv2(Conv):
|
|||
"""
|
||||
|
||||
def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
|
||||
"""
|
||||
Initialize Conv2 layer with given parameters.
|
||||
"""Initialize Conv2 layer with given parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -122,8 +116,7 @@ class Conv2(Conv):
|
|||
self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Apply convolution, batch normalization and activation to input tensor.
|
||||
"""Apply convolution, batch normalization and activation to input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -134,8 +127,7 @@ class Conv2(Conv):
|
|||
return self.act(self.bn(self.conv(x) + self.cv2(x)))
|
||||
|
||||
def forward_fuse(self, x):
|
||||
"""
|
||||
Apply fused convolution, batch normalization and activation to input tensor.
|
||||
"""Apply fused convolution, batch normalization and activation to input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -156,8 +148,7 @@ class Conv2(Conv):
|
|||
|
||||
|
||||
class LightConv(nn.Module):
|
||||
"""
|
||||
Light convolution module with 1x1 and depthwise convolutions.
|
||||
"""Light convolution module with 1x1 and depthwise convolutions.
|
||||
|
||||
This implementation is based on the PaddleDetection HGNetV2 backbone.
|
||||
|
||||
|
|
@ -167,8 +158,7 @@ class LightConv(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, c1, c2, k=1, act=nn.ReLU()):
|
||||
"""
|
||||
Initialize LightConv layer with given parameters.
|
||||
"""Initialize LightConv layer with given parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -181,8 +171,7 @@ class LightConv(nn.Module):
|
|||
self.conv2 = DWConv(c2, c2, k, act=act)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Apply 2 convolutions to input tensor.
|
||||
"""Apply 2 convolutions to input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -197,8 +186,7 @@ class DWConv(Conv):
|
|||
"""Depth-wise convolution module."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
|
||||
"""
|
||||
Initialize depth-wise convolution with given parameters.
|
||||
"""Initialize depth-wise convolution with given parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -215,8 +203,7 @@ class DWConvTranspose2d(nn.ConvTranspose2d):
|
|||
"""Depth-wise transpose convolution module."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0):
|
||||
"""
|
||||
Initialize depth-wise transpose convolution with given parameters.
|
||||
"""Initialize depth-wise transpose convolution with given parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -230,8 +217,7 @@ class DWConvTranspose2d(nn.ConvTranspose2d):
|
|||
|
||||
|
||||
class ConvTranspose(nn.Module):
|
||||
"""
|
||||
Convolution transpose module with optional batch normalization and activation.
|
||||
"""Convolution transpose module with optional batch normalization and activation.
|
||||
|
||||
Attributes:
|
||||
conv_transpose (nn.ConvTranspose2d): Transposed convolution layer.
|
||||
|
|
@ -243,8 +229,7 @@ class ConvTranspose(nn.Module):
|
|||
default_act = nn.SiLU() # default activation
|
||||
|
||||
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
|
||||
"""
|
||||
Initialize ConvTranspose layer with given parameters.
|
||||
"""Initialize ConvTranspose layer with given parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -261,8 +246,7 @@ class ConvTranspose(nn.Module):
|
|||
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Apply transposed convolution, batch normalization and activation to input.
|
||||
"""Apply transposed convolution, batch normalization and activation to input.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -273,8 +257,7 @@ class ConvTranspose(nn.Module):
|
|||
return self.act(self.bn(self.conv_transpose(x)))
|
||||
|
||||
def forward_fuse(self, x):
|
||||
"""
|
||||
Apply activation and convolution transpose operation to input.
|
||||
"""Apply activation and convolution transpose operation to input.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -286,8 +269,7 @@ class ConvTranspose(nn.Module):
|
|||
|
||||
|
||||
class Focus(nn.Module):
|
||||
"""
|
||||
Focus module for concentrating feature information.
|
||||
"""Focus module for concentrating feature information.
|
||||
|
||||
Slices input tensor into 4 parts and concatenates them in the channel dimension.
|
||||
|
||||
|
|
@ -296,8 +278,7 @@ class Focus(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
|
||||
"""
|
||||
Initialize Focus module with given parameters.
|
||||
"""Initialize Focus module with given parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -313,8 +294,7 @@ class Focus(nn.Module):
|
|||
# self.contract = Contract(gain=2)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Apply Focus operation and convolution to input tensor.
|
||||
"""Apply Focus operation and convolution to input tensor.
|
||||
|
||||
Input shape is (B, C, W, H) and output shape is (B, 4C, W/2, H/2).
|
||||
|
||||
|
|
@ -329,8 +309,7 @@ class Focus(nn.Module):
|
|||
|
||||
|
||||
class GhostConv(nn.Module):
|
||||
"""
|
||||
Ghost Convolution module.
|
||||
"""Ghost Convolution module.
|
||||
|
||||
Generates more features with fewer parameters by using cheap operations.
|
||||
|
||||
|
|
@ -343,8 +322,7 @@ class GhostConv(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
|
||||
"""
|
||||
Initialize Ghost Convolution module with given parameters.
|
||||
"""Initialize Ghost Convolution module with given parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -360,8 +338,7 @@ class GhostConv(nn.Module):
|
|||
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Apply Ghost Convolution to input tensor.
|
||||
"""Apply Ghost Convolution to input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -374,8 +351,7 @@ class GhostConv(nn.Module):
|
|||
|
||||
|
||||
class RepConv(nn.Module):
|
||||
"""
|
||||
RepConv module with training and deploy modes.
|
||||
"""RepConv module with training and deploy modes.
|
||||
|
||||
This module is used in RT-DETR and can fuse convolutions during inference for efficiency.
|
||||
|
||||
|
|
@ -393,8 +369,7 @@ class RepConv(nn.Module):
|
|||
default_act = nn.SiLU() # default activation
|
||||
|
||||
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
|
||||
"""
|
||||
Initialize RepConv module with given parameters.
|
||||
"""Initialize RepConv module with given parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -420,8 +395,7 @@ class RepConv(nn.Module):
|
|||
self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
|
||||
|
||||
def forward_fuse(self, x):
|
||||
"""
|
||||
Forward pass for deploy mode.
|
||||
"""Forward pass for deploy mode.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -432,8 +406,7 @@ class RepConv(nn.Module):
|
|||
return self.act(self.conv(x))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass for training mode.
|
||||
"""Forward pass for training mode.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -445,8 +418,7 @@ class RepConv(nn.Module):
|
|||
return self.act(self.conv1(x) + self.conv2(x) + id_out)
|
||||
|
||||
def get_equivalent_kernel_bias(self):
|
||||
"""
|
||||
Calculate equivalent kernel and bias by fusing convolutions.
|
||||
"""Calculate equivalent kernel and bias by fusing convolutions.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Equivalent kernel
|
||||
|
|
@ -459,8 +431,7 @@ class RepConv(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _pad_1x1_to_3x3_tensor(kernel1x1):
|
||||
"""
|
||||
Pad a 1x1 kernel to 3x3 size.
|
||||
"""Pad a 1x1 kernel to 3x3 size.
|
||||
|
||||
Args:
|
||||
kernel1x1 (torch.Tensor): 1x1 convolution kernel.
|
||||
|
|
@ -474,8 +445,7 @@ class RepConv(nn.Module):
|
|||
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
||||
|
||||
def _fuse_bn_tensor(self, branch):
|
||||
"""
|
||||
Fuse batch normalization with convolution weights.
|
||||
"""Fuse batch normalization with convolution weights.
|
||||
|
||||
Args:
|
||||
branch (Conv | nn.BatchNorm2d | None): Branch to fuse.
|
||||
|
|
@ -540,8 +510,7 @@ class RepConv(nn.Module):
|
|||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
"""
|
||||
Channel-attention module for feature recalibration.
|
||||
"""Channel-attention module for feature recalibration.
|
||||
|
||||
Applies attention weights to channels based on global average pooling.
|
||||
|
||||
|
|
@ -555,8 +524,7 @@ class ChannelAttention(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, channels: int) -> None:
|
||||
"""
|
||||
Initialize Channel-attention module.
|
||||
"""Initialize Channel-attention module.
|
||||
|
||||
Args:
|
||||
channels (int): Number of input channels.
|
||||
|
|
@ -567,8 +535,7 @@ class ChannelAttention(nn.Module):
|
|||
self.act = nn.Sigmoid()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply channel attention to input tensor.
|
||||
"""Apply channel attention to input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -580,8 +547,7 @@ class ChannelAttention(nn.Module):
|
|||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
"""
|
||||
Spatial-attention module for feature recalibration.
|
||||
"""Spatial-attention module for feature recalibration.
|
||||
|
||||
Applies attention weights to spatial dimensions based on channel statistics.
|
||||
|
||||
|
|
@ -591,8 +557,7 @@ class SpatialAttention(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, kernel_size=7):
|
||||
"""
|
||||
Initialize Spatial-attention module.
|
||||
"""Initialize Spatial-attention module.
|
||||
|
||||
Args:
|
||||
kernel_size (int): Size of the convolutional kernel (3 or 7).
|
||||
|
|
@ -604,8 +569,7 @@ class SpatialAttention(nn.Module):
|
|||
self.act = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Apply spatial attention to input tensor.
|
||||
"""Apply spatial attention to input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -617,8 +581,7 @@ class SpatialAttention(nn.Module):
|
|||
|
||||
|
||||
class CBAM(nn.Module):
|
||||
"""
|
||||
Convolutional Block Attention Module.
|
||||
"""Convolutional Block Attention Module.
|
||||
|
||||
Combines channel and spatial attention mechanisms for comprehensive feature refinement.
|
||||
|
||||
|
|
@ -628,8 +591,7 @@ class CBAM(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, c1, kernel_size=7):
|
||||
"""
|
||||
Initialize CBAM with given parameters.
|
||||
"""Initialize CBAM with given parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -640,8 +602,7 @@ class CBAM(nn.Module):
|
|||
self.spatial_attention = SpatialAttention(kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Apply channel and spatial attention sequentially to input tensor.
|
||||
"""Apply channel and spatial attention sequentially to input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -653,16 +614,14 @@ class CBAM(nn.Module):
|
|||
|
||||
|
||||
class Concat(nn.Module):
|
||||
"""
|
||||
Concatenate a list of tensors along specified dimension.
|
||||
"""Concatenate a list of tensors along specified dimension.
|
||||
|
||||
Attributes:
|
||||
d (int): Dimension along which to concatenate tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension=1):
|
||||
"""
|
||||
Initialize Concat module.
|
||||
"""Initialize Concat module.
|
||||
|
||||
Args:
|
||||
dimension (int): Dimension along which to concatenate tensors.
|
||||
|
|
@ -671,8 +630,7 @@ class Concat(nn.Module):
|
|||
self.d = dimension
|
||||
|
||||
def forward(self, x: list[torch.Tensor]):
|
||||
"""
|
||||
Concatenate input tensors along specified dimension.
|
||||
"""Concatenate input tensors along specified dimension.
|
||||
|
||||
Args:
|
||||
x (list[torch.Tensor]): List of input tensors.
|
||||
|
|
@ -684,16 +642,14 @@ class Concat(nn.Module):
|
|||
|
||||
|
||||
class Index(nn.Module):
|
||||
"""
|
||||
Returns a particular index of the input.
|
||||
"""Returns a particular index of the input.
|
||||
|
||||
Attributes:
|
||||
index (int): Index to select from input.
|
||||
"""
|
||||
|
||||
def __init__(self, index=0):
|
||||
"""
|
||||
Initialize Index module.
|
||||
"""Initialize Index module.
|
||||
|
||||
Args:
|
||||
index (int): Index to select from input.
|
||||
|
|
@ -702,8 +658,7 @@ class Index(nn.Module):
|
|||
self.index = index
|
||||
|
||||
def forward(self, x: list[torch.Tensor]):
|
||||
"""
|
||||
Select and return a particular index from input.
|
||||
"""Select and return a particular index from input.
|
||||
|
||||
Args:
|
||||
x (list[torch.Tensor]): List of input tensors.
|
||||
|
|
|
|||
|
|
@ -24,8 +24,7 @@ __all__ = "OBB", "Classify", "Detect", "Pose", "RTDETRDecoder", "Segment", "YOLO
|
|||
|
||||
|
||||
class Detect(nn.Module):
|
||||
"""
|
||||
YOLO Detect head for object detection models.
|
||||
"""YOLO Detect head for object detection models.
|
||||
|
||||
This class implements the detection head used in YOLO models for predicting bounding boxes and class probabilities.
|
||||
It supports both training and inference modes, with optional end-to-end detection capabilities.
|
||||
|
|
@ -78,8 +77,7 @@ class Detect(nn.Module):
|
|||
xyxy = False # xyxy or xywh output
|
||||
|
||||
def __init__(self, nc: int = 80, ch: tuple = ()):
|
||||
"""
|
||||
Initialize the YOLO detection layer with specified number of classes and channels.
|
||||
"""Initialize the YOLO detection layer with specified number of classes and channels.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
|
|
@ -126,15 +124,14 @@ class Detect(nn.Module):
|
|||
return y if self.export else (y, x)
|
||||
|
||||
def forward_end2end(self, x: list[torch.Tensor]) -> dict | tuple:
|
||||
"""
|
||||
Perform forward pass of the v10Detect module.
|
||||
"""Perform forward pass of the v10Detect module.
|
||||
|
||||
Args:
|
||||
x (list[torch.Tensor]): Input feature maps from different levels.
|
||||
|
||||
Returns:
|
||||
outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs.
|
||||
Inference mode returns processed detections or tuple with detections and raw outputs.
|
||||
outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs. Inference mode returns
|
||||
processed detections or tuple with detections and raw outputs.
|
||||
"""
|
||||
x_detach = [xi.detach() for xi in x]
|
||||
one2one = [
|
||||
|
|
@ -150,8 +147,7 @@ class Detect(nn.Module):
|
|||
return y if self.export else (y, {"one2many": x, "one2one": one2one})
|
||||
|
||||
def _inference(self, x: list[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
|
||||
"""Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
|
||||
|
||||
Args:
|
||||
x (list[torch.Tensor]): List of feature maps from different detection layers.
|
||||
|
|
@ -194,8 +190,7 @@ class Detect(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
|
||||
"""
|
||||
Post-process YOLO model predictions.
|
||||
"""Post-process YOLO model predictions.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
|
||||
|
|
@ -218,8 +213,7 @@ class Detect(nn.Module):
|
|||
|
||||
|
||||
class Segment(Detect):
|
||||
"""
|
||||
YOLO Segment head for segmentation models.
|
||||
"""YOLO Segment head for segmentation models.
|
||||
|
||||
This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.
|
||||
|
||||
|
|
@ -240,8 +234,7 @@ class Segment(Detect):
|
|||
"""
|
||||
|
||||
def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, ch: tuple = ()):
|
||||
"""
|
||||
Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
|
||||
"""Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
|
|
@ -270,8 +263,7 @@ class Segment(Detect):
|
|||
|
||||
|
||||
class OBB(Detect):
|
||||
"""
|
||||
YOLO OBB detection head for detection with rotation models.
|
||||
"""YOLO OBB detection head for detection with rotation models.
|
||||
|
||||
This class extends the Detect head to include oriented bounding box prediction with rotation angles.
|
||||
|
||||
|
|
@ -292,8 +284,7 @@ class OBB(Detect):
|
|||
"""
|
||||
|
||||
def __init__(self, nc: int = 80, ne: int = 1, ch: tuple = ()):
|
||||
"""
|
||||
Initialize OBB with number of classes `nc` and layer channels `ch`.
|
||||
"""Initialize OBB with number of classes `nc` and layer channels `ch`.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
|
|
@ -326,8 +317,7 @@ class OBB(Detect):
|
|||
|
||||
|
||||
class Pose(Detect):
|
||||
"""
|
||||
YOLO Pose head for keypoints models.
|
||||
"""YOLO Pose head for keypoints models.
|
||||
|
||||
This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.
|
||||
|
||||
|
|
@ -348,8 +338,7 @@ class Pose(Detect):
|
|||
"""
|
||||
|
||||
def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), ch: tuple = ()):
|
||||
"""
|
||||
Initialize YOLO network with default parameters and Convolutional Layers.
|
||||
"""Initialize YOLO network with default parameters and Convolutional Layers.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
|
|
@ -396,8 +385,7 @@ class Pose(Detect):
|
|||
|
||||
|
||||
class Classify(nn.Module):
|
||||
"""
|
||||
YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
|
||||
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
|
||||
|
||||
This class implements a classification head that transforms feature maps into class predictions.
|
||||
|
||||
|
|
@ -421,8 +409,7 @@ class Classify(nn.Module):
|
|||
export = False # export mode
|
||||
|
||||
def __init__(self, c1: int, c2: int, k: int = 1, s: int = 1, p: int | None = None, g: int = 1):
|
||||
"""
|
||||
Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
|
||||
"""Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
|
|
@ -451,8 +438,7 @@ class Classify(nn.Module):
|
|||
|
||||
|
||||
class WorldDetect(Detect):
|
||||
"""
|
||||
Head for integrating YOLO detection models with semantic understanding from text embeddings.
|
||||
"""Head for integrating YOLO detection models with semantic understanding from text embeddings.
|
||||
|
||||
This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding in
|
||||
object detection tasks.
|
||||
|
|
@ -474,8 +460,7 @@ class WorldDetect(Detect):
|
|||
"""
|
||||
|
||||
def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
|
||||
"""
|
||||
Initialize YOLO detection layer with nc classes and layer channels ch.
|
||||
"""Initialize YOLO detection layer with nc classes and layer channels ch.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
|
|
@ -509,8 +494,7 @@ class WorldDetect(Detect):
|
|||
|
||||
|
||||
class LRPCHead(nn.Module):
|
||||
"""
|
||||
Lightweight Region Proposal and Classification Head for efficient object detection.
|
||||
"""Lightweight Region Proposal and Classification Head for efficient object detection.
|
||||
|
||||
This head combines region proposal filtering with classification to enable efficient detection with dynamic
|
||||
vocabulary support.
|
||||
|
|
@ -534,8 +518,7 @@ class LRPCHead(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, vocab: nn.Module, pf: nn.Module, loc: nn.Module, enabled: bool = True):
|
||||
"""
|
||||
Initialize LRPCHead with vocabulary, proposal filter, and localization components.
|
||||
"""Initialize LRPCHead with vocabulary, proposal filter, and localization components.
|
||||
|
||||
Args:
|
||||
vocab (nn.Module): Vocabulary/classification module.
|
||||
|
|
@ -574,8 +557,7 @@ class LRPCHead(nn.Module):
|
|||
|
||||
|
||||
class YOLOEDetect(Detect):
|
||||
"""
|
||||
Head for integrating YOLO detection models with semantic understanding from text embeddings.
|
||||
"""Head for integrating YOLO detection models with semantic understanding from text embeddings.
|
||||
|
||||
This class extends the standard Detect head to support text-guided detection with enhanced semantic understanding
|
||||
through text embeddings and visual prompt embeddings.
|
||||
|
|
@ -607,8 +589,7 @@ class YOLOEDetect(Detect):
|
|||
is_fused = False
|
||||
|
||||
def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
|
||||
"""
|
||||
Initialize YOLO detection layer with nc classes and layer channels ch.
|
||||
"""Initialize YOLO detection layer with nc classes and layer channels ch.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
|
|
@ -762,8 +743,7 @@ class YOLOEDetect(Detect):
|
|||
|
||||
|
||||
class YOLOESegment(YOLOEDetect):
|
||||
"""
|
||||
YOLO segmentation head with text embedding capabilities.
|
||||
"""YOLO segmentation head with text embedding capabilities.
|
||||
|
||||
This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks with
|
||||
text-guided semantic understanding.
|
||||
|
|
@ -788,8 +768,7 @@ class YOLOESegment(YOLOEDetect):
|
|||
def __init__(
|
||||
self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, with_bn: bool = False, ch: tuple = ()
|
||||
):
|
||||
"""
|
||||
Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
|
||||
"""Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
|
|
@ -830,8 +809,7 @@ class YOLOESegment(YOLOEDetect):
|
|||
|
||||
|
||||
class RTDETRDecoder(nn.Module):
|
||||
"""
|
||||
Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
|
||||
"""Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
|
||||
|
||||
This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
|
||||
and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
|
||||
|
|
@ -895,8 +873,7 @@ class RTDETRDecoder(nn.Module):
|
|||
box_noise_scale: float = 1.0,
|
||||
learnt_init_query: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the RTDETRDecoder module with the given parameters.
|
||||
"""Initialize the RTDETRDecoder module with the given parameters.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
|
|
@ -956,8 +933,7 @@ class RTDETRDecoder(nn.Module):
|
|||
self._reset_parameters()
|
||||
|
||||
def forward(self, x: list[torch.Tensor], batch: dict | None = None) -> tuple | torch.Tensor:
|
||||
"""
|
||||
Run the forward pass of the module, returning bounding box and classification scores for the input.
|
||||
"""Run the forward pass of the module, returning bounding box and classification scores for the input.
|
||||
|
||||
Args:
|
||||
x (list[torch.Tensor]): List of feature maps from the backbone.
|
||||
|
|
@ -1013,8 +989,7 @@ class RTDETRDecoder(nn.Module):
|
|||
device: str = "cpu",
|
||||
eps: float = 1e-2,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Generate anchor bounding boxes for given shapes with specific grid size and validate them.
|
||||
"""Generate anchor bounding boxes for given shapes with specific grid size and validate them.
|
||||
|
||||
Args:
|
||||
shapes (list): List of feature map shapes.
|
||||
|
|
@ -1046,8 +1021,7 @@ class RTDETRDecoder(nn.Module):
|
|||
return anchors, valid_mask
|
||||
|
||||
def _get_encoder_input(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, list[list[int]]]:
|
||||
"""
|
||||
Process and return encoder inputs by getting projection features from input and concatenating them.
|
||||
"""Process and return encoder inputs by getting projection features from input and concatenating them.
|
||||
|
||||
Args:
|
||||
x (list[torch.Tensor]): List of feature maps from the backbone.
|
||||
|
|
@ -1079,8 +1053,7 @@ class RTDETRDecoder(nn.Module):
|
|||
dn_embed: torch.Tensor | None = None,
|
||||
dn_bbox: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Generate and prepare the input required for the decoder from the provided features and shapes.
|
||||
"""Generate and prepare the input required for the decoder from the provided features and shapes.
|
||||
|
||||
Args:
|
||||
feats (torch.Tensor): Processed features from encoder.
|
||||
|
|
@ -1158,8 +1131,7 @@ class RTDETRDecoder(nn.Module):
|
|||
|
||||
|
||||
class v10Detect(Detect):
|
||||
"""
|
||||
v10 Detection head from https://arxiv.org/pdf/2405.14458.
|
||||
"""v10 Detection head from https://arxiv.org/pdf/2405.14458.
|
||||
|
||||
This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions for
|
||||
improved efficiency and performance.
|
||||
|
|
@ -1186,8 +1158,7 @@ class v10Detect(Detect):
|
|||
end2end = True
|
||||
|
||||
def __init__(self, nc: int = 80, ch: tuple = ()):
|
||||
"""
|
||||
Initialize the v10Detect object with the specified number of classes and input channels.
|
||||
"""Initialize the v10Detect object with the specified number of classes and input channels.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
|
|
|
|||
|
|
@ -30,8 +30,7 @@ __all__ = (
|
|||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
A single layer of the transformer encoder.
|
||||
"""A single layer of the transformer encoder.
|
||||
|
||||
This class implements a standard transformer encoder layer with multi-head attention and feedforward network,
|
||||
supporting both pre-normalization and post-normalization configurations.
|
||||
|
|
@ -58,8 +57,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||
act: nn.Module = nn.GELU(),
|
||||
normalize_before: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the TransformerEncoderLayer with specified parameters.
|
||||
"""Initialize the TransformerEncoderLayer with specified parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Input dimension.
|
||||
|
|
@ -102,8 +100,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||
src_key_padding_mask: torch.Tensor | None = None,
|
||||
pos: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform forward pass with post-normalization.
|
||||
"""Perform forward pass with post-normalization.
|
||||
|
||||
Args:
|
||||
src (torch.Tensor): Input tensor.
|
||||
|
|
@ -129,8 +126,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||
src_key_padding_mask: torch.Tensor | None = None,
|
||||
pos: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform forward pass with pre-normalization.
|
||||
"""Perform forward pass with pre-normalization.
|
||||
|
||||
Args:
|
||||
src (torch.Tensor): Input tensor.
|
||||
|
|
@ -156,8 +152,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||
src_key_padding_mask: torch.Tensor | None = None,
|
||||
pos: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward propagate the input through the encoder module.
|
||||
"""Forward propagate the input through the encoder module.
|
||||
|
||||
Args:
|
||||
src (torch.Tensor): Input tensor.
|
||||
|
|
@ -174,8 +169,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||
|
||||
|
||||
class AIFI(TransformerEncoderLayer):
|
||||
"""
|
||||
AIFI transformer layer for 2D data with positional embeddings.
|
||||
"""AIFI transformer layer for 2D data with positional embeddings.
|
||||
|
||||
This class extends TransformerEncoderLayer to work with 2D feature maps by adding 2D sine-cosine positional
|
||||
embeddings and handling the spatial dimensions appropriately.
|
||||
|
|
@ -190,8 +184,7 @@ class AIFI(TransformerEncoderLayer):
|
|||
act: nn.Module = nn.GELU(),
|
||||
normalize_before: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the AIFI instance with specified parameters.
|
||||
"""Initialize the AIFI instance with specified parameters.
|
||||
|
||||
Args:
|
||||
c1 (int): Input dimension.
|
||||
|
|
@ -204,8 +197,7 @@ class AIFI(TransformerEncoderLayer):
|
|||
super().__init__(c1, cm, num_heads, dropout, act, normalize_before)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the AIFI transformer layer.
|
||||
"""Forward pass for the AIFI transformer layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with shape [B, C, H, W].
|
||||
|
|
@ -223,8 +215,7 @@ class AIFI(TransformerEncoderLayer):
|
|||
def build_2d_sincos_position_embedding(
|
||||
w: int, h: int, embed_dim: int = 256, temperature: float = 10000.0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Build 2D sine-cosine position embedding.
|
||||
"""Build 2D sine-cosine position embedding.
|
||||
|
||||
Args:
|
||||
w (int): Width of the feature map.
|
||||
|
|
@ -253,8 +244,7 @@ class TransformerLayer(nn.Module):
|
|||
"""Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""
|
||||
|
||||
def __init__(self, c: int, num_heads: int):
|
||||
"""
|
||||
Initialize a self-attention mechanism using linear transformations and multi-head attention.
|
||||
"""Initialize a self-attention mechanism using linear transformations and multi-head attention.
|
||||
|
||||
Args:
|
||||
c (int): Input and output channel dimension.
|
||||
|
|
@ -269,8 +259,7 @@ class TransformerLayer(nn.Module):
|
|||
self.fc2 = nn.Linear(c, c, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply a transformer block to the input x and return the output.
|
||||
"""Apply a transformer block to the input x and return the output.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -283,8 +272,7 @@ class TransformerLayer(nn.Module):
|
|||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""
|
||||
Vision Transformer block based on https://arxiv.org/abs/2010.11929.
|
||||
"""Vision Transformer block based on https://arxiv.org/abs/2010.11929.
|
||||
|
||||
This class implements a complete transformer block with optional convolution layer for channel adjustment, learnable
|
||||
position embedding, and multiple transformer layers.
|
||||
|
|
@ -297,8 +285,7 @@ class TransformerBlock(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, c1: int, c2: int, num_heads: int, num_layers: int):
|
||||
"""
|
||||
Initialize a Transformer module with position embedding and specified number of heads and layers.
|
||||
"""Initialize a Transformer module with position embedding and specified number of heads and layers.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channel dimension.
|
||||
|
|
@ -315,8 +302,7 @@ class TransformerBlock(nn.Module):
|
|||
self.c2 = c2
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward propagate the input through the transformer block.
|
||||
"""Forward propagate the input through the transformer block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with shape [b, c1, w, h].
|
||||
|
|
@ -335,8 +321,7 @@ class MLPBlock(nn.Module):
|
|||
"""A single block of a multi-layer perceptron."""
|
||||
|
||||
def __init__(self, embedding_dim: int, mlp_dim: int, act=nn.GELU):
|
||||
"""
|
||||
Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.
|
||||
"""Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): Input and output dimension.
|
||||
|
|
@ -349,8 +334,7 @@ class MLPBlock(nn.Module):
|
|||
self.act = act()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the MLPBlock.
|
||||
"""Forward pass for the MLPBlock.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -362,8 +346,7 @@ class MLPBlock(nn.Module):
|
|||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
A simple multi-layer perceptron (also called FFN).
|
||||
"""A simple multi-layer perceptron (also called FFN).
|
||||
|
||||
This class implements a configurable MLP with multiple linear layers, activation functions, and optional sigmoid
|
||||
output activation.
|
||||
|
|
@ -378,8 +361,7 @@ class MLP(nn.Module):
|
|||
def __init__(
|
||||
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act=nn.ReLU, sigmoid: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize the MLP with specified input, hidden, output dimensions and number of layers.
|
||||
"""Initialize the MLP with specified input, hidden, output dimensions and number of layers.
|
||||
|
||||
Args:
|
||||
input_dim (int): Input dimension.
|
||||
|
|
@ -397,8 +379,7 @@ class MLP(nn.Module):
|
|||
self.act = act()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the entire MLP.
|
||||
"""Forward pass for the entire MLP.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -412,8 +393,7 @@ class MLP(nn.Module):
|
|||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
"""
|
||||
2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.
|
||||
"""2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.
|
||||
|
||||
This class implements layer normalization for 2D feature maps, normalizing across the channel dimension while
|
||||
preserving spatial dimensions.
|
||||
|
|
@ -429,8 +409,7 @@ class LayerNorm2d(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6):
|
||||
"""
|
||||
Initialize LayerNorm2d with the given parameters.
|
||||
"""Initialize LayerNorm2d with the given parameters.
|
||||
|
||||
Args:
|
||||
num_channels (int): Number of channels in the input.
|
||||
|
|
@ -442,8 +421,7 @@ class LayerNorm2d(nn.Module):
|
|||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Perform forward pass for 2D layer normalization.
|
||||
"""Perform forward pass for 2D layer normalization.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -458,8 +436,7 @@ class LayerNorm2d(nn.Module):
|
|||
|
||||
|
||||
class MSDeformAttn(nn.Module):
|
||||
"""
|
||||
Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.
|
||||
"""Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.
|
||||
|
||||
This module implements multiscale deformable attention that can attend to features at multiple scales with learnable
|
||||
sampling locations and attention weights.
|
||||
|
|
@ -480,8 +457,7 @@ class MSDeformAttn(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, d_model: int = 256, n_levels: int = 4, n_heads: int = 8, n_points: int = 4):
|
||||
"""
|
||||
Initialize MSDeformAttn with the given parameters.
|
||||
"""Initialize MSDeformAttn with the given parameters.
|
||||
|
||||
Args:
|
||||
d_model (int): Model dimension.
|
||||
|
|
@ -539,13 +515,12 @@ class MSDeformAttn(nn.Module):
|
|||
value_shapes: list,
|
||||
value_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform forward pass for multiscale deformable attention.
|
||||
"""Perform forward pass for multiscale deformable attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor with shape [bs, query_length, C].
|
||||
refer_bbox (torch.Tensor): Reference bounding boxes with shape [bs, query_length, n_levels, 2],
|
||||
range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area.
|
||||
refer_bbox (torch.Tensor): Reference bounding boxes with shape [bs, query_length, n_levels, 2], range in [0,
|
||||
1], top-left (0,0), bottom-right (1, 1), including padding area.
|
||||
value (torch.Tensor): Value tensor with shape [bs, value_length, C].
|
||||
value_shapes (list): List with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})].
|
||||
value_mask (torch.Tensor, optional): Mask tensor with shape [bs, value_length], True for non-padding
|
||||
|
|
@ -584,8 +559,7 @@ class MSDeformAttn(nn.Module):
|
|||
|
||||
|
||||
class DeformableTransformerDecoderLayer(nn.Module):
|
||||
"""
|
||||
Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.
|
||||
"""Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.
|
||||
|
||||
This class implements a single decoder layer with self-attention, cross-attention using multiscale deformable
|
||||
attention, and a feedforward network.
|
||||
|
|
@ -619,8 +593,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||
n_levels: int = 4,
|
||||
n_points: int = 4,
|
||||
):
|
||||
"""
|
||||
Initialize the DeformableTransformerDecoderLayer with the given parameters.
|
||||
"""Initialize the DeformableTransformerDecoderLayer with the given parameters.
|
||||
|
||||
Args:
|
||||
d_model (int): Model dimension.
|
||||
|
|
@ -657,8 +630,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_ffn(self, tgt: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Perform forward pass through the Feed-Forward Network part of the layer.
|
||||
"""Perform forward pass through the Feed-Forward Network part of the layer.
|
||||
|
||||
Args:
|
||||
tgt (torch.Tensor): Input tensor.
|
||||
|
|
@ -680,8 +652,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||
attn_mask: torch.Tensor | None = None,
|
||||
query_pos: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform the forward pass through the entire decoder layer.
|
||||
"""Perform the forward pass through the entire decoder layer.
|
||||
|
||||
Args:
|
||||
embed (torch.Tensor): Input embeddings.
|
||||
|
|
@ -715,8 +686,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||
|
||||
|
||||
class DeformableTransformerDecoder(nn.Module):
|
||||
"""
|
||||
Deformable Transformer Decoder based on PaddleDetection implementation.
|
||||
"""Deformable Transformer Decoder based on PaddleDetection implementation.
|
||||
|
||||
This class implements a complete deformable transformer decoder with multiple decoder layers and prediction heads
|
||||
for bounding box regression and classification.
|
||||
|
|
@ -732,8 +702,7 @@ class DeformableTransformerDecoder(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int, decoder_layer: nn.Module, num_layers: int, eval_idx: int = -1):
|
||||
"""
|
||||
Initialize the DeformableTransformerDecoder with the given parameters.
|
||||
"""Initialize the DeformableTransformerDecoder with the given parameters.
|
||||
|
||||
Args:
|
||||
hidden_dim (int): Hidden dimension.
|
||||
|
|
@ -759,8 +728,7 @@ class DeformableTransformerDecoder(nn.Module):
|
|||
attn_mask: torch.Tensor | None = None,
|
||||
padding_mask: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
Perform the forward pass through the entire decoder.
|
||||
"""Perform the forward pass through the entire decoder.
|
||||
|
||||
Args:
|
||||
embed (torch.Tensor): Decoder embeddings.
|
||||
|
|
|
|||
|
|
@ -13,8 +13,7 @@ __all__ = "inverse_sigmoid", "multi_scale_deformable_attn_pytorch"
|
|||
|
||||
|
||||
def _get_clones(module, n):
|
||||
"""
|
||||
Create a list of cloned modules from the given module.
|
||||
"""Create a list of cloned modules from the given module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to be cloned.
|
||||
|
|
@ -34,8 +33,7 @@ def _get_clones(module, n):
|
|||
|
||||
|
||||
def bias_init_with_prob(prior_prob=0.01):
|
||||
"""
|
||||
Initialize conv/fc bias value according to a given probability value.
|
||||
"""Initialize conv/fc bias value according to a given probability value.
|
||||
|
||||
This function calculates the bias initialization value based on a prior probability using the inverse error
|
||||
function. It's commonly used in object detection models to initialize classification layers with a specific positive
|
||||
|
|
@ -56,8 +54,7 @@ def bias_init_with_prob(prior_prob=0.01):
|
|||
|
||||
|
||||
def linear_init(module):
|
||||
"""
|
||||
Initialize the weights and biases of a linear module.
|
||||
"""Initialize the weights and biases of a linear module.
|
||||
|
||||
This function initializes the weights of a linear module using a uniform distribution within bounds calculated from
|
||||
the input dimension. If the module has a bias, it is also initialized.
|
||||
|
|
@ -80,8 +77,7 @@ def linear_init(module):
|
|||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-5):
|
||||
"""
|
||||
Calculate the inverse sigmoid function for a tensor.
|
||||
"""Calculate the inverse sigmoid function for a tensor.
|
||||
|
||||
This function applies the inverse of the sigmoid function to a tensor, which is useful in various neural network
|
||||
operations, particularly in attention mechanisms and coordinate transformations.
|
||||
|
|
@ -110,8 +106,7 @@ def multi_scale_deformable_attn_pytorch(
|
|||
sampling_locations: torch.Tensor,
|
||||
attention_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Implement multi-scale deformable attention in PyTorch.
|
||||
"""Implement multi-scale deformable attention in PyTorch.
|
||||
|
||||
This function performs deformable attention across multiple feature map scales, allowing the model to attend to
|
||||
different spatial locations with learned offsets.
|
||||
|
|
@ -119,10 +114,10 @@ def multi_scale_deformable_attn_pytorch(
|
|||
Args:
|
||||
value (torch.Tensor): The value tensor with shape (bs, num_keys, num_heads, embed_dims).
|
||||
value_spatial_shapes (torch.Tensor): Spatial shapes of the value tensor with shape (num_levels, 2).
|
||||
sampling_locations (torch.Tensor): The sampling locations with shape
|
||||
(bs, num_queries, num_heads, num_levels, num_points, 2).
|
||||
attention_weights (torch.Tensor): The attention weights with shape
|
||||
(bs, num_queries, num_heads, num_levels, num_points).
|
||||
sampling_locations (torch.Tensor): The sampling locations with shape (bs, num_queries, num_heads, num_levels,
|
||||
num_points, 2).
|
||||
attention_weights (torch.Tensor): The attention weights with shape (bs, num_queries, num_heads, num_levels,
|
||||
num_points).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The output tensor with shape (bs, num_queries, embed_dims).
|
||||
|
|
|
|||
|
|
@ -95,8 +95,7 @@ from ultralytics.utils.torch_utils import (
|
|||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
"""
|
||||
Base class for all YOLO models in the Ultralytics family.
|
||||
"""Base class for all YOLO models in the Ultralytics family.
|
||||
|
||||
This class provides common functionality for YOLO models including forward pass handling, model fusion, information
|
||||
display, and weight loading capabilities.
|
||||
|
|
@ -121,8 +120,7 @@ class BaseModel(torch.nn.Module):
|
|||
"""
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
"""
|
||||
Perform forward pass of the model for either training or inference.
|
||||
"""Perform forward pass of the model for either training or inference.
|
||||
|
||||
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
|
||||
|
||||
|
|
@ -139,8 +137,7 @@ class BaseModel(torch.nn.Module):
|
|||
return self.predict(x, *args, **kwargs)
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the network.
|
||||
"""Perform a forward pass through the network.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to the model.
|
||||
|
|
@ -157,8 +154,7 @@ class BaseModel(torch.nn.Module):
|
|||
return self._predict_once(x, profile, visualize, embed)
|
||||
|
||||
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the network.
|
||||
"""Perform a forward pass through the network.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to the model.
|
||||
|
|
@ -196,8 +192,7 @@ class BaseModel(torch.nn.Module):
|
|||
return self._predict_once(x)
|
||||
|
||||
def _profile_one_layer(self, m, x, dt):
|
||||
"""
|
||||
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
||||
"""Profile the computation time and FLOPs of a single layer of the model on a given input.
|
||||
|
||||
Args:
|
||||
m (torch.nn.Module): The layer to be profiled.
|
||||
|
|
@ -222,8 +217,7 @@ class BaseModel(torch.nn.Module):
|
|||
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
||||
|
||||
def fuse(self, verbose=True):
|
||||
"""
|
||||
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
||||
"""Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
|
||||
efficiency.
|
||||
|
||||
Returns:
|
||||
|
|
@ -254,8 +248,7 @@ class BaseModel(torch.nn.Module):
|
|||
return self
|
||||
|
||||
def is_fused(self, thresh=10):
|
||||
"""
|
||||
Check if the model has less than a certain threshold of BatchNorm layers.
|
||||
"""Check if the model has less than a certain threshold of BatchNorm layers.
|
||||
|
||||
Args:
|
||||
thresh (int, optional): The threshold number of BatchNorm layers.
|
||||
|
|
@ -267,8 +260,7 @@ class BaseModel(torch.nn.Module):
|
|||
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
||||
|
||||
def info(self, detailed=False, verbose=True, imgsz=640):
|
||||
"""
|
||||
Print model information.
|
||||
"""Print model information.
|
||||
|
||||
Args:
|
||||
detailed (bool): If True, prints out detailed information about the model.
|
||||
|
|
@ -278,8 +270,7 @@ class BaseModel(torch.nn.Module):
|
|||
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
|
||||
|
||||
def _apply(self, fn):
|
||||
"""
|
||||
Apply a function to all tensors in the model that are not parameters or registered buffers.
|
||||
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
|
||||
|
||||
Args:
|
||||
fn (function): The function to apply to the model.
|
||||
|
|
@ -298,8 +289,7 @@ class BaseModel(torch.nn.Module):
|
|||
return self
|
||||
|
||||
def load(self, weights, verbose=True):
|
||||
"""
|
||||
Load weights into the model.
|
||||
"""Load weights into the model.
|
||||
|
||||
Args:
|
||||
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
||||
|
|
@ -324,8 +314,7 @@ class BaseModel(torch.nn.Module):
|
|||
LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")
|
||||
|
||||
def loss(self, batch, preds=None):
|
||||
"""
|
||||
Compute loss.
|
||||
"""Compute loss.
|
||||
|
||||
Args:
|
||||
batch (dict): Batch to compute loss on.
|
||||
|
|
@ -344,8 +333,7 @@ class BaseModel(torch.nn.Module):
|
|||
|
||||
|
||||
class DetectionModel(BaseModel):
|
||||
"""
|
||||
YOLO detection model.
|
||||
"""YOLO detection model.
|
||||
|
||||
This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
|
||||
inference, and loss computation for object detection tasks.
|
||||
|
|
@ -373,8 +361,7 @@ class DetectionModel(BaseModel):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
|
||||
"""
|
||||
Initialize the YOLO detection model with the given config and parameters.
|
||||
"""Initialize the YOLO detection model with the given config and parameters.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Model configuration file path or dictionary.
|
||||
|
|
@ -429,8 +416,7 @@ class DetectionModel(BaseModel):
|
|||
LOGGER.info("")
|
||||
|
||||
def _predict_augment(self, x):
|
||||
"""
|
||||
Perform augmentations on input image x and return augmented inference and train outputs.
|
||||
"""Perform augmentations on input image x and return augmented inference and train outputs.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input image tensor.
|
||||
|
|
@ -455,8 +441,7 @@ class DetectionModel(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _descale_pred(p, flips, scale, img_size, dim=1):
|
||||
"""
|
||||
De-scale predictions following augmented inference (inverse operation).
|
||||
"""De-scale predictions following augmented inference (inverse operation).
|
||||
|
||||
Args:
|
||||
p (torch.Tensor): Predictions tensor.
|
||||
|
|
@ -477,8 +462,7 @@ class DetectionModel(BaseModel):
|
|||
return torch.cat((x, y, wh, cls), dim)
|
||||
|
||||
def _clip_augmented(self, y):
|
||||
"""
|
||||
Clip YOLO augmented inference tails.
|
||||
"""Clip YOLO augmented inference tails.
|
||||
|
||||
Args:
|
||||
y (list[torch.Tensor]): List of detection tensors.
|
||||
|
|
@ -501,8 +485,7 @@ class DetectionModel(BaseModel):
|
|||
|
||||
|
||||
class OBBModel(DetectionModel):
|
||||
"""
|
||||
YOLO Oriented Bounding Box (OBB) model.
|
||||
"""YOLO Oriented Bounding Box (OBB) model.
|
||||
|
||||
This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
|
||||
computation for rotated object detection.
|
||||
|
|
@ -518,8 +501,7 @@ class OBBModel(DetectionModel):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
|
||||
"""
|
||||
Initialize YOLO OBB model with given config and parameters.
|
||||
"""Initialize YOLO OBB model with given config and parameters.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Model configuration file path or dictionary.
|
||||
|
|
@ -535,8 +517,7 @@ class OBBModel(DetectionModel):
|
|||
|
||||
|
||||
class SegmentationModel(DetectionModel):
|
||||
"""
|
||||
YOLO segmentation model.
|
||||
"""YOLO segmentation model.
|
||||
|
||||
This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
|
||||
pixel-level object detection and segmentation.
|
||||
|
|
@ -552,8 +533,7 @@ class SegmentationModel(DetectionModel):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
|
||||
"""
|
||||
Initialize Ultralytics YOLO segmentation model with given config and parameters.
|
||||
"""Initialize Ultralytics YOLO segmentation model with given config and parameters.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Model configuration file path or dictionary.
|
||||
|
|
@ -569,8 +549,7 @@ class SegmentationModel(DetectionModel):
|
|||
|
||||
|
||||
class PoseModel(DetectionModel):
|
||||
"""
|
||||
YOLO pose model.
|
||||
"""YOLO pose model.
|
||||
|
||||
This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
|
||||
keypoint detection and pose estimation.
|
||||
|
|
@ -589,8 +568,7 @@ class PoseModel(DetectionModel):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
||||
"""
|
||||
Initialize Ultralytics YOLO Pose model.
|
||||
"""Initialize Ultralytics YOLO Pose model.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Model configuration file path or dictionary.
|
||||
|
|
@ -612,8 +590,7 @@ class PoseModel(DetectionModel):
|
|||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
"""
|
||||
YOLO classification model.
|
||||
"""YOLO classification model.
|
||||
|
||||
This class implements the YOLO classification architecture for image classification tasks, providing model
|
||||
initialization, configuration, and output reshaping capabilities.
|
||||
|
|
@ -637,8 +614,7 @@ class ClassificationModel(BaseModel):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
|
||||
"""
|
||||
Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
||||
"""Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Model configuration file path or dictionary.
|
||||
|
|
@ -650,8 +626,7 @@ class ClassificationModel(BaseModel):
|
|||
self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
||||
def _from_yaml(self, cfg, ch, nc, verbose):
|
||||
"""
|
||||
Set Ultralytics YOLO model configurations and define the model architecture.
|
||||
"""Set Ultralytics YOLO model configurations and define the model architecture.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Model configuration file path or dictionary.
|
||||
|
|
@ -675,8 +650,7 @@ class ClassificationModel(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def reshape_outputs(model, nc):
|
||||
"""
|
||||
Update a TorchVision classification model to class count 'n' if required.
|
||||
"""Update a TorchVision classification model to class count 'n' if required.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to update.
|
||||
|
|
@ -708,8 +682,7 @@ class ClassificationModel(BaseModel):
|
|||
|
||||
|
||||
class RTDETRDetectionModel(DetectionModel):
|
||||
"""
|
||||
RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
|
||||
"""RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
|
||||
|
||||
This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
|
||||
the training and inference processes. RTDETR is an object detection and tracking model that extends from the
|
||||
|
|
@ -732,8 +705,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
|
||||
"""
|
||||
Initialize the RTDETRDetectionModel.
|
||||
"""Initialize the RTDETRDetectionModel.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Configuration file name or path.
|
||||
|
|
@ -744,8 +716,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def _apply(self, fn):
|
||||
"""
|
||||
Apply a function to all tensors in the model that are not parameters or registered buffers.
|
||||
"""Apply a function to all tensors in the model that are not parameters or registered buffers.
|
||||
|
||||
Args:
|
||||
fn (function): The function to apply to the model.
|
||||
|
|
@ -766,8 +737,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
|
||||
|
||||
def loss(self, batch, preds=None):
|
||||
"""
|
||||
Compute the loss for the given batch of data.
|
||||
"""Compute the loss for the given batch of data.
|
||||
|
||||
Args:
|
||||
batch (dict): Dictionary containing image and label data.
|
||||
|
|
@ -813,8 +783,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
)
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the model.
|
||||
"""Perform a forward pass through the model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
|
@ -849,8 +818,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
|
||||
|
||||
class WorldModel(DetectionModel):
|
||||
"""
|
||||
YOLOv8 World Model.
|
||||
"""YOLOv8 World Model.
|
||||
|
||||
This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
|
||||
specification and CLIP model integration for zero-shot detection capabilities.
|
||||
|
|
@ -874,8 +842,7 @@ class WorldModel(DetectionModel):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
||||
"""
|
||||
Initialize YOLOv8 world model with given config and parameters.
|
||||
"""Initialize YOLOv8 world model with given config and parameters.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Model configuration file path or dictionary.
|
||||
|
|
@ -888,8 +855,7 @@ class WorldModel(DetectionModel):
|
|||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def set_classes(self, text, batch=80, cache_clip_model=True):
|
||||
"""
|
||||
Set classes in advance so that model could do offline-inference without clip model.
|
||||
"""Set classes in advance so that model could do offline-inference without clip model.
|
||||
|
||||
Args:
|
||||
text (list[str]): List of class names.
|
||||
|
|
@ -900,8 +866,7 @@ class WorldModel(DetectionModel):
|
|||
self.model[-1].nc = len(text)
|
||||
|
||||
def get_text_pe(self, text, batch=80, cache_clip_model=True):
|
||||
"""
|
||||
Set classes in advance so that model could do offline-inference without clip model.
|
||||
"""Set classes in advance so that model could do offline-inference without clip model.
|
||||
|
||||
Args:
|
||||
text (list[str]): List of class names.
|
||||
|
|
@ -924,8 +889,7 @@ class WorldModel(DetectionModel):
|
|||
return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the model.
|
||||
"""Perform a forward pass through the model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
|
@ -969,8 +933,7 @@ class WorldModel(DetectionModel):
|
|||
return x
|
||||
|
||||
def loss(self, batch, preds=None):
|
||||
"""
|
||||
Compute loss.
|
||||
"""Compute loss.
|
||||
|
||||
Args:
|
||||
batch (dict): Batch to compute loss on.
|
||||
|
|
@ -985,8 +948,7 @@ class WorldModel(DetectionModel):
|
|||
|
||||
|
||||
class YOLOEModel(DetectionModel):
|
||||
"""
|
||||
YOLOE detection model.
|
||||
"""YOLOE detection model.
|
||||
|
||||
This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
|
||||
both prompt-based and prompt-free inference modes.
|
||||
|
|
@ -1013,8 +975,7 @@ class YOLOEModel(DetectionModel):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
|
||||
"""
|
||||
Initialize YOLOE model with given config and parameters.
|
||||
"""Initialize YOLOE model with given config and parameters.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Model configuration file path or dictionary.
|
||||
|
|
@ -1026,8 +987,7 @@ class YOLOEModel(DetectionModel):
|
|||
|
||||
@smart_inference_mode()
|
||||
def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
|
||||
"""
|
||||
Set classes in advance so that model could do offline-inference without clip model.
|
||||
"""Set classes in advance so that model could do offline-inference without clip model.
|
||||
|
||||
Args:
|
||||
text (list[str]): List of class names.
|
||||
|
|
@ -1059,8 +1019,7 @@ class YOLOEModel(DetectionModel):
|
|||
|
||||
@smart_inference_mode()
|
||||
def get_visual_pe(self, img, visual):
|
||||
"""
|
||||
Get visual embeddings.
|
||||
"""Get visual embeddings.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Input image tensor.
|
||||
|
|
@ -1072,8 +1031,7 @@ class YOLOEModel(DetectionModel):
|
|||
return self(img, vpe=visual, return_vpe=True)
|
||||
|
||||
def set_vocab(self, vocab, names):
|
||||
"""
|
||||
Set vocabulary for the prompt-free model.
|
||||
"""Set vocabulary for the prompt-free model.
|
||||
|
||||
Args:
|
||||
vocab (nn.ModuleList): List of vocabulary items.
|
||||
|
|
@ -1101,8 +1059,7 @@ class YOLOEModel(DetectionModel):
|
|||
self.names = check_class_names(names)
|
||||
|
||||
def get_vocab(self, names):
|
||||
"""
|
||||
Get fused vocabulary layer from the model.
|
||||
"""Get fused vocabulary layer from the model.
|
||||
|
||||
Args:
|
||||
names (list): List of class names.
|
||||
|
|
@ -1127,8 +1084,7 @@ class YOLOEModel(DetectionModel):
|
|||
return vocab
|
||||
|
||||
def set_classes(self, names, embeddings):
|
||||
"""
|
||||
Set classes in advance so that model could do offline-inference without clip model.
|
||||
"""Set classes in advance so that model could do offline-inference without clip model.
|
||||
|
||||
Args:
|
||||
names (list[str]): List of class names.
|
||||
|
|
@ -1143,8 +1099,7 @@ class YOLOEModel(DetectionModel):
|
|||
self.names = check_class_names(names)
|
||||
|
||||
def get_cls_pe(self, tpe, vpe):
|
||||
"""
|
||||
Get class positional embeddings.
|
||||
"""Get class positional embeddings.
|
||||
|
||||
Args:
|
||||
tpe (torch.Tensor, optional): Text positional embeddings.
|
||||
|
|
@ -1167,8 +1122,7 @@ class YOLOEModel(DetectionModel):
|
|||
def predict(
|
||||
self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
|
||||
):
|
||||
"""
|
||||
Perform a forward pass through the model.
|
||||
"""Perform a forward pass through the model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
|
@ -1215,8 +1169,7 @@ class YOLOEModel(DetectionModel):
|
|||
return x
|
||||
|
||||
def loss(self, batch, preds=None):
|
||||
"""
|
||||
Compute loss.
|
||||
"""Compute loss.
|
||||
|
||||
Args:
|
||||
batch (dict): Batch to compute loss on.
|
||||
|
|
@ -1234,8 +1187,7 @@ class YOLOEModel(DetectionModel):
|
|||
|
||||
|
||||
class YOLOESegModel(YOLOEModel, SegmentationModel):
|
||||
"""
|
||||
YOLOE segmentation model.
|
||||
"""YOLOE segmentation model.
|
||||
|
||||
This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
|
||||
specialized loss computation for pixel-level object detection and segmentation.
|
||||
|
|
@ -1251,8 +1203,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
|
||||
"""
|
||||
Initialize YOLOE segmentation model with given config and parameters.
|
||||
"""Initialize YOLOE segmentation model with given config and parameters.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Model configuration file path or dictionary.
|
||||
|
|
@ -1263,8 +1214,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def loss(self, batch, preds=None):
|
||||
"""
|
||||
Compute loss.
|
||||
"""Compute loss.
|
||||
|
||||
Args:
|
||||
batch (dict): Batch to compute loss on.
|
||||
|
|
@ -1282,8 +1232,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
|
|||
|
||||
|
||||
class Ensemble(torch.nn.ModuleList):
|
||||
"""
|
||||
Ensemble of models.
|
||||
"""Ensemble of models.
|
||||
|
||||
This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
|
||||
or other ensemble techniques.
|
||||
|
|
@ -1305,8 +1254,7 @@ class Ensemble(torch.nn.ModuleList):
|
|||
super().__init__()
|
||||
|
||||
def forward(self, x, augment=False, profile=False, visualize=False):
|
||||
"""
|
||||
Generate the YOLO network's final layer.
|
||||
"""Generate the YOLO network's final layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
|
@ -1330,8 +1278,7 @@ class Ensemble(torch.nn.ModuleList):
|
|||
|
||||
@contextlib.contextmanager
|
||||
def temporary_modules(modules=None, attributes=None):
|
||||
"""
|
||||
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
||||
"""Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
||||
|
||||
This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
|
||||
moved a module from one location to another, but you still want to support the old import paths for backwards
|
||||
|
|
@ -1393,8 +1340,7 @@ class SafeUnpickler(pickle.Unpickler):
|
|||
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
||||
|
||||
def find_class(self, module, name):
|
||||
"""
|
||||
Attempt to find a class, returning SafeClass if not among safe modules.
|
||||
"""Attempt to find a class, returning SafeClass if not among safe modules.
|
||||
|
||||
Args:
|
||||
module (str): Module name.
|
||||
|
|
@ -1419,10 +1365,9 @@ class SafeUnpickler(pickle.Unpickler):
|
|||
|
||||
|
||||
def torch_safe_load(weight, safe_only=False):
|
||||
"""
|
||||
Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
|
||||
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
|
||||
After installation, the function again attempts to load the model using torch.load().
|
||||
"""Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
|
||||
the error, logs a warning message, and attempts to install the missing module via the check_requirements()
|
||||
function. After installation, the function again attempts to load the model using torch.load().
|
||||
|
||||
Args:
|
||||
weight (str): The file path of the PyTorch model.
|
||||
|
|
@ -1501,8 +1446,7 @@ def torch_safe_load(weight, safe_only=False):
|
|||
|
||||
|
||||
def load_checkpoint(weight, device=None, inplace=True, fuse=False):
|
||||
"""
|
||||
Load a single model weights.
|
||||
"""Load a single model weights.
|
||||
|
||||
Args:
|
||||
weight (str | Path): Model weight path.
|
||||
|
|
@ -1539,8 +1483,7 @@ def load_checkpoint(weight, device=None, inplace=True, fuse=False):
|
|||
|
||||
|
||||
def parse_model(d, ch, verbose=True):
|
||||
"""
|
||||
Parse a YOLO model.yaml dictionary into a PyTorch model.
|
||||
"""Parse a YOLO model.yaml dictionary into a PyTorch model.
|
||||
|
||||
Args:
|
||||
d (dict): Model dictionary.
|
||||
|
|
@ -1718,8 +1661,7 @@ def parse_model(d, ch, verbose=True):
|
|||
|
||||
|
||||
def yaml_model_load(path):
|
||||
"""
|
||||
Load a YOLOv8 model from a YAML file.
|
||||
"""Load a YOLOv8 model from a YAML file.
|
||||
|
||||
Args:
|
||||
path (str | Path): Path to the YAML file.
|
||||
|
|
@ -1742,8 +1684,7 @@ def yaml_model_load(path):
|
|||
|
||||
|
||||
def guess_model_scale(model_path):
|
||||
"""
|
||||
Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
||||
"""Extract the size character n, s, m, l, or x of the model's scale from the model path.
|
||||
|
||||
Args:
|
||||
model_path (str | Path): The path to the YOLO model's YAML file.
|
||||
|
|
@ -1758,8 +1699,7 @@ def guess_model_scale(model_path):
|
|||
|
||||
|
||||
def guess_model_task(model):
|
||||
"""
|
||||
Guess the task of a PyTorch model from its architecture or configuration.
|
||||
"""Guess the task of a PyTorch model from its architecture or configuration.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
|
||||
|
|
|
|||
|
|
@ -20,8 +20,7 @@ except ImportError:
|
|||
|
||||
|
||||
class TextModel(nn.Module):
|
||||
"""
|
||||
Abstract base class for text encoding models.
|
||||
"""Abstract base class for text encoding models.
|
||||
|
||||
This class defines the interface for text encoding models used in vision-language tasks. Subclasses must implement
|
||||
the tokenize and encode_text methods to provide text tokenization and encoding functionality.
|
||||
|
|
@ -47,8 +46,7 @@ class TextModel(nn.Module):
|
|||
|
||||
|
||||
class CLIP(TextModel):
|
||||
"""
|
||||
Implements OpenAI's CLIP (Contrastive Language-Image Pre-training) text encoder.
|
||||
"""Implements OpenAI's CLIP (Contrastive Language-Image Pre-training) text encoder.
|
||||
|
||||
This class provides a text encoder based on OpenAI's CLIP model, which can convert text into feature vectors that
|
||||
are aligned with corresponding image features in a shared embedding space.
|
||||
|
|
@ -71,8 +69,7 @@ class CLIP(TextModel):
|
|||
"""
|
||||
|
||||
def __init__(self, size: str, device: torch.device) -> None:
|
||||
"""
|
||||
Initialize the CLIP text encoder.
|
||||
"""Initialize the CLIP text encoder.
|
||||
|
||||
This class implements the TextModel interface using OpenAI's CLIP model for text encoding. It loads a
|
||||
pre-trained CLIP model of the specified size and prepares it for text encoding tasks.
|
||||
|
|
@ -93,8 +90,7 @@ class CLIP(TextModel):
|
|||
self.eval()
|
||||
|
||||
def tokenize(self, texts: str | list[str]) -> torch.Tensor:
|
||||
"""
|
||||
Convert input texts to CLIP tokens.
|
||||
"""Convert input texts to CLIP tokens.
|
||||
|
||||
Args:
|
||||
texts (str | list[str]): Input text or list of texts to tokenize.
|
||||
|
|
@ -111,8 +107,7 @@ class CLIP(TextModel):
|
|||
|
||||
@smart_inference_mode()
|
||||
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
"""
|
||||
Encode tokenized texts into normalized feature vectors.
|
||||
"""Encode tokenized texts into normalized feature vectors.
|
||||
|
||||
This method processes tokenized text inputs through the CLIP model to generate feature vectors, which are then
|
||||
normalized to unit length. These normalized vectors can be used for text-image similarity comparisons.
|
||||
|
|
@ -137,15 +132,14 @@ class CLIP(TextModel):
|
|||
|
||||
@smart_inference_mode()
|
||||
def encode_image(self, image: Image.Image | torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
"""
|
||||
Encode preprocessed images into normalized feature vectors.
|
||||
"""Encode preprocessed images into normalized feature vectors.
|
||||
|
||||
This method processes preprocessed image inputs through the CLIP model to generate feature vectors, which are
|
||||
then normalized to unit length. These normalized vectors can be used for text-image similarity comparisons.
|
||||
|
||||
Args:
|
||||
image (PIL.Image | torch.Tensor): Preprocessed image input. If a PIL Image is provided, it will be
|
||||
converted to a tensor using the model's image preprocessing function.
|
||||
image (PIL.Image | torch.Tensor): Preprocessed image input. If a PIL Image is provided, it will be converted
|
||||
to a tensor using the model's image preprocessing function.
|
||||
dtype (torch.dtype, optional): Data type for output features.
|
||||
|
||||
Returns:
|
||||
|
|
@ -169,8 +163,7 @@ class CLIP(TextModel):
|
|||
|
||||
|
||||
class MobileCLIP(TextModel):
|
||||
"""
|
||||
Implement Apple's MobileCLIP text encoder for efficient text encoding.
|
||||
"""Implement Apple's MobileCLIP text encoder for efficient text encoding.
|
||||
|
||||
This class implements the TextModel interface using Apple's MobileCLIP model, providing efficient text encoding
|
||||
capabilities for vision-language tasks with reduced computational requirements compared to standard CLIP models.
|
||||
|
|
@ -195,8 +188,7 @@ class MobileCLIP(TextModel):
|
|||
config_size_map = {"s0": "s0", "s1": "s1", "s2": "s2", "b": "b", "blt": "b"}
|
||||
|
||||
def __init__(self, size: str, device: torch.device) -> None:
|
||||
"""
|
||||
Initialize the MobileCLIP text encoder.
|
||||
"""Initialize the MobileCLIP text encoder.
|
||||
|
||||
This class implements the TextModel interface using Apple's MobileCLIP model for efficient text encoding.
|
||||
|
||||
|
|
@ -236,8 +228,7 @@ class MobileCLIP(TextModel):
|
|||
self.eval()
|
||||
|
||||
def tokenize(self, texts: list[str]) -> torch.Tensor:
|
||||
"""
|
||||
Convert input texts to MobileCLIP tokens.
|
||||
"""Convert input texts to MobileCLIP tokens.
|
||||
|
||||
Args:
|
||||
texts (list[str]): List of text strings to tokenize.
|
||||
|
|
@ -253,8 +244,7 @@ class MobileCLIP(TextModel):
|
|||
|
||||
@smart_inference_mode()
|
||||
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
"""
|
||||
Encode tokenized texts into normalized feature vectors.
|
||||
"""Encode tokenized texts into normalized feature vectors.
|
||||
|
||||
Args:
|
||||
texts (torch.Tensor): Tokenized text inputs.
|
||||
|
|
@ -276,8 +266,7 @@ class MobileCLIP(TextModel):
|
|||
|
||||
|
||||
class MobileCLIPTS(TextModel):
|
||||
"""
|
||||
Load a TorchScript traced version of MobileCLIP.
|
||||
"""Load a TorchScript traced version of MobileCLIP.
|
||||
|
||||
This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format, providing
|
||||
efficient text encoding capabilities for vision-language tasks with optimized inference performance.
|
||||
|
|
@ -299,8 +288,7 @@ class MobileCLIPTS(TextModel):
|
|||
"""
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
"""
|
||||
Initialize the MobileCLIP TorchScript text encoder.
|
||||
"""Initialize the MobileCLIP TorchScript text encoder.
|
||||
|
||||
This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format for efficient
|
||||
text encoding with optimized inference performance.
|
||||
|
|
@ -321,8 +309,7 @@ class MobileCLIPTS(TextModel):
|
|||
self.device = device
|
||||
|
||||
def tokenize(self, texts: list[str]) -> torch.Tensor:
|
||||
"""
|
||||
Convert input texts to MobileCLIP tokens.
|
||||
"""Convert input texts to MobileCLIP tokens.
|
||||
|
||||
Args:
|
||||
texts (list[str]): List of text strings to tokenize.
|
||||
|
|
@ -338,8 +325,7 @@ class MobileCLIPTS(TextModel):
|
|||
|
||||
@smart_inference_mode()
|
||||
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
"""
|
||||
Encode tokenized texts into normalized feature vectors.
|
||||
"""Encode tokenized texts into normalized feature vectors.
|
||||
|
||||
Args:
|
||||
texts (torch.Tensor): Tokenized text inputs.
|
||||
|
|
@ -360,8 +346,7 @@ class MobileCLIPTS(TextModel):
|
|||
|
||||
|
||||
def build_text_model(variant: str, device: torch.device = None) -> TextModel:
|
||||
"""
|
||||
Build a text encoding model based on the specified variant.
|
||||
"""Build a text encoding model based on the specified variant.
|
||||
|
||||
Args:
|
||||
variant (str): Model variant in format "base:size" (e.g., "clip:ViT-B/32" or "mobileclip:s0").
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@ from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, Sol
|
|||
|
||||
|
||||
class AIGym(BaseSolution):
|
||||
"""
|
||||
A class to manage gym steps of people in a real-time video stream based on their poses.
|
||||
"""A class to manage gym steps of people in a real-time video stream based on their poses.
|
||||
|
||||
This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts
|
||||
repetitions of exercises based on predefined angle thresholds for up and down positions.
|
||||
|
|
@ -32,8 +31,7 @@ class AIGym(BaseSolution):
|
|||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize AIGym for workout monitoring using pose estimation and predefined angles.
|
||||
"""Initialize AIGym for workout monitoring using pose estimation and predefined angles.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Keyword arguments passed to the parent class constructor including:
|
||||
|
|
@ -49,8 +47,7 @@ class AIGym(BaseSolution):
|
|||
self.kpts = self.CFG["kpts"] # User selected kpts of workouts storage for further usage
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Monitor workouts using Ultralytics YOLO Pose Model.
|
||||
"""Monitor workouts using Ultralytics YOLO Pose Model.
|
||||
|
||||
This function processes an input image to track and analyze human poses for workout monitoring. It uses the YOLO
|
||||
Pose model to detect keypoints, estimate angles, and count repetitions based on predefined angle thresholds.
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ from ultralytics.solutions.solutions import BaseSolution, SolutionResults # Imp
|
|||
|
||||
|
||||
class Analytics(BaseSolution):
|
||||
"""
|
||||
A class for creating and updating various types of charts for visual analytics.
|
||||
"""A class for creating and updating various types of charts for visual analytics.
|
||||
|
||||
This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts based on
|
||||
object detection and tracking data.
|
||||
|
|
@ -92,8 +91,7 @@ class Analytics(BaseSolution):
|
|||
self.ax.axis("equal")
|
||||
|
||||
def process(self, im0: np.ndarray, frame_number: int) -> SolutionResults:
|
||||
"""
|
||||
Process image data and run object tracking to update analytics charts.
|
||||
"""Process image data and run object tracking to update analytics charts.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): Input image for processing.
|
||||
|
|
@ -139,13 +137,12 @@ class Analytics(BaseSolution):
|
|||
def update_graph(
|
||||
self, frame_number: int, count_dict: dict[str, int] | None = None, plot: str = "line"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Update the graph with new data for single or multiple classes.
|
||||
"""Update the graph with new data for single or multiple classes.
|
||||
|
||||
Args:
|
||||
frame_number (int): The current frame number.
|
||||
count_dict (dict[str, int], optional): Dictionary with class names as keys and counts as values for
|
||||
multiple classes. If None, updates a single line graph.
|
||||
count_dict (dict[str, int], optional): Dictionary with class names as keys and counts as values for multiple
|
||||
classes. If None, updates a single line graph.
|
||||
plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'.
|
||||
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@ import cv2
|
|||
|
||||
@dataclass
|
||||
class SolutionConfig:
|
||||
"""
|
||||
Manages configuration parameters for Ultralytics Vision AI solutions.
|
||||
"""Manages configuration parameters for Ultralytics Vision AI solutions.
|
||||
|
||||
The SolutionConfig class serves as a centralized configuration container for all the Ultralytics solution modules:
|
||||
https://docs.ultralytics.com/solutions/#solutions. It leverages Python `dataclass` for clear, type-safe, and
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@ from ultralytics.utils.plotting import colors
|
|||
|
||||
|
||||
class DistanceCalculation(BaseSolution):
|
||||
"""
|
||||
A class to calculate distance between two objects in a real-time video stream based on their tracks.
|
||||
"""A class to calculate distance between two objects in a real-time video stream based on their tracks.
|
||||
|
||||
This class extends BaseSolution to provide functionality for selecting objects and calculating the distance between
|
||||
them in a video stream using YOLO object detection and tracking.
|
||||
|
|
@ -43,8 +42,7 @@ class DistanceCalculation(BaseSolution):
|
|||
self.centroids: list[list[int]] = [] # Store centroids of selected objects
|
||||
|
||||
def mouse_event_for_distance(self, event: int, x: int, y: int, flags: int, param: Any) -> None:
|
||||
"""
|
||||
Handle mouse events to select regions in a real-time video stream for distance calculation.
|
||||
"""Handle mouse events to select regions in a real-time video stream for distance calculation.
|
||||
|
||||
Args:
|
||||
event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN).
|
||||
|
|
@ -69,8 +67,7 @@ class DistanceCalculation(BaseSolution):
|
|||
self.left_mouse_count = 0
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Process a video frame and calculate the distance between two selected bounding boxes.
|
||||
"""Process a video frame and calculate the distance between two selected bounding boxes.
|
||||
|
||||
This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance between
|
||||
two user-selected objects if they have been chosen.
|
||||
|
|
@ -79,8 +76,8 @@ class DistanceCalculation(BaseSolution):
|
|||
im0 (np.ndarray): The input image frame to process.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `plot_im`, `total_tracks` (int) representing the total number
|
||||
of tracked objects, and `pixels_distance` (float) representing the distance between selected objects
|
||||
(SolutionResults): Contains processed image `plot_im`, `total_tracks` (int) representing the total number of
|
||||
tracked objects, and `pixels_distance` (float) representing the distance between selected objects
|
||||
in pixels.
|
||||
|
||||
Examples:
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ from ultralytics.solutions.solutions import SolutionAnnotator, SolutionResults
|
|||
|
||||
|
||||
class Heatmap(ObjectCounter):
|
||||
"""
|
||||
A class to draw heatmaps in real-time video streams based on object tracks.
|
||||
"""A class to draw heatmaps in real-time video streams based on object tracks.
|
||||
|
||||
This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video
|
||||
streams. It uses tracked object positions to create a cumulative heatmap effect over time.
|
||||
|
|
@ -36,8 +35,7 @@ class Heatmap(ObjectCounter):
|
|||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the Heatmap class for real-time video stream heatmap generation based on object tracks.
|
||||
"""Initialize the Heatmap class for real-time video stream heatmap generation based on object tracks.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Keyword arguments passed to the parent ObjectCounter class.
|
||||
|
|
@ -53,8 +51,7 @@ class Heatmap(ObjectCounter):
|
|||
self.heatmap = None
|
||||
|
||||
def heatmap_effect(self, box: list[float]) -> None:
|
||||
"""
|
||||
Efficiently calculate heatmap area and effect location for applying colormap.
|
||||
"""Efficiently calculate heatmap area and effect location for applying colormap.
|
||||
|
||||
Args:
|
||||
box (list[float]): Bounding box coordinates [x0, y0, x1, y1].
|
||||
|
|
@ -75,17 +72,15 @@ class Heatmap(ObjectCounter):
|
|||
self.heatmap[y0:y1, x0:x1][within_radius] += 2
|
||||
|
||||
def process(self, im0: np.ndarray) -> SolutionResults:
|
||||
"""
|
||||
Generate heatmap for each frame using Ultralytics tracking.
|
||||
"""Generate heatmap for each frame using Ultralytics tracking.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): Input image array for processing.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `plot_im`,
|
||||
'in_count' (int, count of objects entering the region), 'out_count' (int, count of objects exiting the
|
||||
region), 'classwise_count' (dict, per-class object count), and 'total_tracks' (int, total number of
|
||||
tracked objects).
|
||||
(SolutionResults): Contains processed image `plot_im`, 'in_count' (int, count of objects entering the
|
||||
region), 'out_count' (int, count of objects exiting the region), 'classwise_count' (dict, per-class
|
||||
object count), and 'total_tracks' (int, total number of tracked objects).
|
||||
"""
|
||||
if not self.initialized:
|
||||
self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue