diff --git a/docs/en/models/sam-3.md b/docs/en/models/sam-3.md
index 9ab1a5eb75..4eb1f9ab30 100644
--- a/docs/en/models/sam-3.md
+++ b/docs/en/models/sam-3.md
@@ -114,13 +114,17 @@ SAM 3 will be available directly in the Ultralytics package once integration lan
pip install ultralytics
```
-Models will download automatically when first used. You can then use standard [predict mode](../modes/predict.md) and later [export](../modes/export.md) models to formats like [ONNX](../integrations/onnx.md) and [TensorRT](../integrations/tensorrt.md) for deployment. Watch for a package update with SAM-3 weights and configs soon.
+!!! warning "SAM 3 Model Weights Required"
+
+ Unlike other Ultralytics models, SAM 3 weights (`sam3.pt`) are **not automatically downloaded**. You must manually download the model weights from the [official SAM 3 repository](https://github.com/facebookresearch/sam3) before using SAM 3. Place the downloaded `sam3.pt` file in your working directory or specify the full path when loading the model.
+
+!!! note "BPE Vocabulary for Text Prompts"
+
+ If you plan to use text-based concept segmentation with `SAM3SemanticPredictor`, you also need to download the BPE vocabulary file `bpe_simple_vocab_16e6.txt.gz` from the [SAM 3 assets](https://github.com/facebookresearch/sam3/blob/main/assets/bpe_simple_vocab_16e6.txt.gz).
## How to Use SAM 3: Versatility in Concept Segmentation
-!!! warning "Ultralytics API preview"
-
- The following examples show the intended Ultralytics API once SAM 3 ships in the package. Until integration lands, details may change.
+SAM 3 supports both Promptable Concept Segmentation (PCS) and Promptable Visual Segmentation (PVS) tasks through different predictor interfaces.
### Supported Tasks and Models
@@ -138,143 +142,176 @@ SAM 3 supports both Promptable Concept Segmentation (PCS) and Promptable Visual
!!! example "Text-based Concept Segmentation"
- Find and segment all instances of a concept using a text description.
+ Find and segment all instances of a concept using a text description. Text prompts require the `SAM3SemanticPredictor` interface.
=== "Python"
```python
- from ultralytics import SAM
+ from ultralytics.models.sam.predict import SAM3SemanticPredictor
- # Load SAM 3 model
- model = SAM("sam3.pt")
+ # Initialize predictor with configuration
+ overrides = dict(
+ conf=0.25,
+ task="segment",
+ mode="predict",
+ model="sam3.pt",
+ half=True, # Use FP16 for faster inference
+ )
+ predictor = SAM3SemanticPredictor(
+ overrides=overrides,
+ bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz", # Required for text encoding
+ )
- # Segment all instances of a concept
- results = model("path/to/image.jpg", prompt="yellow school bus")
+ # Set image once for multiple queries
+ predictor.set_image("path/to/image.jpg")
+
+ # Query with multiple text prompts
+ results = predictor(text=["person", "bus", "glasses"], save=True)
# Works with descriptive phrases
- results = model("path/to/image.jpg", prompt="person wearing a red hat")
+ results = predictor(text=["person with red cloth", "person with blue cloth"], save=True)
- # Or simple object names
- results = model("path/to/image.jpg", prompt="striped cat")
+ # Query with a single concept
+ results = predictor(text=["a person"], save=True)
```
- === "CLI"
+ !!! note "Text Encoding Requirement"
- ```bash
- # Segment all matching concepts in an image
- yolo segment model=sam3.pt source=path/to/image.jpg prompt="yellow school bus"
- ```
-
- !!! warning "API Preview"
-
- This example shows intended usage. Actual implementation pending Ultralytics integration.
+ The `bpe_path` parameter is required for text prompt encoding. Download the BPE vocabulary file from the [bpe_simple_vocab_16e6.txt.gz](https://github.com/facebookresearch/sam3/blob/main/assets/bpe_simple_vocab_16e6.txt.gz).
#### Segment with Image Exemplars
!!! example "Image Exemplar-based Segmentation"
- Use one or more example objects to find all similar instances.
+ Use bounding boxes as visual prompts to find all similar instances. This also requires `SAM3SemanticPredictor` for concept-based matching.
=== "Python"
```python
- from ultralytics import SAM
+ from ultralytics.models.sam.predict import SAM3SemanticPredictor
- model = SAM("sam3.pt")
+ # Initialize predictor
+ overrides = dict(conf=0.25, task="segment", mode="predict", model="sam3.pt", half=True)
+ predictor = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
- # Provide a positive example box - finds all similar objects
- results = model("path/to/image.jpg", bboxes=[100, 150, 300, 400], labels=[1])
+ # Set image
+ predictor.set_image("path/to/image.jpg")
- # Add negative examples to exclude certain instances
- results = model(
- "path/to/image.jpg",
- bboxes=[[100, 150, 300, 400], [500, 200, 600, 350]], # Two boxes
- labels=[1, 0], # First is positive, second is negative
- )
+ # Provide bounding box examples to segment similar objects
+ results = predictor(bboxes=[[480.0, 290.0, 590.0, 650.0]], save=True)
- # Combine text and image exemplars for precision
- results = model("path/to/image.jpg", prompt="dog", bboxes=[100, 150, 300, 400], labels=[1])
+ # Multiple bounding boxes for different concepts
+ results = predictor(bboxes=[[539, 599, 589, 639], [343, 267, 499, 662]], save=True)
```
- !!! warning "API Preview"
+#### Feature-based Inference for Efficiency
- This example shows intended usage. Actual implementation pending Ultralytics integration.
+!!! example "Reusing Image Features for Multiple Queries"
-#### Interactive Refinement
-
-!!! example "Iterative Refinement with Exemplars"
-
- Progressively improve results by adding exemplar prompts based on initial output.
+ Extract image features once and reuse them for multiple segmentation queries to improve efficiency.
=== "Python"
```python
- from ultralytics import SAM
+ import cv2
- model = SAM("sam3.pt")
+ from ultralytics.models.sam.predict import SAM3SemanticPredictor
+ from ultralytics.utils.plotting import Annotator, colors
- # Initial segmentation with text
- results = model("path/to/image.jpg", prompt="car")
+ # Initialize predictors
+ overrides = dict(conf=0.50, task="segment", mode="predict", model="sam3.pt", verbose=False)
+ predictor = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
+ predictor2 = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
- # If some cars are missed, add a positive exemplar
- results = model(
- "path/to/image.jpg",
- prompt="car",
- bboxes=[missed_car_box],
- labels=[1], # Positive example
- )
+ # Extract features from the first predictor
+ source = "path/to/image.jpg"
+ predictor.set_image(source)
+ src_shape = cv2.imread(source).shape[:2]
- # If false positives appear, add negative exemplars
- results = model(
- "path/to/image.jpg",
- prompt="car",
- bboxes=[false_positive_box],
- labels=[0], # Negative example
- )
+ # Setup second predictor and reuse features
+ predictor2.setup_model()
+
+ # Perform inference using shared features with text prompt
+ masks, boxes = predictor2.inference_features(predictor.features, src_shape=src_shape, text=["person"])
+
+ # Perform inference using shared features with bounding box prompt
+ masks, boxes = predictor2.inference_features(predictor.features, src_shape=src_shape, bboxes=[439, 437, 524, 709])
+
+ # Visualize results
+ masks, boxes = masks.cpu().numpy(), boxes.cpu().numpy()
+ im = cv2.imread(source)
+ annotator = Annotator(im, pil=False)
+ annotator.masks(masks, [colors(x, True) for x in range(len(masks))])
+
+ cv2.imshow("result", annotator.result())
+ cv2.waitKey(0)
```
- !!! warning "API Preview"
-
- This example shows intended usage. Actual implementation pending Ultralytics integration.
-
### Video Concept Segmentation
-!!! example "Track Concepts Across Video"
+#### Track Concepts Across Video with Bounding Boxes
- Detect and track all instances of a concept throughout a video.
+!!! example "Video Tracking with Visual Prompts"
+
+ Detect and track object instances across video frames using bounding box prompts.
=== "Python"
```python
- from ultralytics.models.sam import SAM3VideoPredictor
+ from ultralytics.models.sam.predict import SAM3VideoPredictor
# Create video predictor
- predictor = SAM3VideoPredictor(model="sam3.pt", imgsz=1024, conf=0.25)
+ overrides = dict(conf=0.25, task="segment", mode="predict", model="sam3.pt", half=True)
+ predictor = SAM3VideoPredictor(overrides=overrides)
- # Track all instances of a concept
- results = predictor(source="video.mp4", prompt="person wearing blue shirt")
+ # Track objects using bounding box prompts
+ results = predictor(source="path/to/video.mp4", bboxes=[[706.5, 442.5, 905.25, 555], [598, 635, 725, 750]], stream=True)
- # Combine text with exemplar for precision
+ # Process and display results
+ for r in results:
+ r.show() # Display frame with segmentation masks
+ ```
+
+#### Track Concepts with Text Prompts
+
+!!! example "Video Tracking with Semantic Queries"
+
+ Track all instances of concepts specified by text across video frames.
+
+ === "Python"
+
+ ```python
+ from ultralytics.models.sam.sam3_video_model import SAM3VideoSemanticPredictor
+
+ # Initialize semantic video predictor
+ overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=640, model="sam3.pt", half=True)
+ predictor = SAM3VideoSemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
+
+ # Track concepts using text prompts
+ results = predictor(source="path/to/video.mp4", text=["person", "bicycle"], stream=True, save=True)
+
+ # Process results
+ for r in results:
+ r.show() # Display frame with tracked objects
+
+ # Alternative: Track with bounding box prompts
results = predictor(
- source="video.mp4",
- prompt="kangaroo",
- bboxes=[initial_box], # Exemplar from first frame
- labels=[1],
+ source="path/to/video.mp4",
+ bboxes=[[864, 383, 975, 620], [705, 229, 782, 402]],
+ labels=[1, 1], # Positive labels
+ stream=True,
+ save=True,
)
```
- !!! warning "API Preview"
-
- This example shows intended usage. Actual implementation pending Ultralytics integration.
-
-For broader streaming and production setups, see [object tracking](../guides/object-counting.md) and [view results in terminal](../guides/view-results-in-terminal.md).
-
### Visual Prompts (SAM 2 Compatibility)
-SAM 3 maintains full backward compatibility with SAM 2's visual prompting:
+SAM 3 maintains full backward compatibility with SAM 2's visual prompting for single-object segmentation:
!!! example "SAM 2 Style Visual Prompts"
+ The basic `SAM` interface behaves exactly like SAM 2, segmenting only the specific area indicated by visual prompts (points, boxes, or masks).
+
=== "Python"
```python
@@ -282,19 +319,21 @@ SAM 3 maintains full backward compatibility with SAM 2's visual prompting:
model = SAM("sam3.pt")
- # Single point prompt (SAM 2 style)
- results = model(points=[900, 370], labels=[1])
+ # Single point prompt - segments object at specific location
+ results = model.predict(source="path/to/image.jpg", points=[900, 370], labels=[1])
+ results[0].show()
- # Multiple points
- results = model(points=[[400, 370], [900, 370]], labels=[1, 1])
+ # Multiple points - segments single object with multiple point hints
+ results = model.predict(source="path/to/image.jpg", points=[[400, 370], [900, 370]], labels=[1, 1])
- # Box prompt
- results = model(bboxes=[100, 150, 300, 400])
+ # Box prompt - segments object within bounding box
+ results = model.predict(source="path/to/image.jpg", bboxes=[100, 150, 300, 400])
+ results[0].show()
```
- !!! warning "API Preview"
+ !!! warning "Visual Prompts vs Concept Segmentation"
- This example shows intended usage. Actual implementation pending Ultralytics integration.
+ Using `SAM("sam3.pt")` with visual prompts (points/boxes/masks) will segment **only the specific object** at that location, just like SAM 2. To segment **all instances of a concept**, use `SAM3SemanticPredictor` with text or exemplar prompts as shown above.
## Performance Benchmarks
diff --git a/docs/en/reference/models/sam/build.md b/docs/en/reference/models/sam/build.md
index ce233acfb3..a2dbc09cdf 100644
--- a/docs/en/reference/models/sam/build.md
+++ b/docs/en/reference/models/sam/build.md
@@ -11,6 +11,10 @@ keywords: Ultralytics, SAM model, Segment Anything Model, SAM 2 model, Segment A
+## ::: ultralytics.models.sam.build._load_checkpoint
+
+
+
## ::: ultralytics.models.sam.build.build_sam_vit_h
diff --git a/docs/en/reference/models/sam/build_sam3.md b/docs/en/reference/models/sam/build_sam3.md
new file mode 100644
index 0000000000..8b2ec7298d
--- /dev/null
+++ b/docs/en/reference/models/sam/build_sam3.md
@@ -0,0 +1,32 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/build_sam3.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/build_sam3.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/build_sam3.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.build_sam3._create_vision_backbone
+
+
+
+## ::: ultralytics.models.sam.build_sam3._create_sam3_transformer
+
+
+
+## ::: ultralytics.models.sam.build_sam3.build_sam3_image_model
+
+
+
+## ::: ultralytics.models.sam.build_sam3.build_interactive_sam3
+
+
+
+## ::: ultralytics.models.sam.build_sam3._load_checkpoint
+
+
diff --git a/docs/en/reference/models/sam/modules/sam.md b/docs/en/reference/models/sam/modules/sam.md
index 0a1b61e6e9..d69ea2b75f 100644
--- a/docs/en/reference/models/sam/modules/sam.md
+++ b/docs/en/reference/models/sam/modules/sam.md
@@ -17,4 +17,8 @@ keywords: Ultralytics, SAM Module, SAM 2 Module, object segmentation, image enco
## ::: ultralytics.models.sam.modules.sam.SAM2Model
+
+
+## ::: ultralytics.models.sam.modules.sam.SAM3Model
+
diff --git a/docs/en/reference/models/sam/modules/utils.md b/docs/en/reference/models/sam/modules/utils.md
index 97ead6dc67..7241ef0404 100644
--- a/docs/en/reference/models/sam/modules/utils.md
+++ b/docs/en/reference/models/sam/modules/utils.md
@@ -49,4 +49,12 @@ keywords: Ultralytics, SAM, SAM 2, API Reference, models, window partition, data
## ::: ultralytics.models.sam.modules.utils.add_decomposed_rel_pos
+
+
+## ::: ultralytics.models.sam.modules.utils.get_abs_pos
+
+
+
+## ::: ultralytics.models.sam.modules.utils.concat_rel_pos
+
diff --git a/docs/en/reference/models/sam/predict.md b/docs/en/reference/models/sam/predict.md
index c5b08777a9..bb5096c048 100644
--- a/docs/en/reference/models/sam/predict.md
+++ b/docs/en/reference/models/sam/predict.md
@@ -25,4 +25,20 @@ keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Mode
## ::: ultralytics.models.sam.predict.SAM2DynamicInteractivePredictor
+
+
+## ::: ultralytics.models.sam.predict.SAM3Predictor
+
+
+
+## ::: ultralytics.models.sam.predict.SAM3SemanticPredictor
+
+
+
+## ::: ultralytics.models.sam.predict.SAM3VideoPredictor
+
+
+
+## ::: ultralytics.models.sam.predict.SAM3VideoSemanticPredictor
+
diff --git a/docs/en/reference/models/sam/sam3/decoder.md b/docs/en/reference/models/sam/sam3/decoder.md
new file mode 100644
index 0000000000..fdbf8c4958
--- /dev/null
+++ b/docs/en/reference/models/sam/sam3/decoder.md
@@ -0,0 +1,20 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/sam3/decoder.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/decoder.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/decoder.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer
+
+
+
+## ::: ultralytics.models.sam.sam3.decoder.TransformerDecoder
+
+
diff --git a/docs/en/reference/models/sam/sam3/encoder.md b/docs/en/reference/models/sam/sam3/encoder.md
new file mode 100644
index 0000000000..2751822db4
--- /dev/null
+++ b/docs/en/reference/models/sam/sam3/encoder.md
@@ -0,0 +1,28 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/sam3/encoder.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/encoder.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/encoder.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoderLayer
+
+
+
+## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoder
+
+
+
+## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoderFusion
+
+
+
+## ::: ultralytics.models.sam.sam3.encoder.pool_text_feat
+
+
diff --git a/docs/en/reference/models/sam/sam3/geometry_encoders.md b/docs/en/reference/models/sam/sam3/geometry_encoders.md
new file mode 100644
index 0000000000..1b3eb75139
--- /dev/null
+++ b/docs/en/reference/models/sam/sam3/geometry_encoders.md
@@ -0,0 +1,28 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/sam3/geometry_encoders.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/geometry_encoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/geometry_encoders.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.sam3.geometry_encoders.Prompt
+
+
+
+## ::: ultralytics.models.sam.sam3.geometry_encoders.SequenceGeometryEncoder
+
+
+
+## ::: ultralytics.models.sam.sam3.geometry_encoders.is_right_padded
+
+
+
+## ::: ultralytics.models.sam.sam3.geometry_encoders.concat_padded_sequences
+
+
diff --git a/docs/en/reference/models/sam/sam3/maskformer_segmentation.md b/docs/en/reference/models/sam/sam3/maskformer_segmentation.md
new file mode 100644
index 0000000000..7218d86ecb
--- /dev/null
+++ b/docs/en/reference/models/sam/sam3/maskformer_segmentation.md
@@ -0,0 +1,32 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/sam3/maskformer_segmentation.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/maskformer_segmentation.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/maskformer_segmentation.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.sam3.maskformer_segmentation.LinearPresenceHead
+
+
+
+## ::: ultralytics.models.sam.sam3.maskformer_segmentation.MaskPredictor
+
+
+
+## ::: ultralytics.models.sam.sam3.maskformer_segmentation.SegmentationHead
+
+
+
+## ::: ultralytics.models.sam.sam3.maskformer_segmentation.PixelDecoder
+
+
+
+## ::: ultralytics.models.sam.sam3.maskformer_segmentation.UniversalSegmentationHead
+
+
diff --git a/docs/en/reference/models/sam/sam3/model_misc.md b/docs/en/reference/models/sam/sam3/model_misc.md
new file mode 100644
index 0000000000..bef97aa7d5
--- /dev/null
+++ b/docs/en/reference/models/sam/sam3/model_misc.md
@@ -0,0 +1,32 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/sam3/model_misc.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/model_misc.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/model_misc.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.sam3.model_misc.DotProductScoring
+
+
+
+## ::: ultralytics.models.sam.sam3.model_misc.LayerScale
+
+
+
+## ::: ultralytics.models.sam.sam3.model_misc.TransformerWrapper
+
+
+
+## ::: ultralytics.models.sam.sam3.model_misc.get_valid_ratio
+
+
+
+## ::: ultralytics.models.sam.sam3.model_misc.gen_sineembed_for_position
+
+
diff --git a/docs/en/reference/models/sam/sam3/necks.md b/docs/en/reference/models/sam/sam3/necks.md
new file mode 100644
index 0000000000..f9c13e2ffb
--- /dev/null
+++ b/docs/en/reference/models/sam/sam3/necks.md
@@ -0,0 +1,16 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/sam3/necks.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/necks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/necks.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.sam3.necks.Sam3DualViTDetNeck
+
+
diff --git a/docs/en/reference/models/sam/sam3/sam3_image.md b/docs/en/reference/models/sam/sam3/sam3_image.md
new file mode 100644
index 0000000000..c773796697
--- /dev/null
+++ b/docs/en/reference/models/sam/sam3/sam3_image.md
@@ -0,0 +1,20 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/sam3/sam3_image.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/sam3_image.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/sam3_image.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.sam3.sam3_image.SAM3SemanticModel
+
+
+
+## ::: ultralytics.models.sam.sam3.sam3_image._update_out
+
+
diff --git a/docs/en/reference/models/sam/sam3/text_encoder_ve.md b/docs/en/reference/models/sam/sam3/text_encoder_ve.md
new file mode 100644
index 0000000000..91e8a3a92b
--- /dev/null
+++ b/docs/en/reference/models/sam/sam3/text_encoder_ve.md
@@ -0,0 +1,32 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/sam3/text_encoder_ve.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/text_encoder_ve.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/text_encoder_ve.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.sam3.text_encoder_ve.ResidualAttentionBlock
+
+
+
+## ::: ultralytics.models.sam.sam3.text_encoder_ve.Transformer
+
+
+
+## ::: ultralytics.models.sam.sam3.text_encoder_ve.TextTransformer
+
+
+
+## ::: ultralytics.models.sam.sam3.text_encoder_ve.VETextEncoder
+
+
+
+## ::: ultralytics.models.sam.sam3.text_encoder_ve.text_global_pool
+
+
diff --git a/docs/en/reference/models/sam/sam3/tokenizer_ve.md b/docs/en/reference/models/sam/sam3/tokenizer_ve.md
new file mode 100644
index 0000000000..d38207bd92
--- /dev/null
+++ b/docs/en/reference/models/sam/sam3/tokenizer_ve.md
@@ -0,0 +1,52 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/sam3/tokenizer_ve.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/tokenizer_ve.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/tokenizer_ve.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.sam3.tokenizer_ve.SimpleTokenizer
+
+
+
+## ::: ultralytics.models.sam.sam3.tokenizer_ve.bytes_to_unicode
+
+
+
+## ::: ultralytics.models.sam.sam3.tokenizer_ve.get_pairs
+
+
+
+## ::: ultralytics.models.sam.sam3.tokenizer_ve.basic_clean
+
+
+
+## ::: ultralytics.models.sam.sam3.tokenizer_ve.whitespace_clean
+
+
+
+## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_canonicalize
+
+
+
+## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_lower
+
+
+
+## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_whitespace
+
+
+
+## ::: ultralytics.models.sam.sam3.tokenizer_ve.get_clean_fn
+
+
+
+## ::: ultralytics.models.sam.sam3.tokenizer_ve.canonicalize_text
+
+
diff --git a/docs/en/reference/models/sam/sam3/vitdet.md b/docs/en/reference/models/sam/sam3/vitdet.md
new file mode 100644
index 0000000000..5632234de8
--- /dev/null
+++ b/docs/en/reference/models/sam/sam3/vitdet.md
@@ -0,0 +1,24 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/sam3/vitdet.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vitdet.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vitdet.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.sam3.vitdet.Attention
+
+
+
+## ::: ultralytics.models.sam.sam3.vitdet.Block
+
+
+
+## ::: ultralytics.models.sam.sam3.vitdet.ViT
+
+
diff --git a/docs/en/reference/models/sam/sam3/vl_combiner.md b/docs/en/reference/models/sam/sam3/vl_combiner.md
new file mode 100644
index 0000000000..f5a7d84ff5
--- /dev/null
+++ b/docs/en/reference/models/sam/sam3/vl_combiner.md
@@ -0,0 +1,16 @@
+---
+description: TODO ADD DESCRIPTION
+keywords: TODO ADD KEYWORDS
+---
+
+# Reference for `ultralytics/models/sam/sam3/vl_combiner.py`
+
+!!! success "Improvements"
+
+ This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vl_combiner.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vl_combiner.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ thank you! ๐
+
+
+
+## ::: ultralytics.models.sam.sam3.vl_combiner.SAM3VLBackbone
+
+
diff --git a/mkdocs.yml b/mkdocs.yml
index 30d7967a34..3c53138a5d 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -592,6 +592,7 @@ nav:
- sam:
- amg: reference/models/sam/amg.md
- build: reference/models/sam/build.md
+ - build_sam3: reference/models/sam/build_sam3.md
- model: reference/models/sam/model.md
- modules:
- blocks: reference/models/sam/modules/blocks.md
@@ -603,6 +604,18 @@ nav:
- transformer: reference/models/sam/modules/transformer.md
- utils: reference/models/sam/modules/utils.md
- predict: reference/models/sam/predict.md
+ - sam3:
+ - decoder: reference/models/sam/sam3/decoder.md
+ - encoder: reference/models/sam/sam3/encoder.md
+ - geometry_encoders: reference/models/sam/sam3/geometry_encoders.md
+ - maskformer_segmentation: reference/models/sam/sam3/maskformer_segmentation.md
+ - model_misc: reference/models/sam/sam3/model_misc.md
+ - necks: reference/models/sam/sam3/necks.md
+ - sam3_image: reference/models/sam/sam3/sam3_image.md
+ - text_encoder_ve: reference/models/sam/sam3/text_encoder_ve.md
+ - tokenizer_ve: reference/models/sam/sam3/tokenizer_ve.md
+ - vitdet: reference/models/sam/sam3/vitdet.md
+ - vl_combiner: reference/models/sam/sam3/vl_combiner.md
- utils:
- loss: reference/models/utils/loss.md
- ops: reference/models/utils/ops.md
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index dbc57048ef..50fc67bba9 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
-__version__ = "8.3.236"
+__version__ = "8.3.237"
import importlib
import os
diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py
index 1b11ca12be..dfe0328a05 100644
--- a/ultralytics/engine/predictor.py
+++ b/ultralytics/engine/predictor.py
@@ -244,14 +244,15 @@ class BasePredictor:
for _ in gen: # sourcery skip: remove-empty-nested-block, noqa
pass
- def setup_source(self, source):
+ def setup_source(self, source, stride: int | None = None):
"""Set up source and inference mode.
Args:
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor): Source for
inference.
+ stride (int, optional): Model stride for image size checking.
"""
- self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
+ self.imgsz = check_imgsz(self.args.imgsz, stride=stride or self.model.stride, min_dim=2) # check image size
self.dataset = load_inference_source(
source=source,
batch=self.args.batch,
diff --git a/ultralytics/models/sam/__init__.py b/ultralytics/models/sam/__init__.py
index 8188c74d9e..e8723bcd97 100644
--- a/ultralytics/models/sam/__init__.py
+++ b/ultralytics/models/sam/__init__.py
@@ -1,7 +1,16 @@
# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
from .model import SAM
-from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor
+from .predict import (
+ Predictor,
+ SAM2DynamicInteractivePredictor,
+ SAM2Predictor,
+ SAM2VideoPredictor,
+ SAM3Predictor,
+ SAM3SemanticPredictor,
+ SAM3VideoPredictor,
+ SAM3VideoSemanticPredictor,
+)
__all__ = (
"SAM",
@@ -9,4 +18,8 @@ __all__ = (
"SAM2DynamicInteractivePredictor",
"SAM2Predictor",
"SAM2VideoPredictor",
+ "SAM3Predictor",
+ "SAM3SemanticPredictor",
+ "SAM3VideoPredictor",
+ "SAM3VideoSemanticPredictor",
) # tuple or list of exportable items
diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py
index 8e47502cda..ff9f0f56aa 100644
--- a/ultralytics/models/sam/build.py
+++ b/ultralytics/models/sam/build.py
@@ -21,6 +21,21 @@ from .modules.tiny_encoder import TinyViT
from .modules.transformer import TwoWayTransformer
+def _load_checkpoint(model, checkpoint):
+ """Load checkpoint into model from file path."""
+ if checkpoint is None:
+ return model
+
+ checkpoint = attempt_download_asset(checkpoint)
+ with open(checkpoint, "rb") as f:
+ state_dict = torch_load(f)
+ # Handle nested "model" key
+ if "model" in state_dict and isinstance(state_dict["model"], dict):
+ state_dict = state_dict["model"]
+ model.load_state_dict(state_dict)
+ return model
+
+
def build_sam_vit_h(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
return _build_sam(
@@ -205,10 +220,7 @@ def _build_sam(
pixel_std=[58.395, 57.12, 57.375],
)
if checkpoint is not None:
- checkpoint = attempt_download_asset(checkpoint)
- with open(checkpoint, "rb") as f:
- state_dict = torch_load(f)
- sam.load_state_dict(state_dict)
+ sam = _load_checkpoint(sam, checkpoint)
sam.eval()
return sam
@@ -299,10 +311,7 @@ def _build_sam2(
)
if checkpoint is not None:
- checkpoint = attempt_download_asset(checkpoint)
- with open(checkpoint, "rb") as f:
- state_dict = torch_load(f)["model"]
- sam2.load_state_dict(state_dict)
+ sam2 = _load_checkpoint(sam2, checkpoint)
sam2.eval()
return sam2
diff --git a/ultralytics/models/sam/build_sam3.py b/ultralytics/models/sam/build_sam3.py
new file mode 100644
index 0000000000..5e939c1849
--- /dev/null
+++ b/ultralytics/models/sam/build_sam3.py
@@ -0,0 +1,374 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+import torch.nn as nn
+
+from ultralytics.nn.modules.transformer import MLP
+from ultralytics.utils.patches import torch_load
+
+from .modules.blocks import PositionEmbeddingSine, RoPEAttention
+from .modules.encoders import MemoryEncoder
+from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
+from .modules.sam import SAM3Model
+from .sam3.decoder import TransformerDecoder, TransformerDecoderLayer
+from .sam3.encoder import TransformerEncoderFusion, TransformerEncoderLayer
+from .sam3.geometry_encoders import SequenceGeometryEncoder
+from .sam3.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
+from .sam3.model_misc import DotProductScoring, TransformerWrapper
+from .sam3.necks import Sam3DualViTDetNeck
+from .sam3.sam3_image import SAM3SemanticModel
+from .sam3.text_encoder_ve import VETextEncoder
+from .sam3.tokenizer_ve import SimpleTokenizer
+from .sam3.vitdet import ViT
+from .sam3.vl_combiner import SAM3VLBackbone
+
+
+def _create_vision_backbone(compile_mode=None, enable_inst_interactivity=True) -> Sam3DualViTDetNeck:
+ """Create SAM3 visual backbone with ViT and neck."""
+ # Position encoding
+ position_encoding = PositionEmbeddingSine(
+ num_pos_feats=256,
+ normalize=True,
+ scale=None,
+ temperature=10000,
+ )
+
+ # ViT backbone
+ vit_backbone = ViT(
+ img_size=1008,
+ pretrain_img_size=336,
+ patch_size=14,
+ embed_dim=1024,
+ depth=32,
+ num_heads=16,
+ mlp_ratio=4.625,
+ norm_layer="LayerNorm",
+ drop_path_rate=0.1,
+ qkv_bias=True,
+ use_abs_pos=True,
+ tile_abs_pos=True,
+ global_att_blocks=(7, 15, 23, 31),
+ rel_pos_blocks=(),
+ use_rope=True,
+ use_interp_rope=True,
+ window_size=24,
+ pretrain_use_cls_token=True,
+ retain_cls_token=False,
+ ln_pre=True,
+ ln_post=False,
+ return_interm_layers=False,
+ bias_patch_embed=False,
+ compile_mode=compile_mode,
+ )
+ return Sam3DualViTDetNeck(
+ position_encoding=position_encoding,
+ d_model=256,
+ scale_factors=[4.0, 2.0, 1.0, 0.5],
+ trunk=vit_backbone,
+ add_sam2_neck=enable_inst_interactivity,
+ )
+
+
+def _create_sam3_transformer() -> TransformerWrapper:
+ """Create SAM3 detector encoder and decoder."""
+ encoder: TransformerEncoderFusion = TransformerEncoderFusion(
+ layer=TransformerEncoderLayer(
+ d_model=256,
+ dim_feedforward=2048,
+ dropout=0.1,
+ pos_enc_at_attn=True,
+ pos_enc_at_cross_attn_keys=False,
+ pos_enc_at_cross_attn_queries=False,
+ pre_norm=True,
+ self_attention=nn.MultiheadAttention(
+ num_heads=8,
+ dropout=0.1,
+ embed_dim=256,
+ batch_first=True,
+ ),
+ cross_attention=nn.MultiheadAttention(
+ num_heads=8,
+ dropout=0.1,
+ embed_dim=256,
+ batch_first=True,
+ ),
+ ),
+ num_layers=6,
+ d_model=256,
+ num_feature_levels=1,
+ frozen=False,
+ use_act_checkpoint=True,
+ add_pooled_text_to_img_feat=False,
+ pool_text_with_mask=True,
+ )
+ decoder: TransformerDecoder = TransformerDecoder(
+ layer=TransformerDecoderLayer(
+ d_model=256,
+ dim_feedforward=2048,
+ dropout=0.1,
+ cross_attention=nn.MultiheadAttention(
+ num_heads=8,
+ dropout=0.1,
+ embed_dim=256,
+ ),
+ n_heads=8,
+ use_text_cross_attention=True,
+ ),
+ num_layers=6,
+ num_queries=200,
+ return_intermediate=True,
+ box_refine=True,
+ num_o2m_queries=0,
+ dac=True,
+ boxRPB="log",
+ d_model=256,
+ frozen=False,
+ interaction_layer=None,
+ dac_use_selfatt_ln=True,
+ use_act_checkpoint=True,
+ presence_token=True,
+ )
+
+ return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
+
+
+def build_sam3_image_model(
+ checkpoint_path: str, bpe_path: str, enable_segmentation: bool = True, compile: bool = False
+):
+ """Build SAM3 image model.
+
+ Args:
+ checkpoint_path: Optional path to model checkpoint
+ bpe_path: Path to the BPE tokenizer vocabulary
+ enable_segmentation: Whether to enable segmentation head
+ compile: To enable compilation, set to "default"
+
+ Returns:
+ A SAM3 image model
+ """
+ # Create visual components
+ compile_mode = "default" if compile else None
+ vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True)
+
+ # Create text components
+ text_encoder = VETextEncoder(
+ tokenizer=SimpleTokenizer(bpe_path=bpe_path),
+ d_model=256,
+ width=1024,
+ heads=16,
+ layers=24,
+ )
+
+ # Create visual-language backbone
+ backbone = SAM3VLBackbone(visual=vision_encoder, text=text_encoder, scalp=1)
+
+ # Create transformer components
+ transformer = _create_sam3_transformer()
+
+ # Create dot product scoring
+ dot_prod_scoring = DotProductScoring(
+ d_model=256,
+ d_proj=256,
+ prompt_mlp=MLP(
+ input_dim=256,
+ hidden_dim=2048,
+ output_dim=256,
+ num_layers=2,
+ residual=True,
+ out_norm=nn.LayerNorm(256),
+ ),
+ )
+
+ # Create segmentation head if enabled
+ segmentation_head = (
+ UniversalSegmentationHead(
+ hidden_dim=256,
+ upsampling_stages=3,
+ aux_masks=False,
+ presence_head=False,
+ dot_product_scorer=None,
+ act_ckpt=True,
+ cross_attend_prompt=nn.MultiheadAttention(
+ num_heads=8,
+ dropout=0,
+ embed_dim=256,
+ ),
+ pixel_decoder=PixelDecoder(
+ num_upsampling_stages=3,
+ interpolation_mode="nearest",
+ hidden_dim=256,
+ compile_mode=compile_mode,
+ ),
+ )
+ if enable_segmentation
+ else None
+ )
+
+ # Create geometry encoder
+ input_geometry_encoder = SequenceGeometryEncoder(
+ pos_enc=PositionEmbeddingSine(
+ num_pos_feats=256,
+ normalize=True,
+ scale=None,
+ temperature=10000,
+ ),
+ encode_boxes_as_points=False,
+ boxes_direct_project=True,
+ boxes_pool=True,
+ boxes_pos_enc=True,
+ d_model=256,
+ num_layers=3,
+ layer=TransformerEncoderLayer(
+ d_model=256,
+ dim_feedforward=2048,
+ dropout=0.1,
+ pos_enc_at_attn=False,
+ pre_norm=True,
+ pos_enc_at_cross_attn_queries=False,
+ pos_enc_at_cross_attn_keys=True,
+ ),
+ use_act_ckpt=True,
+ add_cls=True,
+ add_post_encode_proj=True,
+ )
+
+ # Create the SAM3SemanticModel model
+ model = SAM3SemanticModel(
+ backbone=backbone,
+ transformer=transformer,
+ input_geometry_encoder=input_geometry_encoder,
+ segmentation_head=segmentation_head,
+ num_feature_levels=1,
+ o2m_mask_predict=True,
+ dot_prod_scoring=dot_prod_scoring,
+ use_instance_query=False,
+ multimask_output=True,
+ )
+
+ # Load checkpoint
+ model = _load_checkpoint(model, checkpoint_path)
+ model.eval()
+ return model
+
+
+def build_interactive_sam3(checkpoint_path: str, compile=None, with_backbone=True) -> SAM3Model:
+ """Build the SAM3 Tracker module for video tracking.
+
+ Returns:
+ Sam3TrackerPredictor: Wrapped SAM3 Tracker module
+ """
+ # Create model components
+ memory_encoder = MemoryEncoder(out_dim=64, interpol_size=[1152, 1152])
+ memory_attention = MemoryAttention(
+ batch_first=True,
+ d_model=256,
+ pos_enc_at_input=True,
+ layer=MemoryAttentionLayer(
+ dim_feedforward=2048,
+ dropout=0.1,
+ pos_enc_at_attn=False,
+ pos_enc_at_cross_attn_keys=True,
+ pos_enc_at_cross_attn_queries=False,
+ self_attn=RoPEAttention(
+ embedding_dim=256,
+ num_heads=1,
+ downsample_rate=1,
+ rope_theta=10000.0,
+ feat_sizes=[72, 72],
+ ),
+ d_model=256,
+ cross_attn=RoPEAttention(
+ embedding_dim=256,
+ num_heads=1,
+ downsample_rate=1,
+ kv_in_dim=64,
+ rope_theta=10000.0,
+ feat_sizes=[72, 72],
+ rope_k_repeat=True,
+ ),
+ ),
+ num_layers=4,
+ )
+
+ backbone = (
+ SAM3VLBackbone(scalp=1, visual=_create_vision_backbone(compile_mode=compile), text=None)
+ if with_backbone
+ else None
+ )
+ model = SAM3Model(
+ image_size=1008,
+ image_encoder=backbone,
+ memory_attention=memory_attention,
+ memory_encoder=memory_encoder,
+ backbone_stride=14,
+ num_maskmem=7,
+ sigmoid_scale_for_mem_enc=20.0,
+ sigmoid_bias_for_mem_enc=-10.0,
+ use_mask_input_as_output_without_sam=True,
+ directly_add_no_mem_embed=True,
+ use_high_res_features_in_sam=True,
+ multimask_output_in_sam=True,
+ iou_prediction_use_sigmoid=True,
+ use_obj_ptrs_in_encoder=True,
+ add_tpos_enc_to_obj_ptrs=True,
+ only_obj_ptrs_in_the_past_for_eval=True,
+ pred_obj_scores=True,
+ pred_obj_scores_mlp=True,
+ fixed_no_obj_ptr=True,
+ multimask_output_for_tracking=True,
+ use_multimask_token_for_obj_ptr=True,
+ multimask_min_pt_num=0,
+ multimask_max_pt_num=1,
+ use_mlp_for_obj_ptr_proj=True,
+ compile_image_encoder=False,
+ no_obj_embed_spatial=True,
+ proj_tpos_enc_in_obj_ptrs=True,
+ use_signed_tpos_enc_to_obj_ptrs=True,
+ sam_mask_decoder_extra_args=dict(
+ dynamic_multimask_via_stability=True,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ ),
+ )
+
+ # Load checkpoint if provided
+ model = _load_checkpoint(model, checkpoint_path, interactive=True)
+
+ # Setup device and mode
+ model.eval()
+ return model
+
+
+def _load_checkpoint(model, checkpoint, interactive=False):
+ """Load SAM3 model checkpoint from file."""
+ with open(checkpoint, "rb") as f:
+ ckpt = torch_load(f)
+ if "model" in ckpt and isinstance(ckpt["model"], dict):
+ ckpt = ckpt["model"]
+ sam3_image_ckpt = {k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k}
+ if interactive:
+ sam3_image_ckpt.update(
+ {
+ k.replace("backbone.vision_backbone", "image_encoder.vision_backbone"): v
+ for k, v in sam3_image_ckpt.items()
+ if "backbone.vision_backbone" in k
+ }
+ )
+ sam3_image_ckpt.update(
+ {
+ k.replace("tracker.transformer.encoder", "memory_attention"): v
+ for k, v in ckpt.items()
+ if "tracker.transformer" in k
+ }
+ )
+ sam3_image_ckpt.update(
+ {
+ k.replace("tracker.maskmem_backbone", "memory_encoder"): v
+ for k, v in ckpt.items()
+ if "tracker.maskmem_backbone" in k
+ }
+ )
+ sam3_image_ckpt.update({k.replace("tracker.", ""): v for k, v in ckpt.items() if "tracker." in k})
+ model.load_state_dict(sam3_image_ckpt, strict=False)
+ return model
diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py
index dd366c115e..1026f6b49c 100644
--- a/ultralytics/models/sam/model.py
+++ b/ultralytics/models/sam/model.py
@@ -21,7 +21,7 @@ from pathlib import Path
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
-from .predict import Predictor, SAM2Predictor
+from .predict import Predictor, SAM2Predictor, SAM3Predictor
class SAM(Model):
@@ -59,6 +59,7 @@ class SAM(Model):
if model and Path(model).suffix not in {".pt", ".pth"}:
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
self.is_sam2 = "sam2" in Path(model).stem
+ self.is_sam3 = "sam3" in Path(model).stem
super().__init__(model=model, task="segment")
def _load(self, weights: str, task=None):
@@ -72,9 +73,14 @@ class SAM(Model):
>>> sam = SAM("sam_b.pt")
>>> sam._load("path/to/custom_weights.pt")
"""
- from .build import build_sam # slow import
+ if self.is_sam3:
+ from .build_sam3 import build_interactive_sam3
- self.model = build_sam(weights)
+ self.model = build_interactive_sam3(weights)
+ else:
+ from .build import build_sam # slow import
+
+ self.model = build_sam(weights)
def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
"""Perform segmentation prediction on the given image or video source.
@@ -158,4 +164,6 @@ class SAM(Model):
>>> print(task_map)
{'segment': {'predictor': }}
"""
- return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
+ return {
+ "segment": {"predictor": SAM2Predictor if self.is_sam2 else SAM3Predictor if self.is_sam3 else Predictor}
+ }
diff --git a/ultralytics/models/sam/modules/blocks.py b/ultralytics/models/sam/modules/blocks.py
index e27f0d0fcf..6ff9ece752 100644
--- a/ultralytics/models/sam/modules/blocks.py
+++ b/ultralytics/models/sam/modules/blocks.py
@@ -79,6 +79,7 @@ class MaskDownSampler(nn.Module):
padding: int = 0,
total_stride: int = 16,
activation: type[nn.Module] = nn.GELU,
+ interpol_size: tuple[int, int] | None = None,
):
"""Initialize a mask downsampler module for progressive downsampling and channel expansion."""
super().__init__()
@@ -102,9 +103,24 @@ class MaskDownSampler(nn.Module):
mask_in_chans = mask_out_chans
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
+ self.interpol_size = interpol_size
+ if self.interpol_size is not None:
+ assert isinstance(self.interpol_size, (list, tuple)), (
+ f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
+ )
+ self.interpol_size = list(interpol_size)
+ assert len(self.interpol_size) == 2
def forward(self, x: Tensor) -> Tensor:
"""Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
+ if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]):
+ x = F.interpolate(
+ x.float(),
+ size=self.interpol_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True,
+ ).to(x.dtype)
return self.encoder(x)
@@ -429,13 +445,7 @@ class RoPEAttention(Attention):
)
# Attention
- _, _, _, c_per_head = q.shape
- attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
- attn = attn / math.sqrt(c_per_head)
- attn = torch.softmax(attn, dim=-1)
-
- # Get output
- out = attn @ v
+ out = F.scaled_dot_product_attention(q, k, v)
out = self._recombine_heads(out)
out = self.out_proj(out)
@@ -1033,6 +1043,7 @@ class PatchEmbed(nn.Module):
padding: tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
+ bias: bool = True,
) -> None:
"""Initialize the PatchEmbed module for converting image patches to embeddings.
@@ -1045,10 +1056,11 @@ class PatchEmbed(nn.Module):
padding (tuple[int, int]): Padding applied to the input before convolution.
in_chans (int): Number of input image channels.
embed_dim (int): Dimensionality of the output patch embeddings.
+ bias (bool): Whether to include a bias term in the convolutional layer.
"""
super().__init__()
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Compute patch embedding by applying convolution and transposing resulting tensor."""
diff --git a/ultralytics/models/sam/modules/decoders.py b/ultralytics/models/sam/modules/decoders.py
index 69adf48b5c..b845f6b6e2 100644
--- a/ultralytics/models/sam/modules/decoders.py
+++ b/ultralytics/models/sam/modules/decoders.py
@@ -436,9 +436,8 @@ class SAM2MaskDecoder(nn.Module):
def _get_stability_scores(self, mask_logits):
"""Compute mask stability scores based on IoU between upper and lower thresholds."""
mask_logits = mask_logits.flatten(-2)
- stability_delta = self.dynamic_multimask_stability_delta
- area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
- area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ area_i = torch.sum(mask_logits > self.dynamic_multimask_stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -self.dynamic_multimask_stability_delta, dim=-1).float()
return torch.where(area_u > 0, area_i / area_u, 1.0)
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py
index d5066ce1bd..8fc3da75d5 100644
--- a/ultralytics/models/sam/modules/encoders.py
+++ b/ultralytics/models/sam/modules/encoders.py
@@ -361,6 +361,7 @@ class MemoryEncoder(nn.Module):
self,
out_dim,
in_dim=256, # in_dim of pix_feats
+ interpol_size: tuple[int, int] | None = None,
):
"""Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
@@ -370,10 +371,12 @@ class MemoryEncoder(nn.Module):
Args:
out_dim (int): Output dimension of the encoded features.
in_dim (int): Input dimension of the pixel features.
+ interpol_size (tuple[int, int] | None): Size to interpolate masks to. If None, uses the size of pixel
+ features.
"""
super().__init__()
- self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
+ self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1, interpol_size=interpol_size)
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
diff --git a/ultralytics/models/sam/modules/memory_attention.py b/ultralytics/models/sam/modules/memory_attention.py
index d229a4f326..63c573d697 100644
--- a/ultralytics/models/sam/modules/memory_attention.py
+++ b/ultralytics/models/sam/modules/memory_attention.py
@@ -59,6 +59,8 @@ class MemoryAttentionLayer(nn.Module):
pos_enc_at_attn: bool = False,
pos_enc_at_cross_attn_keys: bool = True,
pos_enc_at_cross_attn_queries: bool = False,
+ self_attn: nn.Module | None = None,
+ cross_attn: nn.Module | None = None,
):
"""Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
@@ -69,13 +71,15 @@ class MemoryAttentionLayer(nn.Module):
pos_enc_at_attn (bool): Whether to add positional encoding at attention.
pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
+ self_attn (nn.Module | None): Custom self-attention module. If None, a default RoPEAttention is used.
+ cross_attn (nn.Module | None): Custom cross-attention module. If None, a default RoPEAttention is used.
"""
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
self.dropout_value = dropout
- self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
- self.cross_attn_image = RoPEAttention(
+ self.self_attn = self_attn or RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
+ self.cross_attn_image = cross_attn or RoPEAttention(
rope_k_repeat=True,
embedding_dim=256,
num_heads=1,
diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py
index 9b5efe8068..701369a15b 100644
--- a/ultralytics/models/sam/modules/sam.py
+++ b/ultralytics/models/sam/modules/sam.py
@@ -13,7 +13,7 @@ from torch.nn.init import trunc_normal_
from ultralytics.nn.modules import MLP
from ultralytics.utils import LOGGER
-from .blocks import SAM2TwoWayTransformer
+from .blocks import SAM2TwoWayTransformer, TwoWayTransformer
from .decoders import MaskDecoder, SAM2MaskDecoder
from .encoders import ImageEncoderViT, PromptEncoder
from .utils import get_1d_sine_pe, select_closest_cond_frames
@@ -329,6 +329,7 @@ class SAM2Model(torch.nn.Module):
self._build_sam_heads()
self.max_cond_frames_in_attn = max_cond_frames_in_attn
+ self.add_all_frames_to_correct_as_cond = True
# Model compilation
if compile_image_encoder:
@@ -473,7 +474,7 @@ class SAM2Model(torch.nn.Module):
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
sam_mask_prompt = F.interpolate(
- mask_inputs.float(),
+ mask_inputs.to(backbone_features.dtype),
size=self.sam_prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
@@ -571,7 +572,7 @@ class SAM2Model(torch.nn.Module):
# produce an object pointer using the SAM decoder from the mask input
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
backbone_features=backbone_features,
- mask_inputs=self.mask_downsample(mask_inputs_float),
+ mask_inputs=self.mask_downsample(mask_inputs_float.to(backbone_features.dtype)),
high_res_features=high_res_features,
)
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
@@ -818,7 +819,6 @@ class SAM2Model(torch.nn.Module):
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
maskmem_features = maskmem_out["vision_features"]
- maskmem_pos_enc = maskmem_out["vision_pos_enc"]
# add a no-object embedding to the spatial memory to indicate that the frame
# is predicted to be occluded (i.e. no object is appearing in the frame)
if self.no_obj_embed_spatial is not None:
@@ -827,7 +827,7 @@ class SAM2Model(torch.nn.Module):
..., None, None
].expand(*maskmem_features.shape)
- return maskmem_features, maskmem_pos_enc
+ return maskmem_features, maskmem_out["vision_pos_enc"]
def _track_step(
self,
@@ -1005,7 +1005,151 @@ class SAM2Model(torch.nn.Module):
def set_imgsz(self, imgsz):
"""Set image size to make model compatible with different image sizes."""
+ if hasattr(self.image_encoder, "set_imgsz"):
+ self.image_encoder.set_imgsz(imgsz)
self.image_size = imgsz[0]
self.sam_prompt_encoder.input_image_size = imgsz
- self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
+ self.sam_prompt_encoder.image_embedding_size = [
+ x // self.backbone_stride for x in imgsz
+ ] # fixed ViT patch size of 16
+ self.sam_prompt_encoder.mask_input_size = [
+ x // self.backbone_stride * 4 for x in imgsz
+ ] # fixed ViT patch size of 16
self.sam_image_embedding_size = self.image_size // self.backbone_stride # update image embedding size
+
+
+class SAM3Model(SAM2Model):
+ """SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
+
+ def __init__(
+ self,
+ image_encoder,
+ memory_attention,
+ memory_encoder,
+ num_maskmem=7,
+ image_size=1008,
+ backbone_stride=14,
+ sigmoid_scale_for_mem_enc=1,
+ sigmoid_bias_for_mem_enc=0,
+ binarize_mask_from_pts_for_mem_enc=False,
+ use_mask_input_as_output_without_sam=False,
+ max_cond_frames_in_attn=-1,
+ directly_add_no_mem_embed=False,
+ use_high_res_features_in_sam=False,
+ multimask_output_in_sam=False,
+ multimask_min_pt_num=1,
+ multimask_max_pt_num=1,
+ multimask_output_for_tracking=False,
+ use_multimask_token_for_obj_ptr: bool = False,
+ iou_prediction_use_sigmoid=False,
+ memory_temporal_stride_for_eval=1,
+ non_overlap_masks_for_mem_enc=False,
+ use_obj_ptrs_in_encoder=False,
+ max_obj_ptrs_in_encoder=16,
+ add_tpos_enc_to_obj_ptrs=True,
+ proj_tpos_enc_in_obj_ptrs=False,
+ use_signed_tpos_enc_to_obj_ptrs=False,
+ only_obj_ptrs_in_the_past_for_eval=False,
+ pred_obj_scores: bool = False,
+ pred_obj_scores_mlp: bool = False,
+ fixed_no_obj_ptr: bool = False,
+ soft_no_obj_ptr: bool = False,
+ use_mlp_for_obj_ptr_proj: bool = False,
+ no_obj_embed_spatial: bool = False,
+ sam_mask_decoder_extra_args=None,
+ compile_image_encoder: bool = False,
+ ):
+ """SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
+ super().__init__(
+ image_encoder,
+ memory_attention,
+ memory_encoder,
+ num_maskmem,
+ image_size,
+ backbone_stride,
+ sigmoid_scale_for_mem_enc,
+ sigmoid_bias_for_mem_enc,
+ binarize_mask_from_pts_for_mem_enc,
+ use_mask_input_as_output_without_sam,
+ max_cond_frames_in_attn,
+ directly_add_no_mem_embed,
+ use_high_res_features_in_sam,
+ multimask_output_in_sam,
+ multimask_min_pt_num,
+ multimask_max_pt_num,
+ multimask_output_for_tracking,
+ use_multimask_token_for_obj_ptr,
+ iou_prediction_use_sigmoid,
+ memory_temporal_stride_for_eval,
+ non_overlap_masks_for_mem_enc,
+ use_obj_ptrs_in_encoder,
+ max_obj_ptrs_in_encoder,
+ add_tpos_enc_to_obj_ptrs,
+ proj_tpos_enc_in_obj_ptrs,
+ use_signed_tpos_enc_to_obj_ptrs,
+ only_obj_ptrs_in_the_past_for_eval,
+ pred_obj_scores,
+ pred_obj_scores_mlp,
+ fixed_no_obj_ptr,
+ soft_no_obj_ptr,
+ use_mlp_for_obj_ptr_proj,
+ no_obj_embed_spatial,
+ sam_mask_decoder_extra_args,
+ compile_image_encoder,
+ )
+ self.sam_mask_decoder = SAM2MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=TwoWayTransformer(
+ depth=2,
+ embedding_dim=self.sam_prompt_embed_dim,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=self.sam_prompt_embed_dim,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ use_high_res_features=self.use_high_res_features_in_sam,
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
+ pred_obj_scores=self.pred_obj_scores,
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
+ **(self.sam_mask_decoder_extra_args or {}),
+ )
+
+ def forward_image(self, img_batch: torch.Tensor):
+ """Process image batch through encoder to extract multi-level features for SAM model."""
+ backbone_out = self.image_encoder.forward_image_sam2(img_batch)
+ if self.use_high_res_features_in_sam:
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
+ return backbone_out
+
+ def set_imgsz(self, imgsz: tuple[int, int]):
+ """Set the image size for the model and mask downsampler."""
+ super().set_imgsz(imgsz)
+ self.memory_encoder.mask_downsampler.interpol_size = [size // 14 * 16 for size in imgsz]
+
+ @staticmethod
+ def _suppress_shrinked_masks(pred_masks, new_pred_masks, shrink_threshold=0.3):
+ """Suppress masks that shrink in area after applying pixelwise non-overlapping constraints."""
+ area_before = (pred_masks > 0).sum(dim=(-1, -2))
+ area_after = (new_pred_masks > 0).sum(dim=(-1, -2))
+ area_before = torch.clamp(area_before, min=1.0)
+ area_ratio = area_after / area_before
+ keep = area_ratio >= shrink_threshold
+ keep_mask = keep[..., None, None].expand_as(pred_masks)
+ pred_masks_after = torch.where(keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0))
+ return pred_masks_after
+
+ def _suppress_object_pw_area_shrinkage(self, pred_masks):
+ """This function suppresses masks that shrink in area after applying pixelwise non-overlapping constraints. Note
+ that the final output can still be overlapping.
+ """
+ # Apply pixel-wise non-overlapping constraint based on mask scores
+ pixel_level_non_overlapping_masks = self._apply_non_overlapping_constraints(pred_masks)
+ # Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints
+ # NOTE: The output of this function can be a no op if none of the masks shrinked by a large factor.
+ pred_masks = self._suppress_shrinked_masks(pred_masks, pixel_level_non_overlapping_masks)
+ return pred_masks
diff --git a/ultralytics/models/sam/modules/utils.py b/ultralytics/models/sam/modules/utils.py
index e934817934..86ea23ac9f 100644
--- a/ultralytics/models/sam/modules/utils.py
+++ b/ultralytics/models/sam/modules/utils.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+import math
from typing import Any
import torch
@@ -86,7 +87,7 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
return pos_embed
-def init_t_xy(end_x: int, end_y: int):
+def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0):
"""Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
This function creates coordinate tensors for a grid with dimensions end_x ร end_y. It generates a linear index
@@ -95,6 +96,8 @@ def init_t_xy(end_x: int, end_y: int):
Args:
end_x (int): Width of the grid (number of columns).
end_y (int): Height of the grid (number of rows).
+ scale (float): Scaling factor to apply to the coordinates.
+ offset (int): Offset to add to the coordinates.
Returns:
t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
@@ -110,10 +113,10 @@ def init_t_xy(end_x: int, end_y: int):
t = torch.arange(end_x * end_y, dtype=torch.float32)
t_x = (t % end_x).float()
t_y = torch.div(t, end_x, rounding_mode="floor").float()
- return t_x, t_y
+ return t_x * scale + offset, t_y * scale + offset
-def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
+def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0, scale_pos: float = 1.0):
"""Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
@@ -124,6 +127,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
end_x (int): Width of the 2D grid.
end_y (int): Height of the 2D grid.
theta (float, optional): Scaling factor for frequency computation.
+ scale_pos (float, optional): Scaling factor for position coordinates.
Returns:
(torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).
@@ -137,7 +141,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
- t_x, t_y = init_t_xy(end_x, end_y)
+ t_x, t_y = init_t_xy(end_x, end_y, scale=scale_pos)
freqs_x = torch.outer(t_x, freqs_x)
freqs_y = torch.outer(t_y, freqs_y)
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
@@ -375,3 +379,129 @@ def add_decomposed_rel_pos(
)
return attn
+
+
+def get_abs_pos(
+ abs_pos: torch.Tensor,
+ has_cls_token: bool,
+ hw: tuple[int, int],
+ retain_cls_token: bool = False,
+ tiling: bool = False,
+) -> torch.Tensor:
+ """Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
+ original embeddings.
+
+ Args:
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
+ hw (Tuple): size of input image tokens.
+ retain_cls_token: whether to retain the cls_token
+ tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win)
+
+ Returns:
+ Absolute positional embeddings after processing with shape (1, H, W, C),: if retain_cls_token is False,
+ otherwise (1, 1+H*W, C).
+ """
+ if retain_cls_token:
+ assert has_cls_token
+
+ h, w = hw
+ if has_cls_token:
+ cls_pos = abs_pos[:, :1]
+ abs_pos = abs_pos[:, 1:]
+
+ xy_num = abs_pos.shape[1]
+ size = int(math.sqrt(xy_num))
+ assert size * size == xy_num
+
+ if size != h or size != w:
+ new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2)
+ if tiling:
+ new_abs_pos = new_abs_pos.tile([1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])])[
+ :, :, :h, :w
+ ]
+ else:
+ new_abs_pos = F.interpolate(
+ new_abs_pos,
+ size=(h, w),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ if not retain_cls_token:
+ return new_abs_pos.permute(0, 2, 3, 1)
+ else:
+ # add cls_token back, flatten spatial dims
+ assert has_cls_token
+ return torch.cat(
+ [cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)],
+ dim=1,
+ )
+
+ else:
+ if not retain_cls_token:
+ return abs_pos.reshape(1, h, w, -1)
+ else:
+ assert has_cls_token
+ return torch.cat([cls_pos, abs_pos], dim=1)
+
+
+def concat_rel_pos(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ q_hw: tuple[int, int],
+ k_hw: tuple[int, int],
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ rescale: bool = False,
+ relative_coords: torch.Tensor = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now effectively including rel pos biases.
+
+ Args:
+ q (torch.Tensor): q tensor with shape (B, L_q, C).
+ k (torch.Tensor): k tensor with shape (B, L_k, C).
+ q_hw: These are spatial size of q tensors.
+ k_hw: These are spatial size of k tensors.
+ rel_pos_h: These are relative pos embeddings/params of height.
+ rel_pos_w: These are relative pos embeddings/params of width.
+ rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will scale by the wrong factor due to
+ the concat.
+ relative_coords (torch.Tensor, optional): Precomputed relative coords index tensor.
+
+ Returns:
+ q, k: But, padded so that qk^T accounts for rel pos biases.
+ """
+ q_h, q_w = q_hw
+ k_h, k_w = k_hw
+
+ assert (q_h == q_w) and (k_h == k_w), "only square inputs supported"
+
+ if relative_coords is not None:
+ Rh = rel_pos_h[relative_coords]
+ Rw = rel_pos_w[relative_coords]
+ else:
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+ B, _, dim = q.shape
+ r_q = q.reshape(B, q_h, q_w, dim)
+
+ old_scale = dim**0.5
+ new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa
+ # attn will be divided by new_scale, but we want to divide q by old_scale
+ scale_ratio = new_scale / old_scale
+
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w)
+
+ eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device)
+ eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device)
+
+ eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h])
+ eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w])
+
+ q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1)
+ k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(B, k_h * k_w, -1)
+
+ return q, k
diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py
index 8f421d977a..b236b3d9d3 100644
--- a/ultralytics/models/sam/predict.py
+++ b/ultralytics/models/sam/predict.py
@@ -10,7 +10,8 @@ segmentation tasks.
from __future__ import annotations
-from collections import OrderedDict
+from collections import OrderedDict, defaultdict
+from copy import deepcopy
from typing import Any
import cv2
@@ -21,7 +22,8 @@ import torch.nn.functional as F
from ultralytics.data.augment import LetterBox
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
-from ultralytics.utils import DEFAULT_CFG, ops
+from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
+from ultralytics.utils.metrics import box_iou, mask_iou
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
from .amg import (
@@ -35,6 +37,7 @@ from .amg import (
uncrop_boxes_xyxy,
uncrop_masks,
)
+from .sam3.geometry_encoders import Prompt
class Predictor(BasePredictor):
@@ -79,6 +82,8 @@ class Predictor(BasePredictor):
>>> results = predictor(bboxes=bboxes)
"""
+ stride = 16
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize the Predictor with configuration, overrides, and callbacks.
@@ -156,7 +161,7 @@ class Predictor(BasePredictor):
1
"""
assert len(im) == 1, "SAM model does not currently support batched inference"
- letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
+ letterbox = LetterBox(self.imgsz, auto=False, center=False)
return [letterbox(image=x) for x in im]
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
@@ -520,30 +525,6 @@ class Predictor(BasePredictor):
self.segment_all = False
return results
- def setup_source(self, source):
- """Set up the data source for inference.
-
- This method configures the data source from which images will be fetched for inference. It supports various
- input types such as image files, directories, video files, and other compatible data sources.
-
- Args:
- source (str | Path | None): The path or identifier for the image data source. Can be a file path, directory
- path, URL, or other supported source types.
-
- Examples:
- >>> predictor = Predictor()
- >>> predictor.setup_source("path/to/images")
- >>> predictor.setup_source("video.mp4")
- >>> predictor.setup_source(None) # Uses default source if available
-
- Notes:
- - If source is None, the method may use a default source if configured.
- - The method adapts to different source types and prepares them for subsequent inference steps.
- - Supported source types may include local files, directories, URLs, and video streams.
- """
- if source is not None:
- super().setup_source(source)
-
def set_image(self, image):
"""Preprocess and set a single image for inference.
@@ -576,12 +557,18 @@ class Predictor(BasePredictor):
self.features = self.get_im_features(im)
break
- def get_im_features(self, im):
- """Extract image features using the SAM model's image encoder for subsequent mask prediction."""
+ def setup_source(self, source):
+ """Set up the data source for SAM inference."""
+ if source is None: # handle the situation when set_imgsz in advance
+ return
+ super().setup_source(source, self.stride)
assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
f"SAM models only support square image size, but got {self.imgsz}."
)
self.model.set_imgsz(self.imgsz)
+
+ def get_im_features(self, im):
+ """Extract image features using the SAM model's image encoder for subsequent mask prediction."""
return self.model.image_encoder(im)
def set_prompts(self, prompts):
@@ -726,6 +713,7 @@ class SAM2Predictor(Predictor):
(128, 128),
(64, 64),
]
+ stride = 16
def get_model(self):
"""Retrieve and initialize the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
@@ -767,45 +755,13 @@ class SAM2Predictor(Predictor):
points, labels = bboxes, bbox_labels
return points, labels, masks
- def set_image(self, image):
- """Preprocess and set a single image for inference using the SAM2 model.
-
- This method initializes the model if not already done, configures the data source to the specified image, and
- preprocesses the image for feature extraction. It supports setting only one image at a time.
-
- Args:
- image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
-
- Raises:
- AssertionError: If more than one image is attempted to be set.
-
- Examples:
- >>> predictor = SAM2Predictor()
- >>> predictor.set_image("path/to/image.jpg")
- >>> predictor.set_image(np.array([...])) # Using a numpy array
-
- Notes:
- - This method must be called before performing any inference on a new image.
- - The method caches the extracted features for efficient subsequent inferences on the same image.
- - Only one image can be set at a time. To process multiple images, call this method for each new image.
- """
- if self.model is None:
- self.setup_model(model=None)
- self.setup_source(image)
- assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
- for batch in self.dataset:
- im = self.preprocess(batch[1])
- self.features = self.get_im_features(im)
- break
+ def setup_source(self, source):
+ """Set up the data source and image size for SAM2 inference."""
+ super().setup_source(source)
+ self._bb_feat_sizes = [[int(x / (self.stride * i)) for x in self.imgsz] for i in [1 / 4, 1 / 2, 1]]
def get_im_features(self, im):
"""Extract image features from the SAM image encoder for subsequent processing."""
- assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
- f"SAM 2 models only support square image size, but got {self.imgsz}."
- )
- self.model.set_imgsz(self.imgsz)
- self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]]
-
backbone_out = self.model.forward_image(im)
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
if self.model.directly_add_no_mem_embed:
@@ -1037,6 +993,7 @@ class SAM2VideoPredictor(SAM2Predictor):
labels=None,
masks=None,
frame_idx=0,
+ inference_state: dict[str, Any] | None = None,
):
"""Add new points or masks to a specific frame for a given object ID.
@@ -1051,6 +1008,8 @@ class SAM2VideoPredictor(SAM2Predictor):
labels (torch.Tensor, optional): The labels corresponding to the points.
masks (torch.Tensor, optional): Binary masks for the object.
frame_idx (int, optional): The index of the frame to which the prompts are applied.
+ inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's
+ inference state.
Returns:
pred_masks (torch.Tensor): The flattened predicted masks.
@@ -1064,24 +1023,25 @@ class SAM2VideoPredictor(SAM2Predictor):
- If the frame is being tracked for the first time, it is treated as an initial conditioning frame.
- The method handles the consolidation of outputs and resizing of masks to the original video resolution.
"""
+ inference_state = inference_state or self.inference_state
assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other."
- obj_idx = self._obj_id_to_idx(obj_id)
+ obj_idx = self._obj_id_to_idx(obj_id, inference_state)
point_inputs = None
pop_key = "point_inputs_per_obj"
if points is not None:
point_inputs = {"point_coords": points, "point_labels": labels}
- self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs
+ inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs
pop_key = "mask_inputs_per_obj"
- self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks
- self.inference_state[pop_key][obj_idx].pop(frame_idx, None)
+ inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks
+ inference_state[pop_key][obj_idx].pop(frame_idx, None)
# If this frame hasn't been tracked before, we treat it as an initial conditioning
# frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks.
- is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"]
- obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
- obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
# if the model sees all frames receiving clicks/mask as conditioning frames.
is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond
@@ -1119,6 +1079,7 @@ class SAM2VideoPredictor(SAM2Predictor):
# them into memory.
run_mem_encoder=False,
prev_sam_mask_logits=prev_sam_mask_logits,
+ inference_state=inference_state,
)
# Add the output to the output dict (to be used as future memory)
obj_temp_output_dict[storage_key][frame_idx] = current_out
@@ -1128,31 +1089,37 @@ class SAM2VideoPredictor(SAM2Predictor):
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
+ inference_state=inference_state,
)
pred_masks = consolidated_out["pred_masks"].flatten(0, 1)
return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device)
@smart_inference_mode()
- def propagate_in_video_preflight(self):
+ def propagate_in_video_preflight(self, inference_state: dict[str, Any] | None = None):
"""Prepare inference_state and consolidate temporary outputs before tracking.
This method marks the start of tracking, disallowing the addition of new objects until the session is reset. It
consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. Additionally,
it clears non-conditioning memory around input frames and ensures that the state is consistent with the provided
inputs.
+
+ Args:
+ inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's
+ inference state.
"""
+ inference_state = inference_state or self.inference_state
# Tracking has started and we don't allow adding new objects until session is reset.
- self.inference_state["tracking_has_started"] = True
- batch_size = len(self.inference_state["obj_idx_to_id"])
+ inference_state["tracking_has_started"] = True
+ batch_size = len(inference_state["obj_idx_to_id"])
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
# add them into "output_dict".
- temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"]
- output_dict = self.inference_state["output_dict"]
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+ output_dict = inference_state["output_dict"]
# "consolidated_frame_inds" contains indices of those frames where consolidated
# temporary outputs have been added (either in this call or any previous calls
# to `propagate_in_video_preflight`).
- consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
for is_cond in {False, True}:
# Separately consolidate conditioning and non-conditioning temp outputs
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
@@ -1166,11 +1133,11 @@ class SAM2VideoPredictor(SAM2Predictor):
# consolidate the temporary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
- frame_idx, is_cond=is_cond, run_mem_encoder=True
+ frame_idx, is_cond=is_cond, run_mem_encoder=True, inference_state=inference_state
)
# merge them into "output_dict" and also create per-object slices
output_dict[storage_key][frame_idx] = consolidated_out
- self._add_output_per_object(frame_idx, consolidated_out, storage_key)
+ self._add_output_per_object(frame_idx, consolidated_out, storage_key, inference_state=inference_state)
if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(frame_idx)
@@ -1183,7 +1150,7 @@ class SAM2VideoPredictor(SAM2Predictor):
# output on the same frame in "non_cond_frame_outputs"
for frame_idx in output_dict["cond_frame_outputs"]:
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
- for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
for frame_idx in obj_output_dict["cond_frame_outputs"]:
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
@@ -1196,9 +1163,9 @@ class SAM2VideoPredictor(SAM2Predictor):
consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"]
)
input_frames_inds = set()
- for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values():
+ for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
input_frames_inds.update(point_inputs_per_frame.keys())
- for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values():
+ for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
input_frames_inds.update(mask_inputs_per_frame.keys())
assert all_consolidated_frame_inds == input_frames_inds
@@ -1217,9 +1184,21 @@ class SAM2VideoPredictor(SAM2Predictor):
return
assert predictor.dataset is not None
assert predictor.dataset.mode == "video"
+ predictor.inference_state = predictor._init_state(predictor.dataset.frames)
+ @staticmethod
+ def _init_state(num_frames):
+ """Initialize an inference state.
+
+ This function sets up the initial state required for performing inference on video data. It includes
+ initializing various dictionaries and ordered dictionaries that will store inputs, outputs, and other metadata
+ relevant to the tracking process.
+
+ Args:
+ num_frames (int): The number of frames in the video.
+ """
inference_state = {
- "num_frames": predictor.dataset.frames,
+ "num_frames": num_frames, # TODO: see if there's any chance to remove it
"point_inputs_per_obj": {}, # inputs points on each frame
"mask_inputs_per_obj": {}, # inputs mask on each frame
"constants": {}, # values that don't change across frames (so we only need to hold one copy of them)
@@ -1247,7 +1226,7 @@ class SAM2VideoPredictor(SAM2Predictor):
"tracking_has_started": False,
"frames_already_tracked": [],
}
- predictor.inference_state = inference_state
+ return inference_state
def get_im_features(self, im, batch=1):
"""Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.
@@ -1265,7 +1244,6 @@ class SAM2VideoPredictor(SAM2Predictor):
- If `batch` is greater than 1, the features are expanded to fit the batch size.
- The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.
"""
- self.model.set_imgsz(self.imgsz)
backbone_out = self.model.forward_image(im)
if batch > 1: # expand features if there's more than one prompt
for i, feat in enumerate(backbone_out["backbone_fpn"]):
@@ -1276,11 +1254,13 @@ class SAM2VideoPredictor(SAM2Predictor):
_, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out)
return vis_feats, vis_pos_embed, feat_sizes
- def _obj_id_to_idx(self, obj_id):
+ def _obj_id_to_idx(self, obj_id, inference_state: dict[str, Any] | None = None):
"""Map client-side object id to model-side object index.
Args:
obj_id (int): The unique identifier of the object provided by the client side.
+ inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's
+ inference state.
Returns:
(int): The index of the object on the model side.
@@ -1295,27 +1275,28 @@ class SAM2VideoPredictor(SAM2Predictor):
- It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`).
- Additional data structures are initialized for the new object to store inputs and outputs.
"""
- obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None)
+ inference_state = inference_state or self.inference_state
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
if obj_idx is not None:
return obj_idx
# This is a new object id not sent to the server before. We only allow adding
# new objects *before* the tracking starts.
- allow_new_object = not self.inference_state["tracking_has_started"]
+ allow_new_object = not inference_state["tracking_has_started"]
if allow_new_object:
# get the next object slot
- obj_idx = len(self.inference_state["obj_id_to_idx"])
- self.inference_state["obj_id_to_idx"][obj_id] = obj_idx
- self.inference_state["obj_idx_to_id"][obj_idx] = obj_id
- self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"])
+ obj_idx = len(inference_state["obj_id_to_idx"])
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
# set up input and output structures for this object
- self.inference_state["point_inputs_per_obj"][obj_idx] = {}
- self.inference_state["mask_inputs_per_obj"][obj_idx] = {}
- self.inference_state["output_dict_per_obj"][obj_idx] = {
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
+ inference_state["output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: }
"non_cond_frame_outputs": {}, # dict containing {frame_idx: }
}
- self.inference_state["temp_output_dict_per_obj"][obj_idx] = {
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: }
"non_cond_frame_outputs": {}, # dict containing {frame_idx: }
}
@@ -1323,7 +1304,7 @@ class SAM2VideoPredictor(SAM2Predictor):
else:
raise RuntimeError(
f"Cannot add new object id {obj_id} after tracking starts. "
- f"All existing object ids: {self.inference_state['obj_ids']}. "
+ f"All existing object ids: {inference_state['obj_ids']}. "
f"Please call 'reset_state' to restart from scratch."
)
@@ -1338,6 +1319,7 @@ class SAM2VideoPredictor(SAM2Predictor):
reverse,
run_mem_encoder,
prev_sam_mask_logits=None,
+ inference_state: dict[str, Any] | None = None,
):
"""Run tracking on a single frame based on current inputs and previous memory.
@@ -1351,6 +1333,8 @@ class SAM2VideoPredictor(SAM2Predictor):
reverse (bool): Indicates if the tracking should be performed in reverse order.
run_mem_encoder (bool): Indicates if the memory encoder should be executed.
prev_sam_mask_logits (torch.Tensor | None): Previous mask logits for the current object.
+ inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's
+ inference state.
Returns:
(dict): A dictionary containing the output of the tracking step, including updated features and predictions.
@@ -1364,9 +1348,10 @@ class SAM2VideoPredictor(SAM2Predictor):
- The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.
- The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements.
"""
+ inference_state = inference_state or self.inference_state
# Retrieve correct image features
current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(
- self.inference_state["im"], batch_size
+ inference_state["im"], batch_size
)
# point and mask should not appear as input simultaneously on the same frame
@@ -1380,7 +1365,7 @@ class SAM2VideoPredictor(SAM2Predictor):
point_inputs=point_inputs,
mask_inputs=mask_inputs,
output_dict=output_dict,
- num_frames=self.inference_state["num_frames"],
+ num_frames=inference_state["num_frames"],
track_in_reverse=reverse,
run_mem_encoder=run_mem_encoder,
prev_sam_mask_logits=prev_sam_mask_logits,
@@ -1398,10 +1383,10 @@ class SAM2VideoPredictor(SAM2Predictor):
# pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
- current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"])
+ current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"], inference_state)
return current_out
- def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
+ def _get_maskmem_pos_enc(self, out_maskmem_pos_enc, inference_state: dict[str, Any] | None = None):
"""Cache and manage the positional encoding for mask memory across frames and objects.
This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for mask memory, which is
@@ -1413,6 +1398,8 @@ class SAM2VideoPredictor(SAM2Predictor):
Args:
out_maskmem_pos_enc (list[torch.Tensor] | None): The positional encoding for mask memory. Should be a list
of tensors or None.
+ inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's
+ inference state.
Returns:
(list[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
@@ -1423,7 +1410,8 @@ class SAM2VideoPredictor(SAM2Predictor):
- The method checks if the positional encoding has already been cached in the session's constants.
- If the batch size is greater than one, the cached encoding is expanded to fit the batch size.
"""
- model_constants = self.inference_state["constants"]
+ inference_state = inference_state or self.inference_state
+ model_constants = inference_state["constants"]
# "out_maskmem_pos_enc" should be either a list of tensors or None
if out_maskmem_pos_enc is not None:
if "maskmem_pos_enc" not in model_constants:
@@ -1444,6 +1432,7 @@ class SAM2VideoPredictor(SAM2Predictor):
frame_idx,
is_cond=False,
run_mem_encoder=False,
+ inference_state: dict[str, Any] | None = None,
):
"""Consolidate per-object temporary outputs into a single output for all objects.
@@ -1457,6 +1446,8 @@ class SAM2VideoPredictor(SAM2Predictor):
is_cond (bool, optional): Indicates if the frame is considered a conditioning frame.
run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after consolidating the
outputs.
+ inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's
+ inference state.
Returns:
(dict): A consolidated output dictionary containing the combined results for all objects.
@@ -1467,7 +1458,8 @@ class SAM2VideoPredictor(SAM2Predictor):
- If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.
- The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True.
"""
- batch_size = len(self.inference_state["obj_idx_to_id"])
+ inference_state = inference_state or self.inference_state
+ batch_size = len(inference_state["obj_idx_to_id"])
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
@@ -1478,7 +1470,8 @@ class SAM2VideoPredictor(SAM2Predictor):
"maskmem_features": None,
"maskmem_pos_enc": None,
"pred_masks": torch.full(
- size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
+ # size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
+ size=(batch_size, 1, *self._bb_feat_sizes[0]),
fill_value=-1024.0,
dtype=self.torch_dtype,
device=self.device,
@@ -1499,8 +1492,8 @@ class SAM2VideoPredictor(SAM2Predictor):
),
}
for obj_idx in range(batch_size):
- obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
- obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
out = (
obj_temp_output_dict[storage_key].get(frame_idx)
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
@@ -1540,21 +1533,25 @@ class SAM2VideoPredictor(SAM2Predictor):
high_res_masks=high_res_masks,
is_mask_from_pts=True, # these frames are what the user interacted with
object_score_logits=consolidated_out["object_score_logits"],
+ inference_state=inference_state,
)
return consolidated_out
- def _get_empty_mask_ptr(self, frame_idx):
+ def _get_empty_mask_ptr(self, frame_idx, inference_state: dict[str, Any] | None = None):
"""Get a dummy object pointer based on an empty mask on the current frame.
Args:
frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
+ inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's
+ inference state.
Returns:
(torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask.
"""
+ inference_state = inference_state or self.inference_state
# Retrieve correct image features
- current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"])
+ current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(inference_state["im"])
# Feed the empty mask and image feature above to get a dummy object pointer
current_out = self.model.track_step(
@@ -1567,14 +1564,21 @@ class SAM2VideoPredictor(SAM2Predictor):
# A dummy (empty) mask with a single object
mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=self.torch_dtype, device=self.device),
output_dict={},
- num_frames=self.inference_state["num_frames"],
+ num_frames=inference_state["num_frames"],
track_in_reverse=False,
run_mem_encoder=False,
prev_sam_mask_logits=None,
)
return current_out["obj_ptr"]
- def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
+ def _run_memory_encoder(
+ self,
+ batch_size,
+ high_res_masks,
+ object_score_logits,
+ is_mask_from_pts,
+ inference_state: dict[str, Any] | None = None,
+ ):
"""Run the memory encoder on masks.
This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
@@ -1585,13 +1589,16 @@ class SAM2VideoPredictor(SAM2Predictor):
high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory.
object_score_logits (torch.Tensor): Logits representing the object scores.
is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
+ inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's
+ inference state.
Returns:
maskmem_features (torch.Tensor): The encoded mask features.
maskmem_pos_enc (torch.Tensor): The positional encoding.
"""
+ inference_state = inference_state or self.inference_state
# Retrieve correct image features
- current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)
+ current_vision_feats, _, feat_sizes = self.get_im_features(inference_state["im"], batch_size)
maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
@@ -1601,12 +1608,14 @@ class SAM2VideoPredictor(SAM2Predictor):
)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
- maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)
+ maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc, inference_state)
return maskmem_features.to(
dtype=torch.float16, device=self.device, non_blocking=self.device.type == "cuda"
), maskmem_pos_enc
- def _add_output_per_object(self, frame_idx, current_out, storage_key):
+ def _add_output_per_object(
+ self, frame_idx, current_out, storage_key, inference_state: dict[str, Any] | None = None
+ ):
"""Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
The resulting slices share the same tensor storage.
@@ -1615,14 +1624,17 @@ class SAM2VideoPredictor(SAM2Predictor):
frame_idx (int): The index of the current frame.
current_out (dict): The current output dictionary containing multi-object outputs.
storage_key (str): The key used to store the output in the per-object output dictionary.
+ inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's
+ inference state.
"""
+ inference_state = inference_state or self.inference_state
maskmem_features = current_out["maskmem_features"]
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
maskmem_pos_enc = current_out["maskmem_pos_enc"]
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
- for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items():
+ for obj_idx, obj_output_dict in inference_state["output_dict_per_obj"].items():
obj_slice = slice(obj_idx, obj_idx + 1)
obj_out = {
"maskmem_features": None,
@@ -1636,7 +1648,7 @@ class SAM2VideoPredictor(SAM2Predictor):
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
obj_output_dict[storage_key][frame_idx] = obj_out
- def _clear_non_cond_mem_around_input(self, frame_idx):
+ def _clear_non_cond_mem_around_input(self, frame_idx, inference_state: dict[str, Any] | None = None):
"""Remove the non-conditioning memory around the input frame.
When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain
@@ -1646,15 +1658,179 @@ class SAM2VideoPredictor(SAM2Predictor):
Args:
frame_idx (int): The index of the current frame where user interaction occurred.
+ inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's
+ inference state.
"""
+ inference_state = inference_state or self.inference_state
r = self.model.memory_temporal_stride_for_eval
frame_idx_begin = frame_idx - r * self.model.num_maskmem
frame_idx_end = frame_idx + r * self.model.num_maskmem
for t in range(frame_idx_begin, frame_idx_end + 1):
- self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
- for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
+ inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
+ @smart_inference_mode()
+ def remove_object(self, inference_state, obj_id, strict=False):
+ """Remove an object id from the tracking state. If strict is True, we check whether the object id actually
+ exists and raise an error if it doesn't exist.
+ """
+ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
+ # Check whether this object_id to remove actually exists and possibly raise an error.
+ if old_obj_idx_to_rm is None:
+ if not strict:
+ return inference_state["obj_ids"]
+ raise RuntimeError(
+ f"Cannot remove object id {obj_id} as it doesn't exist. "
+ f"All existing object ids: {inference_state['obj_ids']}."
+ )
+
+ # If this is the only remaining object id, we simply reset the state.
+ if len(inference_state["obj_id_to_idx"]) == 1:
+ self.clear_all_points_in_video(inference_state)
+ return inference_state["obj_ids"]
+
+ # There are still remaining objects after removing this object id. In this case,
+ # we need to delete the object storage from inference state tensors.
+ # Step 0: clear the input on those frames where this object id has point or mask input
+ # (note that this step is required as it might downgrade conditioning frames to
+ # non-conditioning ones)
+ obj_input_frames_inds = set()
+ obj_input_frames_inds.update(inference_state["point_inputs_per_obj"][old_obj_idx_to_rm])
+ obj_input_frames_inds.update(inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm])
+ for frame_idx in obj_input_frames_inds:
+ self.clear_all_points_in_frame(inference_state, frame_idx, obj_id)
+
+ # Step 1: Update the object id mapping (note that it must be done after Step 0,
+ # since Step 0 still requires the old object id mappings in inference_state)
+ old_obj_ids = inference_state["obj_ids"]
+ old_obj_inds = list(range(len(old_obj_ids)))
+ remain_old_obj_inds = old_obj_inds.copy()
+ remain_old_obj_inds.remove(old_obj_idx_to_rm)
+ new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
+ new_obj_inds = list(range(len(new_obj_ids)))
+ # build new mappings
+ old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
+ inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
+ inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
+ inference_state["obj_ids"] = new_obj_ids
+
+ # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
+ # (note that "consolidated_frame_inds" doesn't need to be updated in this step as
+ # it's already handled in Step 0)
+ def _map_keys(container):
+ new_kvs = []
+ for k in old_obj_inds:
+ v = container.pop(k)
+ if k in old_idx_to_new_idx:
+ new_kvs.append((old_idx_to_new_idx[k], v))
+ container.update(new_kvs)
+
+ _map_keys(inference_state["point_inputs_per_obj"])
+ _map_keys(inference_state["mask_inputs_per_obj"])
+ _map_keys(inference_state["output_dict_per_obj"])
+ _map_keys(inference_state["temp_output_dict_per_obj"])
+
+ # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
+ def _slice_state(output_dict, storage_key):
+ for frame_idx, out in output_dict[storage_key].items():
+ out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
+ out["maskmem_pos_enc"] = [x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]]
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(out["maskmem_pos_enc"], inference_state)
+ out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
+ out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
+ out["object_score_logits"] = out["object_score_logits"][remain_old_obj_inds]
+ # also update the per-object slices
+ self._add_output_per_object(frame_idx, out, storage_key, inference_state=inference_state)
+
+ _slice_state(inference_state["output_dict"], "cond_frame_outputs")
+ _slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
+
+ return inference_state["obj_ids"]
+
+ @smart_inference_mode()
+ def clear_all_points_in_frame(self, inference_state, frame_idx, obj_id):
+ """Remove all input points or mask in a specific frame for a given object."""
+ obj_idx = self._obj_id_to_idx(obj_id, inference_state)
+
+ # Clear the conditioning information on the given frame
+ inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
+ inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
+
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+ temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
+ temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
+
+ # Check and see if there are still any inputs left on this frame
+ batch_size = len(inference_state["obj_idx_to_id"])
+ frame_has_input = False
+ for obj_idx2 in range(batch_size):
+ if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
+ frame_has_input = True
+ break
+ if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
+ frame_has_input = True
+ break
+
+ # If this frame has no remaining inputs for any objects, we further clear its
+ # conditioning frame status
+ if not frame_has_input:
+ output_dict = inference_state["output_dict"]
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+ consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
+ # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
+ out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
+ if out is not None:
+ # The frame is not a conditioning frame anymore since it's not receiving inputs,
+ # so we "downgrade" its output (if exists) to a non-conditioning frame output.
+ output_dict["non_cond_frame_outputs"][frame_idx] = out
+ inference_state["frames_already_tracked"].pop(frame_idx, None)
+ # Similarly, do it for the sliced output on each object.
+ for obj_idx2 in range(batch_size):
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
+ obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
+ if obj_out is not None:
+ obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
+
+ # If all the conditioning frames have been removed, we also clear the tracking outputs
+ if len(output_dict["cond_frame_outputs"]) == 0:
+ self._reset_tracking_results(inference_state)
+
+ @smart_inference_mode()
+ def clear_all_points_in_video(self, inference_state):
+ """Remove all input points or mask in all frames throughout the video."""
+ self._reset_tracking_results(inference_state)
+ # Remove all object ids
+ inference_state["obj_id_to_idx"].clear()
+ inference_state["obj_idx_to_id"].clear()
+ inference_state["obj_ids"].clear()
+ inference_state["point_inputs_per_obj"].clear()
+ inference_state["mask_inputs_per_obj"].clear()
+ inference_state["output_dict_per_obj"].clear()
+ inference_state["temp_output_dict_per_obj"].clear()
+
+ def _reset_tracking_results(self, inference_state):
+ """Reset all tracking inputs and results across the videos."""
+ for v in inference_state["point_inputs_per_obj"].values():
+ v.clear()
+ for v in inference_state["mask_inputs_per_obj"].values():
+ v.clear()
+ for v in inference_state["output_dict_per_obj"].values():
+ v["cond_frame_outputs"].clear()
+ v["non_cond_frame_outputs"].clear()
+ for v in inference_state["temp_output_dict_per_obj"].values():
+ v["cond_frame_outputs"].clear()
+ v["non_cond_frame_outputs"].clear()
+ inference_state["output_dict"]["cond_frame_outputs"].clear()
+ inference_state["output_dict"]["non_cond_frame_outputs"].clear()
+ inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
+ inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
+ inference_state["tracking_has_started"] = False
+ inference_state["frames_already_tracked"].clear()
+ inference_state["first_ann_frame_idx"] = None
+
class SAM2DynamicInteractivePredictor(SAM2Predictor):
"""SAM2DynamicInteractivePredictor extends SAM2Predictor to support dynamic interactions with video frames or a
@@ -1986,3 +2162,1785 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
"obj_ptr": obj_ptr,
"object_score_logits": object_score_logits,
}
+
+
+class SAM3Predictor(SAM2Predictor):
+ """Segment Anything Model 3 (SAM3) Interactive Predictor for image segmentation tasks."""
+
+ _bb_feat_sizes = [
+ (288, 288),
+ (144, 144),
+ (72, 72),
+ ]
+ stride = 14
+
+ def setup_model(self, model=None, verbose=True):
+ """Setup the SAM3 model with appropriate mean and standard deviation for preprocessing."""
+ super().setup_model(model, verbose)
+ # update mean and std
+ self.mean = torch.tensor([127.5, 127.5, 127.5]).view(-1, 1, 1).to(self.device)
+ self.std = torch.tensor([127.5, 127.5, 127.5]).view(-1, 1, 1).to(self.device)
+
+ def get_model(self):
+ """Retrieve and initialize the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
+ from .build_sam3 import build_interactive_sam3 # slow import
+
+ return build_interactive_sam3(self.args.model, compile=self.args.compile)
+
+
+class SAM3SemanticPredictor(SAM3Predictor):
+ """Segment Anything Model 3 (SAM3) Predictor for image segmentation tasks."""
+
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None, bpe_path=None):
+ """Initialize the SAM3SemanticPredictor with configuration and optional overrides."""
+ super().__init__(cfg, overrides, _callbacks)
+ self.bpe_path = bpe_path
+
+ def get_model(self):
+ """Retrieve and initialize the Segment Anything Model 3 (SAM3) for image segmentation tasks."""
+ from .build_sam3 import build_sam3_image_model # slow import
+
+ return build_sam3_image_model(self.args.model, bpe_path=self.bpe_path, compile=self.args.compile)
+
+ @smart_inference_mode()
+ def get_im_features(self, im):
+ """Extract image features using the model's backbone."""
+ return self.model.backbone.forward_image(im)
+
+ def pre_transform(self, im):
+ """Perform initial transformations on the input image for preprocessing.
+
+ This method applies transformations such as resizing to prepare the image for further preprocessing. Currently,
+ batched inference is not supported; hence the list length should be 1.
+
+ Args:
+ im (list[np.ndarray]): List containing a single image in HWC numpy array format.
+
+ Returns:
+ (list[np.ndarray]): List containing the transformed image.
+
+ Raises:
+ AssertionError: If the input list contains more than one image.
+
+ Examples:
+ >>> predictor = Predictor()
+ >>> image = np.random.rand(480, 640, 3) # Single HWC image
+ >>> transformed = predictor.pre_transform([image])
+ >>> print(len(transformed))
+ 1
+ """
+ assert len(im) == 1, "SAM model does not currently support batched inference"
+ letterbox = LetterBox(self.imgsz, auto=False, center=False, scale_fill=True) # hardcode here for sam3
+ return [letterbox(image=x) for x in im]
+
+ def _prepare_geometric_prompts(self, src_shape, bboxes=None, labels=None):
+ """Prepare prompts by normalizing bounding boxes and points to the destination shape."""
+ if bboxes is not None:
+ bboxes = torch.as_tensor(bboxes, dtype=self.torch_dtype, device=self.device)
+ bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
+ # needs xywh as input
+ bboxes = ops.xyxy2xywh(bboxes)
+ bboxes[:, 0::2] /= src_shape[1]
+ bboxes[:, 1::2] /= src_shape[0]
+ # Assuming labels are all positive if users don't pass labels.
+ if labels is None:
+ labels = np.ones(bboxes.shape[:-1])
+ labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
+ assert bboxes.shape[-2] == labels.shape[-1], (
+ f"Number of points {bboxes.shape[-2]} should match number of labels {labels.shape[-1]}."
+ )
+ bboxes = bboxes.view(-1, 1, 4) # (N, 1, 4)
+ labels = labels.view(-1, 1) # (N, 1)
+ return bboxes, labels
+
+ def _inference_features(self, features, bboxes=None, labels=None, text: list[str] | None = None):
+ """Run inference on the extracted features with optional bounding boxes and labels."""
+ # NOTE: priority: bboxes > text > pre-set classes
+ nc = 1 if bboxes is not None else len(text) if text is not None else len(self.model.names)
+ geometric_prompt = self._get_dummy_prompt(nc)
+ if bboxes is not None:
+ for i in range(len(bboxes)):
+ geometric_prompt.append_boxes(bboxes[[i]], labels[[i]])
+ if text is None:
+ text = ["visual"] # bboxes needs this `visual` text prompt if no text passed
+ if text is not None and self.model.names != text:
+ self.model.set_classes(text=text)
+ outputs = self.model.forward_grounding(
+ backbone_out=features,
+ text_ids=torch.arange(nc, device=self.device, dtype=torch.long),
+ geometric_prompt=geometric_prompt,
+ )
+ return outputs
+
+ def postprocess(self, preds, img, orig_imgs):
+ """Post-process the predictions to apply non-overlapping constraints if required."""
+ pred_boxes = preds["pred_boxes"] # (nc, num_query, 4)
+ pred_logits = preds["pred_logits"]
+ pred_masks = preds["pred_masks"]
+ pred_scores = pred_logits.sigmoid()
+ presence_score = preds["presence_logit_dec"].sigmoid().unsqueeze(1)
+ pred_scores = (pred_scores * presence_score).squeeze(-1)
+ pred_cls = torch.tensor(
+ list(range(pred_scores.shape[0])),
+ dtype=pred_scores.dtype,
+ device=pred_scores.device,
+ )[:, None].expand_as(pred_scores)
+ pred_boxes = torch.cat([pred_boxes, pred_scores[..., None], pred_cls[..., None]], dim=-1)
+
+ keep = pred_scores > self.args.conf
+ pred_masks = pred_masks[keep]
+ pred_boxes = pred_boxes[keep]
+ pred_boxes[:, :4] = ops.xywh2xyxy(pred_boxes[:, :4])
+
+ names = getattr(self.model, "names", [str(i) for i in range(pred_scores.shape[0])])
+ if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
+ orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+ results = []
+ for masks, boxes, orig_img, img_path in zip([pred_masks], [pred_boxes], orig_imgs, self.batch[0]):
+ if masks.shape[0] == 0:
+ masks, boxes = None, torch.zeros((0, 6), device=pred_masks.device)
+ else:
+ masks = F.interpolate(masks.float()[None], orig_img.shape[:2], mode="bilinear")[0] > 0.5
+ boxes[..., [0, 2]] *= orig_img.shape[1]
+ boxes[..., [1, 3]] *= orig_img.shape[0]
+ results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=boxes))
+ return results
+
+ def inference(self, im, bboxes=None, labels=None, text: list[str] | None = None, *args, **kwargs):
+ """Perform inference on a single image with optional prompts."""
+ bboxes = self.prompts.pop("bboxes", bboxes)
+ labels = self.prompts.pop("labels", labels)
+ text = self.prompts.pop("text", text)
+ features = self.get_im_features(im) if self.features is None else self.features
+ prompts = self._prepare_geometric_prompts(self.batch[1][0].shape[:2], bboxes, labels)
+ return self._inference_features(features, *prompts, text=text)
+
+ @smart_inference_mode()
+ def inference_features(
+ self,
+ features,
+ src_shape,
+ bboxes=None,
+ labels=None,
+ text: list[str] | None = None,
+ ):
+ """Perform prompts preprocessing and inference on provided image features using the SAM model.
+
+ Args:
+ features (dict[str, Any]): Extracted image features from the SAM3 model image encoder.
+ src_shape (tuple[int, int]): The source shape (height, width) of the input image.
+ bboxes (np.ndarray | list[list[float]] | None): Bounding boxes in xyxy format with shape (N, 4). pixels.
+ labels (np.ndarray | list[int] | None): Point prompt labels with shape (N, ).
+ text (list[str] | None): List of text prompts corresponding to the classes.
+
+ Returns:
+ pred_masks (torch.Tensor): The output masks in shape (C, H, W), where C is the number of generated masks.
+ pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 6), where N is the number of boxes.
+ Each box is in xyxy format with additional columns for score and class.
+
+ Notes:
+ - The input features is a torch.Tensor of shape (B, C, H, W) if performing on SAM, or a dict[str, Any] if performing on SAM2.
+ """
+ prompts = self._prepare_geometric_prompts(src_shape[:2], bboxes, labels)
+ preds = self._inference_features(features, *prompts, text=text)
+ pred_boxes = preds["pred_boxes"] # (nc, num_query, 4)
+ pred_logits = preds["pred_logits"]
+ pred_masks = preds["pred_masks"]
+ pred_scores = pred_logits.sigmoid()
+ presence_score = preds["presence_logit_dec"].sigmoid().unsqueeze(1)
+ pred_scores = (pred_scores * presence_score).squeeze(-1)
+ pred_cls = torch.tensor(
+ list(range(pred_scores.shape[0])),
+ dtype=pred_scores.dtype,
+ device=pred_scores.device,
+ )[:, None].expand_as(pred_scores)
+ pred_boxes = torch.cat([pred_boxes, pred_scores[..., None], pred_cls[..., None]], dim=-1)
+
+ keep = pred_scores > self.args.conf
+ pred_masks = pred_masks[keep]
+ pred_boxes = pred_boxes[keep]
+ pred_boxes[:, :4] = ops.xywh2xyxy(pred_boxes[:, :4])
+
+ if pred_masks.shape[0] == 0:
+ pred_masks, pred_boxes = None, torch.zeros((0, 6), device=pred_masks.device)
+ else:
+ pred_masks = F.interpolate(pred_masks.float()[None], src_shape[:2], mode="bilinear")[0] > 0.5
+ pred_boxes[..., 0] *= src_shape[1]
+ pred_boxes[..., 1] *= src_shape[0]
+ pred_boxes[..., 2] *= src_shape[1]
+ pred_boxes[..., 3] *= src_shape[0]
+ return pred_masks, pred_boxes
+
+ def reset_prompts(self):
+ """Reset the prompts for the predictor."""
+ self.prompts = {}
+ self.model.text_embeddings = {}
+
+ def _get_dummy_prompt(self, num_prompts=1):
+ """Get a dummy geometric prompt with zero boxes."""
+ geometric_prompt = Prompt(
+ box_embeddings=torch.zeros(0, num_prompts, 4, device=self.device),
+ box_mask=torch.zeros(num_prompts, 0, device=self.device, dtype=torch.bool),
+ )
+ return geometric_prompt
+
+
+class SAM3VideoPredictor(SAM2VideoPredictor, SAM3Predictor):
+ """Segment Anything Model 3 (SAM3) Video Predictor for video segmentation tasks."""
+
+ def propagate_in_video(self, inference_state, frame_idx):
+ """Perform image segmentation inference based on the given input cues, using the currently loaded image. This
+ method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
+ encoder, and mask decoder for real-time and promptable segmentation tasks.
+
+ Args:
+ inference_state (dict): The current state of inference, including input cues and previous outputs.
+ frame_idx (int): The index of the current frame in the video sequence.
+ """
+ frame = frame_idx
+ output_dict = inference_state["output_dict"]
+ obj_ids = inference_state["obj_ids"]
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+ batch_size = len(inference_state["obj_idx_to_id"])
+ if len(output_dict["cond_frame_outputs"]) == 0:
+ raise RuntimeError("No points are provided; please add points first")
+
+ if frame in consolidated_frame_inds["cond_frame_outputs"]:
+ storage_key = "cond_frame_outputs"
+ current_out = output_dict[storage_key][frame]
+ if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_non_cond_mem_around_input(frame)
+ elif frame in consolidated_frame_inds["non_cond_frame_outputs"]:
+ storage_key = "non_cond_frame_outputs"
+ current_out = output_dict[storage_key][frame]
+ else:
+ storage_key = "non_cond_frame_outputs"
+ current_out = self._run_single_frame_inference(
+ output_dict=output_dict,
+ frame_idx=frame,
+ batch_size=batch_size,
+ is_init_cond_frame=False,
+ point_inputs=None,
+ mask_inputs=None,
+ reverse=False,
+ run_mem_encoder=True,
+ inference_state=inference_state,
+ )
+ output_dict[storage_key][frame] = current_out
+ # Create slices of per-object outputs for subsequent interaction with each
+ # individual object after tracking.
+ self._add_output_per_object(frame, current_out, storage_key, inference_state=inference_state)
+ inference_state["frames_already_tracked"].append(frame)
+ pred_masks = current_out["pred_masks"].flatten(0, 1)
+ obj_scores = current_out["object_score_logits"]
+
+ return obj_ids, pred_masks, obj_scores
+
+ def get_im_features(self, im, batch=1):
+ """A wrapper to get image features, supporting pre-extracted backbone outputs."""
+ if getattr(self, "backbone_out", None):
+ backbone_out = self.backbone_out
+ if batch > 1: # expand features if there's more than one prompt
+ backbone_out = {
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
+ }
+ for i, feat in enumerate(backbone_out["backbone_fpn"]):
+ backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1)
+ for i, pos in enumerate(backbone_out["vision_pos_enc"]):
+ pos = pos.expand(batch, -1, -1, -1)
+ backbone_out["vision_pos_enc"][i] = pos
+ _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out)
+ return vis_feats, vis_pos_embed, feat_sizes
+ return super().get_im_features(im, batch)
+
+
+class SAM3VideoSemanticPredictor(SAM3SemanticPredictor):
+ """Segment Anything Model 3 (SAM3) Video Semantic Predictor."""
+
+ HIGH_CONF_THRESH = 0.8
+ HIGH_IOU_THRESH = 0.8
+ NO_OBJ_LOGIT = -10.0
+ NEVER_OCCLUDED = -1
+ ALWAYS_OCCLUDED = 100000
+
+ UNCONFIRMED = 1 # newly added masklet, not confirmed by any detection yet
+ CONFIRMED = 2 # confirmed by at least one detection
+ _bb_feat_sizes = [
+ (288, 288),
+ (144, 144),
+ (72, 72),
+ ]
+ stride = 14
+
+ def __init__(
+ self,
+ cfg=DEFAULT_CFG,
+ overrides=None,
+ _callbacks=None,
+ bpe_path="bpe_simple_vocab_16e6.txt.gz",
+ # prob threshold for detection outputs -- only keep detections above this threshold
+ # enters NMS and det-to-track matching
+ score_threshold_detection=0.5,
+ # IoU threshold for detection NMS
+ det_nms_thresh=0.0,
+ # IoU threshold for det-to-track matching -- a detection is considered "matched" to a tracklet it
+ # overlaps with a tracklet above this threshold -- it is often a loose threshold like 0.1
+ assoc_iou_thresh=0.5,
+ # IoU threshold for det-to-track matching, which is used to determine whether a masklet is "unmatched"
+ # by any detections -- it is often a stricter threshold like 0.5
+ trk_assoc_iou_thresh=0.5,
+ # prob threshold for a detection to be added as a new object
+ new_det_thresh=0.0,
+ # hotstart parameters: we hold off the outputs for `hotstart_delay` frames and
+ # 1) remove those tracklets unmatched by any detections based on `hotstart_unmatch_thresh`
+ # 2) remove those tracklets overlapping with one another based on `hotstart_dup_thresh`
+ hotstart_delay=0,
+ hotstart_unmatch_thresh=3,
+ hotstart_dup_thresh=3,
+ # Whether to suppress masks only within hotstart. If False, we can suppress masks even if they start before hotstart period.
+ suppress_unmatched_only_within_hotstart=True,
+ init_trk_keep_alive=0,
+ max_trk_keep_alive=8,
+ min_trk_keep_alive=-4,
+ # Threshold for suppressing overlapping objects based on recent occlusion
+ suppress_overlapping_based_on_recent_occlusion_threshold=0.0,
+ decrease_trk_keep_alive_for_empty_masklets=False,
+ o2o_matching_masklets_enable=False, # Enable hungarian matching to match existing masklets
+ suppress_det_close_to_boundary=False,
+ fill_hole_area=16,
+ # The maximum number of objects (masklets) to track across all GPUs (for no limit, set it to -1)
+ max_num_objects=-1,
+ recondition_every_nth_frame=-1,
+ # masket confirmation status (to suppress unconfirmed masklets)
+ masklet_confirmation_enable=False,
+ # a masklet is confirmed after being consecutively detected and matched for
+ # `masklet_confirmation_consecutive_det_thresh`
+ masklet_confirmation_consecutive_det_thresh=3,
+ # bbox heuristic parameters
+ reconstruction_bbox_iou_thresh=0.0,
+ reconstruction_bbox_det_score=0.0,
+ ):
+ """Initialize the SAM3VideoSemanticPredictor with configuration and optional overrides."""
+ super().__init__(cfg, overrides, _callbacks, bpe_path=bpe_path)
+ self.score_threshold_detection = score_threshold_detection
+ self.det_nms_thresh = det_nms_thresh
+ self.assoc_iou_thresh = assoc_iou_thresh
+ self.trk_assoc_iou_thresh = trk_assoc_iou_thresh
+ self.new_det_thresh = new_det_thresh
+
+ # hotstart parameters
+ if hotstart_delay > 0:
+ assert hotstart_unmatch_thresh <= hotstart_delay
+ assert hotstart_dup_thresh <= hotstart_delay
+ self.hotstart_delay = hotstart_delay
+ self.hotstart_unmatch_thresh = hotstart_unmatch_thresh
+ self.hotstart_dup_thresh = hotstart_dup_thresh
+ self.suppress_unmatched_only_within_hotstart = suppress_unmatched_only_within_hotstart
+ self.init_trk_keep_alive = init_trk_keep_alive
+ self.max_trk_keep_alive = max_trk_keep_alive
+ self.min_trk_keep_alive = min_trk_keep_alive
+ self.suppress_overlapping_based_on_recent_occlusion_threshold = (
+ suppress_overlapping_based_on_recent_occlusion_threshold
+ )
+ self.suppress_det_close_to_boundary = suppress_det_close_to_boundary
+ self.decrease_trk_keep_alive_for_empty_masklets = decrease_trk_keep_alive_for_empty_masklets
+ self.o2o_matching_masklets_enable = o2o_matching_masklets_enable
+ self.fill_hole_area = fill_hole_area
+ self._dist_pg_cpu = None # CPU process group (lazy-initialized on first use)
+
+ max_num_objects = 10000 # no limit
+ num_obj_for_compile = 16
+ self.max_num_objects = max_num_objects
+ self.num_obj_for_compile = num_obj_for_compile
+ self.recondition_every_nth_frame = recondition_every_nth_frame
+ self.masklet_confirmation_enable = masklet_confirmation_enable
+ self.masklet_confirmation_consecutive_det_thresh = masklet_confirmation_consecutive_det_thresh
+ self.reconstruction_bbox_iou_thresh = reconstruction_bbox_iou_thresh
+ self.reconstruction_bbox_det_score = reconstruction_bbox_det_score
+
+ # build SAM3 tracker
+ self.tracker = SAM3VideoPredictor(overrides=overrides)
+
+ self.inference_state = {}
+ self.callbacks["on_predict_start"].append(self.init_state)
+
+ def setup_model(self, model=None, verbose=True):
+ """Setup the SAM3VideoSemanticPredictor model."""
+ super().setup_model(model, verbose)
+ from .build_sam3 import build_interactive_sam3
+
+ # Initialize the SAM3 tracker model without backbone (backbone is handled in the detector)
+ model = build_interactive_sam3(self.args.model, with_backbone=False)
+ self.tracker.setup_model(model=model, verbose=False)
+
+ def setup_source(self, source):
+ """Setup the source for the SAM3VideoSemanticPredictor model."""
+ super().setup_source(source)
+ self.tracker.imgsz = self.imgsz
+ self.tracker.model.set_imgsz(self.imgsz)
+ self.tracker._bb_feat_sizes = [[int(x / (self.stride * i)) for x in self.imgsz] for i in [1 / 4, 1 / 2, 1]]
+ self.interpol_size = self.tracker.model.memory_encoder.mask_downsampler.interpol_size
+
+ @staticmethod
+ def init_state(predictor):
+ """Initialize an inference state for the predictor.
+
+ This function sets up the initial state required for performing inference on video data. It includes
+ initializing various dictionaries and ordered dictionaries that will store inputs, outputs, and other metadata
+ relevant to the tracking process.
+
+ Args:
+ predictor (SAM3VideoSemanticPredictor): The predictor object for which to initialize the state.
+ """
+ if len(predictor.inference_state) > 0: # means initialized
+ return
+ assert predictor.dataset is not None
+ assert predictor.dataset.mode == "video"
+ num_frames = predictor.dataset.frames
+ inference_state = {
+ "num_frames": num_frames,
+ "tracker_inference_states": [],
+ "tracker_metadata": {},
+ }
+ inference_state["text_prompt"] = None
+ inference_state["per_frame_geometric_prompt"] = [None] * num_frames
+ predictor.inference_state = inference_state
+
+ def inference(self, im, bboxes=None, labels=None, text: list[str] | None = None, *args, **kwargs):
+ """Perform inference on a video sequence with optional prompts."""
+ frame = self.dataset.frame - 1 # align frame index to be 0-based
+ self.inference_state["im"] = im # only pass image for subsequent frames
+ if "text_ids" not in self.inference_state: # first frame processing
+ self.add_prompt(frame_idx=frame, text=text, bboxes=bboxes, labels=labels)
+ return self._run_single_frame_inference(frame, reverse=False)
+
+ def postprocess(self, preds, img, orig_imgs):
+ """Post-process the predictions to apply non-overlapping constraints if required."""
+ obj_id_to_mask = preds["obj_id_to_mask"] # low res masks
+ curr_obj_ids = sorted(obj_id_to_mask.keys())
+ if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
+ orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+
+ if len(curr_obj_ids) == 0:
+ pred_masks, pred_boxes = None, torch.zeros((0, 7), device=self.device)
+ else:
+ pred_masks = torch.cat([obj_id_to_mask[obj_id] for obj_id in curr_obj_ids], dim=0)
+ pred_masks = F.interpolate(pred_masks.float()[None], orig_imgs[0].shape[:2], mode="bilinear")[0] > 0.5
+ pred_ids = torch.tensor(curr_obj_ids, dtype=torch.int32, device=pred_masks.device)
+ pred_scores = torch.tensor(
+ [preds["obj_id_to_score"][obj_id] for obj_id in curr_obj_ids], device=pred_masks.device
+ )
+ pred_cls = torch.tensor(
+ [preds["obj_id_to_cls"][obj_id] for obj_id in curr_obj_ids], device=pred_masks.device
+ )
+ keep = (pred_scores > self.args.conf) & pred_masks.any(dim=(1, 2))
+ pred_masks = pred_masks[keep]
+ pred_boxes = batched_mask_to_box(pred_masks)
+ pred_boxes = torch.cat(
+ [pred_boxes, pred_ids[keep][:, None], pred_scores[keep][..., None], pred_cls[keep][..., None]], dim=-1
+ )
+ if pred_masks.shape[0] > 1:
+ tracker_scores = torch.tensor(
+ [
+ (
+ preds["obj_id_to_tracker_score"][obj_id]
+ if obj_id in preds["obj_id_to_tracker_score"]
+ else 0.0
+ )
+ for obj_id in curr_obj_ids
+ ],
+ device=pred_masks.device,
+ )[keep]
+ pred_masks = (
+ self._apply_object_wise_non_overlapping_constraints(
+ pred_masks.unsqueeze(1),
+ tracker_scores.unsqueeze(1),
+ background_value=0,
+ ).squeeze(1)
+ ) > 0
+
+ # names = getattr(self.model, "names", [str(i) for i in range(pred_scores.shape[0])])
+ names = dict(enumerate(str(i) for i in range(pred_masks.shape[0])))
+ results = []
+ for masks, boxes, orig_img, img_path in zip([pred_masks], [pred_boxes], orig_imgs, self.batch[0]):
+ results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=boxes))
+ return results
+
+ def _run_single_frame_inference(self, frame_idx, reverse=False, inference_state=None):
+ """Perform inference on a single frame and get its inference results."""
+ inference_state = inference_state or self.inference_state
+ # prepare inputs
+ tracker_states_local = inference_state["tracker_inference_states"]
+ has_text_prompt = inference_state["text_prompt"] is not None
+ has_geometric_prompt = inference_state["per_frame_geometric_prompt"][frame_idx] is not None
+ # run inference for the current frame
+ (
+ obj_id_to_mask,
+ obj_id_to_score,
+ obj_id_to_cls,
+ tracker_states_local_new,
+ tracker_metadata_new,
+ frame_stats,
+ _,
+ ) = self._det_track_one_frame(
+ frame_idx=frame_idx,
+ num_frames=inference_state["num_frames"],
+ reverse=reverse,
+ im=inference_state["im"],
+ text_ids=inference_state["text_ids"],
+ geometric_prompt=(
+ self._get_dummy_prompt(num_prompts=len(inference_state["text_ids"]))
+ if not has_geometric_prompt
+ else inference_state["per_frame_geometric_prompt"][frame_idx]
+ ),
+ tracker_states_local=tracker_states_local,
+ tracker_metadata_prev=inference_state["tracker_metadata"],
+ allow_new_detections=has_text_prompt or has_geometric_prompt,
+ )
+ # update inference state
+ inference_state["tracker_inference_states"] = tracker_states_local_new
+ inference_state["tracker_metadata"] = tracker_metadata_new
+
+ out = {
+ "obj_id_to_mask": obj_id_to_mask,
+ "obj_id_to_score": obj_id_to_score, # first frame detection score
+ "obj_id_to_cls": obj_id_to_cls, # first frame detection score
+ "obj_id_to_tracker_score": tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx],
+ }
+ # removed_obj_ids is only needed on rank 0 to handle hotstart delay buffer
+ metadata = tracker_metadata_new["metadata"]
+ removed_obj_ids = metadata["removed_obj_ids"]
+ out["removed_obj_ids"] = removed_obj_ids
+ out["suppressed_obj_ids"] = metadata["suppressed_obj_ids"][frame_idx]
+ out["frame_stats"] = frame_stats
+ if self.masklet_confirmation_enable:
+ status = metadata["masklet_confirmation"]["status"]
+ is_unconfirmed = status == self.UNCONFIRMED
+ out["unconfirmed_obj_ids"] = tracker_metadata_new["obj_ids_all_gpu"][is_unconfirmed].tolist()
+ else:
+ out["unconfirmed_obj_ids"] = []
+ return out
+
+ @smart_inference_mode()
+ def add_prompt(
+ self,
+ frame_idx,
+ text=None,
+ bboxes=None,
+ labels=None,
+ inference_state=None,
+ ):
+ """Add text, point or box prompts on a single frame. This method returns the inference outputs only on the
+ prompted frame.
+
+ Note that text prompts are NOT associated with a particular frame (i.e. they apply
+ to all frames). However, we only run inference on the frame specified in `frame_idx`.
+ """
+ inference_state = inference_state or self.inference_state
+ assert text is not None or bboxes is not None, "at least one type of prompt (text, boxes) must be provided"
+
+ # 1) handle text prompt
+ use_text = text is not None
+ text = text if use_text else "visual"
+ text_batch = [text] if isinstance(text, str) else text
+ inference_state["text_prompt"] = text if use_text else None
+ n = len(text_batch)
+ text_ids = torch.arange(n, device=self.device, dtype=torch.long)
+ inference_state["text_ids"] = text_ids
+ if text is not None and self.model.names != text:
+ self.model.set_classes(text=text)
+
+ # 2) handle box prompt
+ bboxes, labels = self._prepare_geometric_prompts(self.batch[1][0].shape[:2], bboxes, labels)
+ assert (bboxes is not None) == (labels is not None)
+ geometric_prompt = self._get_dummy_prompt(num_prompts=n)
+ if bboxes is not None:
+ for i in range(len(bboxes)):
+ geometric_prompt.append_boxes(bboxes[[i]], labels[[i]])
+ inference_state["per_frame_geometric_prompt"][frame_idx] = geometric_prompt
+ out = self._run_single_frame_inference(frame_idx, reverse=False, inference_state=inference_state)
+ return frame_idx, out
+
+ def _apply_object_wise_non_overlapping_constraints(self, pred_masks, obj_scores, background_value=-10.0):
+ """Applies non-overlapping constraints object wise (i.e. only one object can claim the overlapping region)."""
+ # Replace pixel scores with object scores
+ pred_masks_single_score = torch.where(pred_masks > 0, obj_scores[..., None, None], background_value)
+ # Apply pixel-wise non-overlapping constraint based on mask scores
+ pixel_level_non_overlapping_masks = self.tracker.model._apply_non_overlapping_constraints(
+ pred_masks_single_score
+ )
+ # Replace object scores with pixel scores. Note, that now only one object can claim the overlapping region
+ pred_masks = torch.where(
+ pixel_level_non_overlapping_masks > 0,
+ pred_masks,
+ torch.clamp(pred_masks, max=background_value),
+ )
+ return pred_masks
+
+ def _det_track_one_frame(
+ self,
+ im: torch.Tensor,
+ text_ids: torch.Tensor,
+ frame_idx: int,
+ num_frames: int,
+ reverse: bool,
+ geometric_prompt: Prompt,
+ tracker_states_local: list[Any],
+ tracker_metadata_prev: dict[str, Any],
+ allow_new_detections: bool = True,
+ ):
+ """This function handles one-step inference for the DenseTracking model in an SPMD manner. At a high-level, all
+ GPUs execute the same function calls as if it's done on a single GPU, while under the hood, some
+ function calls involve distributed computation based on sharded SAM2 states.
+
+ - `input_batch` contains image and other inputs on the entire video; it should be identical across GPUs
+ - `tracker_states_local` holds the local masklet information in this GPU shard
+ - `tracker_metadata_prev` manages the metadata for SAM2 objects, such as which masklet is hold on which GPUs
+ it contains both global and local masklet information
+ """
+ # Step 1: run backbone and detector in a distributed manner -- this is done via Sam3ImageOnVideoMultiGPU,
+ # a MultiGPU model (assigned to `self.detector`) that shards frames in a round-robin manner.
+ det_out = self.run_backbone_and_detection(
+ im=im,
+ text_ids=text_ids,
+ geometric_prompt=geometric_prompt,
+ allow_new_detections=allow_new_detections,
+ )
+
+ # Step 2: each GPU propagates its local SAM2 states to get the SAM2 prediction masks.
+ # the returned `tracker_low_res_masks_global` contains the concatenated masklet predictions
+ # gathered from all GPUs (as if they are propagated on a single GPU). Note that this step only
+ # runs the SAM2 propagation step, but doesn't encode new memory for the predicted masks;
+ # we defer memory encoding to `run_tracker_update_execution_phase` after resolving all heuristics.
+ if tracker_metadata_prev == {}:
+ # initialize masklet metadata if it's uninitialized (empty dict)
+ tracker_metadata_prev.update(self._initialize_metadata())
+ tracker_low_res_masks_global, tracker_obj_scores_global = self.run_tracker_propagation(
+ frame_idx=frame_idx,
+ tracker_states_local=tracker_states_local,
+ tracker_metadata_prev=tracker_metadata_prev,
+ )
+
+ # Step 3: based on detection outputs and the propagated SAM2 prediction masks, we make plans
+ # for SAM2 masklet updates (i.e. which objects to add and remove, how to load-balance them, etc).
+ # We also run SAM2 memory encoder globally in this step to resolve non-overlapping constraints.
+ # **This step should involve all the heuristics needed for any updates.** Most of the update
+ # planning will be done on the master rank (GPU 0) and the resulting plan `tracker_update_plan` is
+ # broadcasted to other GPUs (to be executed in a distributed manner). This step also generates the
+ # new masklet metadata `tracker_metadata_new` (based on its previous version `tracker_metadata_prev`).
+ tracker_update_plan, tracker_metadata_new = self.run_tracker_update_planning_phase(
+ frame_idx=frame_idx,
+ reverse=reverse,
+ det_out=det_out,
+ tracker_low_res_masks_global=tracker_low_res_masks_global,
+ tracker_obj_scores_global=tracker_obj_scores_global,
+ tracker_metadata_prev=tracker_metadata_prev,
+ tracker_states_local=tracker_states_local,
+ )
+
+ # Get reconditioning info from the update plan
+ reconditioned_obj_ids = tracker_update_plan.get("reconditioned_obj_ids", set())
+
+ # Step 4: based on `tracker_update_plan`, each GPU executes the update w.r.t. its local SAM2 inference states
+ tracker_states_local_new = self.run_tracker_update_execution_phase(
+ frame_idx=frame_idx,
+ num_frames=num_frames,
+ det_out=det_out,
+ tracker_states_local=tracker_states_local,
+ tracker_update_plan=tracker_update_plan,
+ )
+
+ # Step 5: finally, build the outputs for this frame (it only needs to be done on GPU 0 since
+ # only GPU 0 will send outputs to the server).
+ obj_id_to_mask = self.build_outputs(
+ det_out=det_out,
+ tracker_low_res_masks_global=tracker_low_res_masks_global,
+ tracker_metadata_prev=tracker_metadata_prev,
+ tracker_update_plan=tracker_update_plan,
+ reconditioned_obj_ids=reconditioned_obj_ids,
+ )
+ obj_id_to_score = tracker_metadata_new["obj_id_to_score"]
+ obj_id_to_cls = tracker_metadata_new["obj_id_to_cls"]
+ # a few statistics for the current frame as a part of the output
+ frame_stats = {
+ "num_obj_tracked": np.sum(tracker_metadata_new["num_obj"]),
+ "num_obj_dropped": tracker_update_plan["num_obj_dropped_due_to_limit"],
+ }
+ # add tracker scores to metadata, it should be fired for frames except the first frame
+ if tracker_obj_scores_global.shape[0] > 0:
+ # Convert tracker_obj_scores_global to sigmoid scores before updating
+ tracker_obj_scores_global = tracker_obj_scores_global.sigmoid().tolist()
+ tracker_obj_ids = tracker_metadata_prev["obj_ids"]
+ tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx].update(
+ dict(zip(tracker_obj_ids, tracker_obj_scores_global))
+ )
+ return (
+ obj_id_to_mask, # a dict: obj_id --> output mask
+ obj_id_to_score, # a dict: obj_id --> output score (prob)
+ obj_id_to_cls, # a dict: obj_id --> output cls (int)
+ tracker_states_local_new,
+ tracker_metadata_new,
+ frame_stats,
+ tracker_obj_scores_global, # a dict: obj_id --> tracker frame-level scores
+ )
+
+ def _suppress_detections_close_to_boundary(self, boxes, margin=0.025):
+ """Suppress detections too close to image edges (for normalized boxes).
+
+ boxes: (N, 4) in xyxy format, normalized [0,1]
+ margin: fraction of image
+ """
+ x_min, y_min, x_max, y_max = boxes.unbind(-1)
+ x_c = (x_min + x_max) / 2
+ y_c = (y_min + y_max) / 2
+ keep = (x_c > margin) & (x_c < 1.0 - margin) & (y_c > margin) & (y_c < 1.0 - margin)
+
+ return keep
+
+ def run_backbone_and_detection(
+ self, im: torch.Tensor, text_ids: torch.Tensor, geometric_prompt: Prompt, allow_new_detections: bool
+ ):
+ """Run backbone and detection for a single frame."""
+ features = self.get_im_features(im)
+ sam3_image_out = self.model.forward_grounding(
+ backbone_out=features, text_ids=text_ids, geometric_prompt=geometric_prompt
+ )
+ det_out = self._extract_detection_outputs(sam3_image_out, allow_new_detections)
+ self._cache_backbone_features(sam3_image_out)
+ return det_out
+
+ def _extract_detection_outputs(self, sam3_image_out, allow_new_detections):
+ """Extract and filter detection outputs."""
+ pred_probs = sam3_image_out["pred_logits"].squeeze(-1).sigmoid()
+ if not allow_new_detections:
+ pred_probs = pred_probs - 1e8
+
+ pred_cls = torch.tensor(
+ list(range(pred_probs.shape[0])),
+ dtype=pred_probs.dtype,
+ device=pred_probs.device,
+ )[:, None].expand_as(pred_probs)
+
+ pred_boxes_xyxy = sam3_image_out["pred_boxes_xyxy"]
+ pred_masks = sam3_image_out["pred_masks"]
+
+ keep = pred_probs > self.score_threshold_detection
+ return {
+ "bbox": pred_boxes_xyxy[keep],
+ "mask": pred_masks[keep],
+ "scores": pred_probs[keep],
+ "cls": pred_cls[keep],
+ }
+
+ def _cache_backbone_features(self, sam3_image_out):
+ """Build and cache SAM2 backbone features."""
+ sam_mask_decoder = self.tracker.model.sam_mask_decoder
+ feats = sam3_image_out["backbone_out"]["sam2_backbone_out"]
+ tracker_backbone_fpn = [
+ sam_mask_decoder.conv_s0(feats["backbone_fpn"][0]),
+ sam_mask_decoder.conv_s1(feats["backbone_fpn"][1]),
+ feats["backbone_fpn"][2],
+ ]
+ tracker_backbone_out = {
+ "vision_features": tracker_backbone_fpn[-1],
+ "vision_pos_enc": feats["vision_pos_enc"],
+ "backbone_fpn": tracker_backbone_fpn,
+ }
+ # cache the SAM2 backbone features for `frame_idx` in the tracker
+ self.tracker.backbone_out = tracker_backbone_out
+
+ def run_tracker_propagation(
+ self, frame_idx: int, tracker_states_local: list[Any], tracker_metadata_prev: dict[str, np.ndarray]
+ ):
+ """Run the tracker propagation phase for a single frame in an SPMD manner."""
+ # Step 1: propagate the local SAM2 states to get the current frame's prediction
+ # `low_res_masks_local` of the existing masklets on this GPU
+ # - obj_ids_local: list[int] -- list of object IDs
+ # - low_res_masks_local: Tensor -- (num_local_obj, H_mask, W_mask)
+ obj_ids_local, low_res_masks_local, obj_scores_local = self._propogate_tracker_one_frame_local_gpu(
+ tracker_states_local, frame_idx=frame_idx
+ )
+
+ assert np.all(obj_ids_local == tracker_metadata_prev["obj_ids"]), "{} != {}".format(
+ obj_ids_local, tracker_metadata_prev["obj_ids"]
+ )
+
+ # Step 2: all-gather `low_res_masks_local` into `low_res_masks_global`
+ # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask)
+ low_res_masks_global = low_res_masks_local
+ obj_scores_global = obj_scores_local
+ return low_res_masks_global, obj_scores_global
+
+ def _recondition_masklets(
+ self,
+ frame_idx,
+ det_out: dict[str, torch.Tensor],
+ trk_id_to_max_iou_high_conf_det: list[int],
+ tracker_states_local: list[Any],
+ tracker_metadata: dict[str, np.ndarray],
+ tracker_obj_scores_global: torch.Tensor,
+ ):
+ """Recondition masklets based on new high-confidence detections."""
+ # Recondition the masklets based on the new detections
+ for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items():
+ new_mask = det_out["mask"][det_idx : det_idx + 1]
+ new_mask_binary = (
+ F.interpolate(new_mask.unsqueeze(1), size=self.interpol_size, mode="bilinear", align_corners=False) > 0
+ )
+ HIGH_CONF_THRESH = 0.8
+ reconditioned_states_idx = set()
+ obj_idx = np.where(tracker_metadata["obj_ids"] == trk_obj_id)[0].item()
+ obj_score = tracker_obj_scores_global[obj_idx]
+ for state_idx, inference_state in enumerate(tracker_states_local):
+ if (
+ trk_obj_id in inference_state["obj_ids"]
+ # NOTE: Goal of this condition is to avoid reconditioning masks that are occluded/low qualiy.
+ # Unfortunately, these can get reconditioned anyway due to batching. We should consider removing these heuristics.
+ and obj_score > HIGH_CONF_THRESH
+ ):
+ LOGGER.debug(
+ f"Adding new mask for track {trk_obj_id} at frame {frame_idx}. Objects {inference_state['obj_ids']} are all reconditioned."
+ )
+ self.tracker.add_new_prompts(
+ inference_state=inference_state,
+ frame_idx=frame_idx,
+ obj_id=trk_obj_id,
+ masks=new_mask_binary,
+ )
+ reconditioned_states_idx.add(state_idx)
+
+ for idx in reconditioned_states_idx:
+ self.tracker.propagate_in_video_preflight(tracker_states_local[idx])
+ return tracker_states_local
+
+ def run_tracker_update_planning_phase(
+ self,
+ frame_idx: int,
+ reverse: bool,
+ det_out: dict[str, torch.Tensor],
+ tracker_low_res_masks_global: torch.Tensor,
+ tracker_obj_scores_global: torch.Tensor,
+ tracker_metadata_prev: dict[str, np.ndarray],
+ tracker_states_local: list[Any],
+ ):
+ """Run the tracker update planning phase for a single frame in an SPMD manner."""
+ # initialize new metadata from previous metadata (its values will be updated later)
+ tracker_metadata_new = {
+ "obj_ids": deepcopy(tracker_metadata_prev["obj_ids"]),
+ "num_obj": deepcopy(tracker_metadata_prev["num_obj"]),
+ "obj_id_to_score": deepcopy(tracker_metadata_prev["obj_id_to_score"]),
+ "obj_id_to_cls": deepcopy(tracker_metadata_prev["obj_id_to_cls"]),
+ "obj_id_to_tracker_score_frame_wise": deepcopy(tracker_metadata_prev["obj_id_to_tracker_score_frame_wise"]),
+ "obj_id_to_last_occluded": {}, # will be filled later
+ "max_obj_id": deepcopy(tracker_metadata_prev["max_obj_id"]),
+ }
+
+ # Initialize reconditioned_obj_ids early to avoid UnboundLocalError
+ reconditioned_obj_ids = set()
+
+ # Step 1: make the update plan and resolve heuristics on GPU 0
+ det_mask_preds: torch.Tensor = det_out["mask"] # low-res mask logits
+ det_scores_np: np.ndarray = det_out["scores"].float().cpu().numpy()
+ det_cls_np: np.ndarray = det_out["cls"].float().cpu().numpy()
+ det_bbox_xyxy: torch.Tensor = det_out["bbox"]
+ # a) match detector and tracker masks and find new objects
+ (
+ new_det_fa_inds,
+ unmatched_trk_obj_ids,
+ det_to_matched_trk_obj_ids,
+ trk_id_to_max_iou_high_conf_det,
+ empty_trk_obj_ids,
+ ) = self._associate_det_trk(
+ det_masks=det_mask_preds,
+ det_scores_np=det_scores_np,
+ trk_masks=tracker_low_res_masks_global,
+ trk_obj_ids=tracker_metadata_prev["obj_ids"],
+ )
+ if self.suppress_det_close_to_boundary:
+ keep = self._suppress_detections_close_to_boundary(det_bbox_xyxy[new_det_fa_inds])
+ new_det_fa_inds = new_det_fa_inds[keep.cpu().numpy()]
+
+ # check whether we've hit the maximum number of objects we can track (and if so, drop some detections)
+ prev_obj_num = np.sum(tracker_metadata_prev["num_obj"])
+ new_det_num = len(new_det_fa_inds)
+ num_obj_dropped_due_to_limit = 0
+ if prev_obj_num + new_det_num > self.max_num_objects:
+ LOGGER.warning(f"hitting {self.max_num_objects=} with {new_det_num=} and {prev_obj_num=}")
+ new_det_num_to_keep = self.max_num_objects - prev_obj_num
+ num_obj_dropped_due_to_limit = new_det_num - new_det_num_to_keep
+ new_det_fa_inds = self._drop_new_det_with_obj_limit(new_det_fa_inds, det_scores_np, new_det_num_to_keep)
+ assert len(new_det_fa_inds) == new_det_num_to_keep
+ new_det_num = len(new_det_fa_inds)
+
+ # assign object IDs to new detections and decide which GPU to place them
+ new_det_obj_ids = tracker_metadata_prev["max_obj_id"] + 1 + np.arange(new_det_num)
+
+ # b) handle hotstart heuristics to remove objects
+ # here `metadata` contains metadata stored on (and only accessible to) GPU 0;
+ # we avoid broadcasting them to other GPUs to save communication cost, assuming
+ # that `metadata` is not needed by other GPUs
+ metadata_new = deepcopy(tracker_metadata_prev["metadata"])
+ if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
+ obj_ids_newly_removed, metadata_new = self._process_hotstart(
+ frame_idx=frame_idx,
+ reverse=reverse,
+ det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
+ new_det_obj_ids=new_det_obj_ids,
+ empty_trk_obj_ids=empty_trk_obj_ids,
+ unmatched_trk_obj_ids=unmatched_trk_obj_ids,
+ metadata=metadata_new,
+ )
+ else:
+ # if warm-up is not complete, we don't remove any objects
+ obj_ids_newly_removed = set()
+ tracker_metadata_new["metadata"] = metadata_new
+
+ # `tracker_update_plan` should be identical on all GPUs after broadcasting
+ tracker_update_plan = {
+ "new_det_fa_inds": new_det_fa_inds, # np.ndarray
+ "new_det_obj_ids": new_det_obj_ids, # np.ndarray
+ # "new_det_gpu_ids": new_det_gpu_ids, # np.ndarray
+ "unmatched_trk_obj_ids": unmatched_trk_obj_ids, # np.ndarray
+ "det_to_matched_trk_obj_ids": det_to_matched_trk_obj_ids, # dict
+ "obj_ids_newly_removed": obj_ids_newly_removed, # set
+ "num_obj_dropped_due_to_limit": num_obj_dropped_due_to_limit, # int
+ "trk_id_to_max_iou_high_conf_det": trk_id_to_max_iou_high_conf_det, # dict
+ "reconditioned_obj_ids": reconditioned_obj_ids, # set
+ }
+
+ # Step 3 (optional): recondition masklets based on high-confidence detections before memory encoding
+ # NOTE: Running this in execution phase (after memory encoding) can lead to suboptimal results
+ should_recondition_iou = False
+
+ # Evaluate tracklets for reconditioning based on bbox IoU mismatch with detections
+ if self.reconstruction_bbox_iou_thresh > 0 and len(trk_id_to_max_iou_high_conf_det) > 0:
+ for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items():
+ det_box = det_out["bbox"][det_idx]
+ det_score = det_out["scores"][det_idx]
+
+ try:
+ trk_idx = list(tracker_metadata_prev["obj_ids"]).index(trk_obj_id)
+ except ValueError:
+ continue # Skip if tracklet not found
+
+ tracker_mask = tracker_low_res_masks_global[trk_idx]
+ mask_binary = tracker_mask > 0
+ mask_area = mask_binary.sum().item()
+
+ if mask_area == 0:
+ continue # Skip tracklets with zero mask area
+
+ # Get bounding box from SAM2 mask and convert to normalized coordinates
+ tracker_box_pixels = batched_mask_to_box(mask_binary.unsqueeze(0)).squeeze(0)
+ mask_height, mask_width = tracker_mask.shape[-2:]
+ tracker_box_normalized = torch.tensor(
+ [
+ tracker_box_pixels[0] / mask_width,
+ tracker_box_pixels[1] / mask_height,
+ tracker_box_pixels[2] / mask_width,
+ tracker_box_pixels[3] / mask_height,
+ ],
+ device=tracker_box_pixels.device,
+ )
+
+ # Compute IoU between detection and SAM2 tracklet bounding boxes
+ det_box_batch = det_box.unsqueeze(0)
+ tracker_box_batch = tracker_box_normalized.unsqueeze(0)
+ iou = box_iou(det_box_batch, tracker_box_batch)[0]
+
+ if iou < self.reconstruction_bbox_iou_thresh and det_score >= self.reconstruction_bbox_det_score:
+ should_recondition_iou = True
+ reconditioned_obj_ids.add(trk_obj_id)
+
+ should_recondition_periodic = (
+ self.recondition_every_nth_frame > 0
+ and frame_idx % self.recondition_every_nth_frame == 0
+ and len(trk_id_to_max_iou_high_conf_det) > 0
+ )
+
+ # Recondition if periodic or IoU condition met
+ if should_recondition_periodic or should_recondition_iou:
+ self._recondition_masklets(
+ frame_idx,
+ det_out,
+ trk_id_to_max_iou_high_conf_det,
+ tracker_states_local,
+ tracker_metadata_prev,
+ tracker_obj_scores_global,
+ )
+
+ # Step 4: Run SAM2 memory encoder on the current frame's prediction masks
+ # This is done on all GPUs
+ batch_size = tracker_low_res_masks_global.size(0)
+ if batch_size > 0:
+ if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
+ if self.suppress_overlapping_based_on_recent_occlusion_threshold > 0.0:
+ # NOTE: tracker_low_res_masks_global is updated in-place then returned
+ tracker_low_res_masks_global = self._suppress_overlapping_based_on_recent_occlusion(
+ frame_idx,
+ tracker_low_res_masks_global,
+ tracker_metadata_prev,
+ tracker_metadata_new,
+ obj_ids_newly_removed,
+ reverse,
+ )
+
+ self._tracker_update_memories(tracker_states_local, frame_idx, low_res_masks=tracker_low_res_masks_global)
+
+ # Step 4: update the SAM2 metadata based on the update plan
+ updated_obj_ids_this_gpu = tracker_metadata_new["obj_ids"]
+ if len(new_det_obj_ids) > 0:
+ updated_obj_ids_this_gpu = np.concatenate([updated_obj_ids_this_gpu, new_det_obj_ids])
+ if len(obj_ids_newly_removed) > 0:
+ is_removed = np.isin(updated_obj_ids_this_gpu, list(obj_ids_newly_removed))
+ updated_obj_ids_this_gpu = updated_obj_ids_this_gpu[~is_removed]
+ tracker_metadata_new["obj_ids"] = updated_obj_ids_this_gpu
+ tracker_metadata_new["num_obj"] = len(updated_obj_ids_this_gpu)
+ # update object scores and the maximum object ID assigned so far
+ if len(new_det_obj_ids) > 0:
+ tracker_metadata_new["obj_id_to_score"].update(zip(new_det_obj_ids, det_scores_np[new_det_fa_inds]))
+ tracker_metadata_new["obj_id_to_cls"].update(zip(new_det_obj_ids, det_cls_np[new_det_fa_inds]))
+ # tracker scores are not available for new objects, use det score instead.
+ tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx].update(
+ zip(new_det_obj_ids, det_scores_np[new_det_fa_inds])
+ )
+ tracker_metadata_new["max_obj_id"] = max(tracker_metadata_new["max_obj_id"], np.max(new_det_obj_ids))
+ # for removed objects, we set their scores to a very low value (-1e4) but still
+ # keep them in "obj_id_to_score" (it's easier to handle outputs this way)
+ for obj_id in obj_ids_newly_removed:
+ tracker_metadata_new["obj_id_to_score"][obj_id] = -1e4
+ tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx][obj_id] = -1e4
+ tracker_metadata_new["obj_id_to_last_occluded"].pop(obj_id, None)
+ # check that "metadata" is in tracker_metadata_new if and only if it's GPU 0
+ assert "metadata" in tracker_metadata_new
+ if self.masklet_confirmation_enable:
+ metadata = self.update_masklet_confirmation_status(
+ metadata=tracker_metadata_new["metadata"],
+ obj_ids_all_gpu_prev=tracker_metadata_prev["obj_ids"],
+ obj_ids_all_gpu_updated=tracker_metadata_new["obj_ids"],
+ det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
+ new_det_obj_ids=new_det_obj_ids,
+ )
+ tracker_metadata_new["metadata"] = metadata
+
+ return tracker_update_plan, tracker_metadata_new
+
+ def _suppress_overlapping_based_on_recent_occlusion(
+ self,
+ frame_idx: int,
+ tracker_low_res_masks_global: torch.Tensor,
+ tracker_metadata_prev: dict[str, Any],
+ tracker_metadata_new: dict[str, Any],
+ obj_ids_newly_removed: set[int],
+ reverse: bool = False,
+ ):
+ """Suppress overlapping masks based on the most recent occlusion information. If an object is removed by
+ hotstart, we always suppress it if it overlaps with any other object.
+
+ Args:
+ frame_idx (int): The current frame index.
+ tracker_low_res_masks_global (torch.Tensor): The low-resolution masks for the current frame.
+ tracker_metadata_prev (dict[str, Any]): The metadata from the previous frame.
+ tracker_metadata_new (dict[str, Any]): The metadata for the current frame.
+ obj_ids_newly_removed (set[int]): The object IDs that have been removed.
+ reverse (bool): Whether the tracking is in reverse order.
+
+ Returns:
+ (torch.Tensor): The updated low-resolution masks with some objects suppressed.
+ """
+ obj_ids_global = tracker_metadata_prev["obj_ids"]
+ binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
+ batch_size = tracker_low_res_masks_global.size(0)
+ if batch_size > 0:
+ assert len(obj_ids_global) == batch_size, (
+ f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
+ )
+ last_occluded_prev = torch.cat(
+ [
+ tracker_metadata_prev["obj_id_to_last_occluded"].get(
+ obj_id,
+ torch.full(
+ (1,),
+ fill_value=(
+ self.NEVER_OCCLUDED if obj_id not in obj_ids_newly_removed else self.ALWAYS_OCCLUDED
+ ),
+ device=binary_tracker_low_res_masks_global.device,
+ dtype=torch.long,
+ ),
+ )
+ for obj_id in obj_ids_global
+ ],
+ dim=0,
+ )
+ to_suppress = self._get_objects_to_suppress_based_on_most_recently_occluded(
+ binary_tracker_low_res_masks_global,
+ last_occluded_prev,
+ obj_ids_global,
+ frame_idx,
+ reverse,
+ )
+
+ # Update metadata with occlusion information
+ is_obj_occluded = ~(binary_tracker_low_res_masks_global.any(dim=(-1, -2)))
+ is_obj_occluded_or_suppressed = is_obj_occluded | to_suppress
+ last_occluded_new = last_occluded_prev.clone()
+ last_occluded_new[is_obj_occluded_or_suppressed] = frame_idx
+ # Slice out the last occluded frame for each object
+ tracker_metadata_new["obj_id_to_last_occluded"] = {
+ obj_id: last_occluded_new[obj_idx : obj_idx + 1] for obj_idx, obj_id in enumerate(obj_ids_global)
+ }
+
+ # Zero out suppressed masks before memory encoding
+ tracker_low_res_masks_global[to_suppress] = self.NO_OBJ_LOGIT
+
+ return tracker_low_res_masks_global
+
+ def run_tracker_update_execution_phase(
+ self,
+ frame_idx: int,
+ num_frames: int,
+ det_out: dict[str, torch.Tensor],
+ tracker_states_local: list[Any],
+ tracker_update_plan: dict[str, np.ndarray],
+ ):
+ """Execute the tracker update plan for a single frame in an SPMD manner."""
+ # initialize tracking scores with detection scores
+ new_det_fa_inds: np.ndarray = tracker_update_plan["new_det_fa_inds"]
+ new_det_obj_ids: np.ndarray = tracker_update_plan["new_det_obj_ids"]
+ # new_det_gpu_ids: np.ndarray = tracker_update_plan["new_det_gpu_ids"]
+ new_det_obj_ids_local: np.ndarray = new_det_obj_ids
+ new_det_fa_inds_local: np.ndarray = new_det_fa_inds
+ obj_ids_newly_removed: set[int] = tracker_update_plan["obj_ids_newly_removed"]
+
+ # Step 1: add new objects from the detector to SAM2 inference states
+ if len(new_det_fa_inds_local) > 0:
+ new_det_fa_inds_local_t = torch.from_numpy(new_det_fa_inds_local)
+ new_det_masks: torch.Tensor = det_out["mask"][new_det_fa_inds_local_t]
+ # initialize SAM2 with new object masks
+ tracker_states_local = self._tracker_add_new_objects(
+ frame_idx=frame_idx,
+ num_frames=num_frames,
+ new_obj_ids=new_det_obj_ids_local,
+ new_obj_masks=new_det_masks,
+ tracker_states_local=tracker_states_local,
+ )
+
+ # Step 2: remove from SAM2 inference states those objects removed by heuristics
+ if len(obj_ids_newly_removed) > 0:
+ self._tracker_remove_objects(tracker_states_local, obj_ids_newly_removed)
+
+ return tracker_states_local
+
+ def build_outputs(
+ self,
+ det_out: dict[str, torch.Tensor],
+ tracker_low_res_masks_global: torch.Tensor,
+ tracker_metadata_prev: dict[str, np.ndarray],
+ tracker_update_plan: dict[str, np.ndarray],
+ reconditioned_obj_ids: set | None = None,
+ ):
+ """Build the output masks for the current frame."""
+ new_det_fa_inds: np.ndarray = tracker_update_plan["new_det_fa_inds"]
+ new_det_obj_ids: np.ndarray = tracker_update_plan["new_det_obj_ids"]
+ obj_id_to_mask = {} # obj_id --> output mask tensor
+
+ # Part 1: masks from previous SAM2 propagation
+ existing_masklet_obj_ids = tracker_metadata_prev["obj_ids"]
+ existing_masklet_binary = tracker_low_res_masks_global.unsqueeze(1)
+ assert len(existing_masklet_obj_ids) == len(existing_masklet_binary)
+ for obj_id, mask in zip(existing_masklet_obj_ids, existing_masklet_binary):
+ obj_id_to_mask[obj_id] = mask # (1, H_video, W_video)
+
+ # Part 2: masks from new detections
+ new_det_fa_inds_t = torch.from_numpy(new_det_fa_inds)
+ new_det_low_res_masks = det_out["mask"][new_det_fa_inds_t].unsqueeze(1)
+ assert len(new_det_obj_ids) == len(new_det_low_res_masks)
+ for obj_id, mask in zip(new_det_obj_ids, new_det_low_res_masks):
+ obj_id_to_mask[obj_id] = mask # (1, H_video, W_video)
+
+ # Part 3: Override masks for reconditioned objects using detection masks
+ if reconditioned_obj_ids is not None and len(reconditioned_obj_ids) > 0:
+ trk_id_to_max_iou_high_conf_det = tracker_update_plan.get("trk_id_to_max_iou_high_conf_det", {})
+
+ for obj_id in reconditioned_obj_ids:
+ det_idx = trk_id_to_max_iou_high_conf_det.get(obj_id)
+
+ if det_idx is not None:
+ obj_id_to_mask[obj_id] = det_out["mask"][det_idx].unsqueeze(0)
+
+ return obj_id_to_mask
+
+ def _get_objects_to_suppress_based_on_most_recently_occluded(
+ self,
+ binary_low_res_masks: torch.Tensor,
+ last_occluded: list[int],
+ obj_ids: list[int],
+ frame_idx: int | None = None,
+ reverse: bool = False,
+ ):
+ # Suppress overlapping masks for objects that were most recently occluded
+ assert binary_low_res_masks.dtype == torch.bool, f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
+ to_suppress = torch.zeros(
+ binary_low_res_masks.size(0),
+ device=binary_low_res_masks.device,
+ dtype=torch.bool,
+ )
+ if len(obj_ids) <= 1:
+ return to_suppress
+
+ iou = mask_iou(binary_low_res_masks.flatten(1), binary_low_res_masks.flatten(1)) # [N,N]
+
+ # Create masks for upper triangular matrix (i < j) and IoU threshold
+ mask_iou_thresh = iou >= self.suppress_overlapping_based_on_recent_occlusion_threshold
+ overlapping_pairs = torch.triu(mask_iou_thresh, diagonal=1) # [N,N]
+
+ last_occ_expanded_i = last_occluded.unsqueeze(1) # (N, 1)
+ last_occ_expanded_j = last_occluded.unsqueeze(0) # (1, N)
+ # Suppress most recently occluded
+ cmp_op = torch.gt if not reverse else torch.lt
+ suppress_i_mask = (
+ overlapping_pairs
+ & cmp_op(last_occ_expanded_i, last_occ_expanded_j) # (last_occ_expanded_i > last_occ_expanded_j)
+ & (last_occ_expanded_j > -1) # j can suppress i only if i was previously occluded
+ )
+ suppress_j_mask = (
+ overlapping_pairs
+ & cmp_op(last_occ_expanded_j, last_occ_expanded_i)
+ & (last_occ_expanded_i > -1) # i can suppress j only if j was previously occluded
+ )
+ # Apply suppression
+ to_suppress = suppress_i_mask.any(dim=1) | suppress_j_mask.any(dim=0)
+
+ # Log for debugging
+ if LOGGER.isEnabledFor(10) and frame_idx is not None:
+ suppress_i_mask = suppress_i_mask.cpu().numpy()
+ suppress_j_mask = suppress_j_mask.cpu().numpy()
+ last_occluded = last_occluded.cpu().numpy()
+
+ # Find all suppression pairs without using torch.where
+ batch_size = suppress_i_mask.shape[0]
+
+ # Log i-suppression cases (where i gets suppressed in favor of j)
+ for i in range(batch_size):
+ for j in range(batch_size):
+ if suppress_i_mask[i, j]:
+ LOGGER.debug(
+ f"{frame_idx=}: Suppressing obj {obj_ids[i]} last occluded {last_occluded[i]} in favor of {obj_ids[j]} last occluded {last_occluded[j]}"
+ )
+
+ # Log j-suppression cases (where j gets suppressed in favor of i)
+ for i in range(batch_size):
+ for j in range(batch_size):
+ if suppress_j_mask[i, j]:
+ LOGGER.debug(
+ f"{frame_idx=}: Suppressing obj {obj_ids[j]} last occluded {last_occluded[j]} in favor of {obj_ids[i]} last occluded {last_occluded[i]}"
+ )
+
+ return to_suppress
+
+ def _propogate_tracker_one_frame_local_gpu(self, inference_states: list[Any], frame_idx: int):
+ """Inference_states: list of inference states, each state corresponds to a different set of objects."""
+ obj_ids_local = []
+ low_res_masks_list = []
+ obj_scores_list = []
+ for inference_state in inference_states:
+ if len(inference_state["obj_ids"]) == 0:
+ continue # skip propagation on empty inference states
+
+ out_obj_ids, out_low_res_masks, out_obj_scores = self.tracker.propagate_in_video(
+ inference_state, frame_idx=frame_idx
+ )
+ assert isinstance(out_obj_ids, list)
+ obj_ids_local.extend(out_obj_ids)
+ low_res_masks_list.append(out_low_res_masks.squeeze(1))
+ obj_scores_list.append(out_obj_scores.squeeze(1))
+
+ # concatenate the output masklets from all local inference states
+ if len(low_res_masks_list) > 0:
+ low_res_masks_local = torch.cat(low_res_masks_list, dim=0)
+ obj_scores_local = torch.cat(obj_scores_list, dim=0)
+ low_res_masks_local = low_res_masks_local.squeeze(1)
+ else:
+ low_res_masks_local = torch.zeros(0, *self._bb_feat_sizes[0], device=self.device)
+ obj_scores_local = torch.zeros(0, device=self.device)
+
+ return obj_ids_local, low_res_masks_local, obj_scores_local
+
+ def _associate_det_trk(
+ self,
+ det_masks: torch.Tensor,
+ det_scores_np: np.ndarray,
+ trk_masks: torch.Tensor,
+ trk_obj_ids: np.ndarray,
+ ):
+ """Match detections on the current frame with the existing masklets.
+
+ Args:
+ det_masks: (N, H, W) tensor of predicted masks
+ det_scores_np: (N,) array of detection scores
+ trk_masks: (M, H, W) tensor of track masks
+ trk_obj_ids: (M,) array of object IDs corresponding to trk_masks
+
+ Returns:
+ new_det_fa_inds: array of new object indices.
+ unmatched_trk_obj_ids: array of existing masklet object IDs that are not matched to any detections on this
+ frame (for unmatched, we only count masklets with >0 area)
+ det_to_matched_trk_obj_ids: dict[int, np.ndarray]: mapping from detector's detection indices to the list of
+ matched tracklet object IDs
+ empty_trk_obj_ids: array of existing masklet object IDs with zero area in SAM2 prediction
+ """
+ iou_threshold = self.assoc_iou_thresh
+ iou_threshold_trk = self.trk_assoc_iou_thresh
+ new_det_thresh = self.new_det_thresh
+
+ assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
+ assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
+ assert trk_masks.size(0) == len(trk_obj_ids), (
+ f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
+ )
+ if trk_masks.size(0) == 0:
+ # all detections are new
+ new_det_fa_inds = np.arange(det_masks.size(0))
+ unmatched_trk_obj_ids = np.array([], np.int64)
+ empty_trk_obj_ids = np.array([], np.int64)
+ det_to_matched_trk_obj_ids = {}
+ trk_id_to_max_iou_high_conf_det = {}
+ return (
+ new_det_fa_inds,
+ unmatched_trk_obj_ids,
+ det_to_matched_trk_obj_ids,
+ trk_id_to_max_iou_high_conf_det,
+ empty_trk_obj_ids,
+ )
+ elif det_masks.size(0) == 0:
+ # all previous tracklets are unmatched if they have a non-zero area
+ new_det_fa_inds = np.array([], np.int64)
+ trk_is_nonempty = (trk_masks > 0).any(dim=(1, 2)).cpu().numpy()
+ unmatched_trk_obj_ids = trk_obj_ids[trk_is_nonempty]
+ empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty]
+ det_to_matched_trk_obj_ids = {}
+ trk_id_to_max_iou_high_conf_det = {}
+ return (
+ new_det_fa_inds,
+ unmatched_trk_obj_ids,
+ det_to_matched_trk_obj_ids,
+ trk_id_to_max_iou_high_conf_det,
+ empty_trk_obj_ids,
+ )
+
+ if det_masks.shape[-2:] != trk_masks.shape[-2:]:
+ # resize to the smaller size to save GPU memory
+ if np.prod(det_masks.shape[-2:]) < np.prod(trk_masks.shape[-2:]):
+ trk_masks = F.interpolate(
+ trk_masks.unsqueeze(1),
+ size=det_masks.shape[-2:],
+ mode="bilinear",
+ align_corners=False,
+ ).squeeze(1)
+ else:
+ # resize detections to track size
+ det_masks = F.interpolate(
+ det_masks.unsqueeze(1),
+ size=trk_masks.shape[-2:],
+ mode="bilinear",
+ align_corners=False,
+ ).squeeze(1)
+
+ det_masks_binary = det_masks > 0
+ trk_masks_binary = trk_masks > 0
+ ious = mask_iou(det_masks_binary.flatten(1).float(), trk_masks_binary.flatten(1).float()) # (N, M)
+
+ ious_np = ious.cpu().numpy()
+ if self.o2o_matching_masklets_enable:
+ from scipy.optimize import linear_sum_assignment
+
+ # Hungarian matching for tracks (one-to-one: each track matches at most one detection)
+ cost_matrix = 1 - ious_np # Hungarian solves for minimum cost
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
+ trk_is_matched = np.zeros(trk_masks.size(0), dtype=bool)
+ for d, t in zip(row_ind, col_ind):
+ if ious_np[d, t] >= iou_threshold_trk:
+ trk_is_matched[t] = True
+ else:
+ trk_is_matched = (ious_np >= iou_threshold_trk).any(axis=0)
+ # Non-empty tracks not matched by Hungarian assignment above threshold are unmatched
+ trk_is_nonempty = trk_masks_binary.any(dim=(1, 2)).cpu().numpy()
+ trk_is_unmatched = np.logical_and(trk_is_nonempty, ~trk_is_matched)
+ unmatched_trk_obj_ids = trk_obj_ids[trk_is_unmatched]
+ # also record masklets that have zero area in SAM 2 prediction
+ empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty]
+
+ # For detections: allow many tracks to match to the same detection (many-to-one)
+ # So, a detection is 'new' if it does not match any track above threshold
+ is_new_det = np.logical_and(
+ det_scores_np >= new_det_thresh,
+ np.logical_not(np.any(ious_np >= iou_threshold, axis=1)),
+ )
+ new_det_fa_inds = np.nonzero(is_new_det)[0]
+
+ # for each detection, which tracks it matched to (above threshold)
+ det_to_matched_trk_obj_ids = {}
+ trk_id_to_max_iou_high_conf_det = {} # trk id --> exactly one detection idx
+ det_to_max_iou_trk_idx = np.argmax(ious_np, axis=1)
+ det_is_high_conf = (det_scores_np >= self.HIGH_CONF_THRESH) & ~is_new_det
+ det_is_high_iou = np.max(ious_np, axis=1) >= self.HIGH_IOU_THRESH
+ det_is_high_conf_and_iou = set(np.nonzero(det_is_high_conf & det_is_high_iou)[0])
+ for d in range(det_masks.size(0)):
+ det_to_matched_trk_obj_ids[d] = trk_obj_ids[ious_np[d, :] >= iou_threshold]
+ if d in det_is_high_conf_and_iou:
+ trk_obj_id = trk_obj_ids[det_to_max_iou_trk_idx[d]].item()
+ trk_id_to_max_iou_high_conf_det[trk_obj_id] = d
+
+ return (
+ new_det_fa_inds,
+ unmatched_trk_obj_ids,
+ det_to_matched_trk_obj_ids,
+ trk_id_to_max_iou_high_conf_det,
+ empty_trk_obj_ids,
+ )
+
+ def _process_hotstart(
+ self,
+ frame_idx: int,
+ reverse: bool,
+ det_to_matched_trk_obj_ids: dict[int, np.ndarray],
+ new_det_obj_ids: np.ndarray,
+ empty_trk_obj_ids: np.ndarray,
+ unmatched_trk_obj_ids: np.ndarray,
+ metadata: dict[str, Any],
+ ):
+ """Handle hotstart heuristics to remove unmatched or duplicated objects."""
+ # obj_id --> first frame index where the object was detected
+ obj_first_frame_idx = metadata["obj_first_frame_idx"]
+ # obj_id --> [mismatched frame indices]
+ unmatched_frame_inds = metadata["unmatched_frame_inds"]
+ trk_keep_alive = metadata["trk_keep_alive"]
+ # (first_appear_obj_id, obj_id) --> [overlap frame indices]
+ overlap_pair_to_frame_inds = metadata["overlap_pair_to_frame_inds"]
+ # removed_obj_ids: object IDs that are suppressed via hot-start
+ removed_obj_ids = metadata["removed_obj_ids"]
+ suppressed_obj_ids = metadata["suppressed_obj_ids"][frame_idx]
+
+ obj_ids_newly_removed = set() # object IDs to be newly removed on this frame
+ hotstart_diff = frame_idx - self.hotstart_delay if not reverse else frame_idx + self.hotstart_delay
+
+ # Step 1: log the frame index where each object ID first appears
+ for obj_id in new_det_obj_ids:
+ if obj_id not in obj_first_frame_idx:
+ obj_first_frame_idx[obj_id] = frame_idx
+ assert obj_id not in trk_keep_alive
+ trk_keep_alive[obj_id] = self.init_trk_keep_alive
+
+ matched_trks = set()
+ # We use the det-->tracks list to check for matched objects. Otherwise, we need to compute areas to decide whether they're occluded
+ for matched_trks_per_det in det_to_matched_trk_obj_ids.values():
+ matched_trks.update(matched_trks_per_det)
+ for obj_id in matched_trks:
+ # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the max value of trk_keep_alive
+ trk_keep_alive[obj_id] = min(self.max_trk_keep_alive, trk_keep_alive[obj_id] + 1)
+ for obj_id in unmatched_trk_obj_ids:
+ unmatched_frame_inds[obj_id].append(frame_idx)
+ # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive
+ # The max keep alive is 2x the min, means the model prefers to keep the prediction rather than suppress it if it was matched long enough.
+ trk_keep_alive[obj_id] = max(self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1)
+ if self.decrease_trk_keep_alive_for_empty_masklets:
+ for obj_id in empty_trk_obj_ids:
+ # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive
+ trk_keep_alive[obj_id] = max(self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1)
+
+ # Step 2: removed tracks that has not matched with detections for `hotstart_unmatch_thresh` frames with hotstart period
+ # a) add unmatched frame indices for each existing object ID
+ # note that `unmatched_trk_obj_ids` contains those frames where the SAM2 output mask
+ # doesn't match any detection; it excludes those frames where SAM2 gives an empty mask
+ # b) remove a masklet if it first appears after `hotstart_diff` and is unmatched for more
+ # than `self.hotstart_unmatch_thresh` frames
+ for obj_id, frame_indices in unmatched_frame_inds.items():
+ if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed:
+ continue # skip if the object is already removed
+ if len(frame_indices) >= self.hotstart_unmatch_thresh:
+ is_within_hotstart = (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or (
+ obj_first_frame_idx[obj_id] < hotstart_diff and reverse
+ )
+ if is_within_hotstart:
+ obj_ids_newly_removed.add(obj_id)
+ LOGGER.debug(
+ f"Removing object {obj_id} at frame {frame_idx} "
+ f"since it is unmatched for frames: {frame_indices}"
+ )
+ if (
+ trk_keep_alive[obj_id] <= 0 # Object has not been matched for too long
+ and not self.suppress_unmatched_only_within_hotstart
+ and obj_id not in removed_obj_ids
+ and obj_id not in obj_ids_newly_removed
+ ):
+ LOGGER.debug(f"Suppressing object {obj_id} at frame {frame_idx}, due to being unmatched")
+ suppressed_obj_ids.add(obj_id)
+
+ # Step 3: removed tracks that overlaps with another track for `hotstart_dup_thresh` frames
+ # a) find overlaps tracks -- we consider overlap if they match to the same detection
+ for _, matched_trk_obj_ids in det_to_matched_trk_obj_ids.items():
+ if len(matched_trk_obj_ids) < 2:
+ continue # only count detections that are matched to multiple (>=2) masklets
+ # if there are multiple matched track ids, we need to find the one that appeared first;
+ # these later appearing ids may be removed since they may be considered as duplicates
+ first_appear_obj_id = (
+ min(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x])
+ if not reverse
+ else max(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x])
+ )
+ for obj_id in matched_trk_obj_ids:
+ if obj_id != first_appear_obj_id:
+ key = (first_appear_obj_id, obj_id)
+ overlap_pair_to_frame_inds[key].append(frame_idx)
+
+ # b) remove a masklet if it first appears after `hotstart_diff` and it overlaps with another
+ # masklet (that appears earlier) for more than `self.hotstart_dup_thresh` frames
+ for (first_obj_id, obj_id), frame_indices in overlap_pair_to_frame_inds.items():
+ if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed:
+ continue # skip if the object is already removed
+ if (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or (
+ obj_first_frame_idx[obj_id] < hotstart_diff and reverse
+ ):
+ if len(frame_indices) >= self.hotstart_dup_thresh:
+ obj_ids_newly_removed.add(obj_id)
+ LOGGER.debug(
+ f"Removing object {obj_id} at frame {frame_idx} "
+ f"since it overlaps with another track {first_obj_id} at frames: {frame_indices}"
+ )
+
+ removed_obj_ids.update(obj_ids_newly_removed)
+ return obj_ids_newly_removed, metadata
+
+ def _tracker_update_memories(
+ self, tracker_inference_states: list[Any], frame_idx: int, low_res_masks: torch.Tensor
+ ):
+ """Run Sam2 memory encoder, enforcing non-overlapping constraints globally."""
+ if len(tracker_inference_states) == 0:
+ return
+ # NOTE: inspect this part if we observe OOMs in the demo
+ high_res_masks = F.interpolate(
+ low_res_masks.unsqueeze(1),
+ size=self.interpol_size,
+ mode="bilinear",
+ align_corners=False,
+ )
+ # We first apply non-overlapping constraints before memory encoding. This may include some suppression heuristics.
+ if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
+ high_res_masks = self.tracker.model._suppress_object_pw_area_shrinkage(high_res_masks)
+ # Instead of gathering the predicted object scores, we use mask areas as a proxy.
+ object_score_logits = torch.where((high_res_masks > 0).any(dim=(-1, -2)), 10.0, -10.0)
+
+ # Run the memory encoder on local slices for each GPU
+ start_idx_gpu = 0
+ start_idx_state = start_idx_gpu
+ for tracker_state in tracker_inference_states:
+ num_obj_per_state = len(tracker_state["obj_ids"])
+ if num_obj_per_state == 0:
+ continue
+ # Get the local high-res masks and object score logits for this inference state
+ end_idx_state = start_idx_state + num_obj_per_state
+ local_high_res_masks = high_res_masks[start_idx_state:end_idx_state]
+ local_object_score_logits = object_score_logits[start_idx_state:end_idx_state]
+ local_batch_size = local_high_res_masks.size(0)
+ # Run Sam2 memory encoder. Note that we do not re-enforce the non-overlapping constraint as it is turned off by default
+
+ encoded_mem = self.tracker._run_memory_encoder(
+ local_batch_size,
+ local_high_res_masks,
+ local_object_score_logits,
+ is_mask_from_pts=False,
+ inference_state=tracker_state,
+ )
+ local_maskmem_features, local_maskmem_pos_enc = encoded_mem
+ # Store encoded memories in the local inference state
+ output_dict = tracker_state["output_dict"]
+ for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]:
+ if frame_idx not in output_dict[storage_key]:
+ continue
+ output_dict[storage_key][frame_idx]["maskmem_features"] = local_maskmem_features
+ output_dict[storage_key][frame_idx]["maskmem_pos_enc"] = [pos for pos in local_maskmem_pos_enc]
+ # for batched inference state, we also need to add per-object
+ # memory slides to support instance interactivity
+ self.tracker._add_output_per_object(
+ inference_state=tracker_state,
+ frame_idx=frame_idx,
+ current_out=output_dict[storage_key][frame_idx],
+ storage_key=storage_key,
+ )
+ start_idx_state += num_obj_per_state
+
+ def _tracker_add_new_objects(
+ self,
+ frame_idx: int,
+ num_frames: int,
+ new_obj_ids: list[int],
+ new_obj_masks: torch.Tensor,
+ tracker_states_local: list[Any],
+ ):
+ """Add a new object to SAM2 inference states."""
+ prev_tracker_state = tracker_states_local[0] if len(tracker_states_local) > 0 else None
+
+ # prepare inference_state
+ # batch objects that first appear on the same frame together
+ # Clear inference state. Keep the cached image features if available.
+ new_tracker_state = self.tracker._init_state(num_frames=num_frames)
+ # NOTE: adding image placeholder
+ new_tracker_state["im"] = None
+ new_tracker_state["backbone_out"] = (
+ prev_tracker_state.get("backbone_out", None) if prev_tracker_state is not None else None
+ )
+
+ assert len(new_obj_ids) == new_obj_masks.size(0)
+ assert new_obj_masks.is_floating_point()
+ new_obj_masks = F.interpolate(
+ new_obj_masks.unsqueeze(0),
+ size=self.interpol_size,
+ mode="bilinear",
+ align_corners=False,
+ ).squeeze(0)
+ new_obj_masks = new_obj_masks > 0
+
+ # add object one by one
+ for new_obj_id, new_mask in zip(new_obj_ids, new_obj_masks):
+ self.tracker.add_new_prompts(
+ inference_state=new_tracker_state,
+ frame_idx=frame_idx,
+ obj_id=new_obj_id,
+ masks=new_mask[None, None], # add bs, channel
+ )
+ # NOTE: we skip enforcing the non-overlapping constraint **globally** when adding new objects.
+ self.tracker.propagate_in_video_preflight(new_tracker_state)
+ tracker_states_local.append(new_tracker_state)
+ return tracker_states_local
+
+ def _tracker_remove_objects(self, tracker_states_local: list[Any], obj_ids: list[int]):
+ """Remove an object from SAM2 inference states. This would remove the object from all frames in the video."""
+ if not obj_ids:
+ return
+ # Filter out states that become empty after removal
+ active_states = []
+ for state in tracker_states_local:
+ for obj_id in obj_ids:
+ # we try to remove `obj_id` on every inference state with `strict=False`
+ # it will not do anything if an inference state doesn't contain `obj_id`
+ self.tracker.remove_object(state, obj_id, strict=False)
+
+ if len(state["obj_ids"]) > 0:
+ active_states.append(state)
+
+ # Update the list in-place
+ tracker_states_local[:] = active_states
+
+ def _initialize_metadata(self):
+ """Initialize metadata for the masklets."""
+ tracker_metadata = {
+ "obj_ids": np.array([], np.int32),
+ "num_obj": np.zeros(1, np.int32),
+ "max_obj_id": -1,
+ "obj_id_to_score": {},
+ "obj_id_to_cls": {},
+ "obj_id_to_tracker_score_frame_wise": defaultdict(dict),
+ "obj_id_to_last_occluded": {},
+ }
+ # "metadata" contains metadata that is only stored on (and accessible to) GPU 0
+ # - obj_first_frame_idx: obj_id --> first frame index where the object was detected
+ # - unmatched_frame_inds: obj_id --> [mismatched frame indices]
+ # - overlap_pair_to_frame_inds: (first_appear_obj_id, obj_id) --> [overlap frame indices]
+ # - removed_obj_ids: object IDs that are suppressed via hot-start
+ metadata = {
+ "obj_first_frame_idx": {},
+ "unmatched_frame_inds": defaultdict(list),
+ "trk_keep_alive": defaultdict(int), # This is used only for object suppression not for removal
+ "overlap_pair_to_frame_inds": defaultdict(list),
+ "removed_obj_ids": set(),
+ # frame_idx --> set of objects with suppressed outputs, but still continue to be tracked
+ "suppressed_obj_ids": defaultdict(set),
+ }
+ if self.masklet_confirmation_enable:
+ # all the following are np.ndarray with the same shape as `obj_ids_all_gpu`
+ metadata["masklet_confirmation"] = {
+ # "status" is the confirmation status of each masklet
+ "status": np.array([], np.int64),
+ # "consecutive_det_num" is the number of consecutive frames where the masklet is
+ # detected by the detector (with a matched detection)
+ "consecutive_det_num": np.array([], np.int64),
+ }
+ tracker_metadata["metadata"] = metadata
+
+ return tracker_metadata
+
+ def update_masklet_confirmation_status(
+ self,
+ metadata: dict[str, Any],
+ obj_ids_all_gpu_prev: np.ndarray,
+ obj_ids_all_gpu_updated: np.ndarray,
+ det_to_matched_trk_obj_ids: dict[int, np.ndarray],
+ new_det_obj_ids: np.ndarray,
+ ):
+ """Update the confirmation status of masklets based on the current frame's detection results."""
+ confirmation_data = metadata["masklet_confirmation"]
+
+ # a) first, expand "confirmation_data" to include new masklets added in this frame
+ status_prev = confirmation_data["status"]
+ consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
+ assert status_prev.shape == obj_ids_all_gpu_prev.shape, (
+ f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
+ )
+
+ obj_id_to_updated_idx = {obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)}
+ prev_elem_is_in_updated = np.isin(obj_ids_all_gpu_prev, obj_ids_all_gpu_updated)
+ prev_elem_obj_ids_in_updated = obj_ids_all_gpu_prev[prev_elem_is_in_updated]
+ prev_elem_inds_in_updated = np.array(
+ [obj_id_to_updated_idx[obj_id] for obj_id in prev_elem_obj_ids_in_updated],
+ dtype=np.int64,
+ )
+ # newly added masklets are initialized to "UNCONFIRMED" status
+ unconfirmed_val = self.UNCONFIRMED
+ status = np.full_like(obj_ids_all_gpu_updated, fill_value=unconfirmed_val)
+ status[prev_elem_inds_in_updated] = status_prev[prev_elem_is_in_updated]
+ consecutive_det_num = np.zeros_like(obj_ids_all_gpu_updated)
+ consecutive_det_num[prev_elem_inds_in_updated] = consecutive_det_num_prev[prev_elem_is_in_updated]
+
+ # b) update the confirmation status of all masklets based on the current frame
+ # b.1) update "consecutive_det_num"
+ # "is_matched": whether a masklet is matched to a detection on this frame
+ is_matched = np.isin(obj_ids_all_gpu_updated, new_det_obj_ids)
+ for matched_trk_obj_ids in det_to_matched_trk_obj_ids.values():
+ is_matched |= np.isin(obj_ids_all_gpu_updated, matched_trk_obj_ids)
+ consecutive_det_num = np.where(is_matched, consecutive_det_num + 1, 0)
+
+ # b.2) update "status"
+ change_to_confirmed = consecutive_det_num >= self.masklet_confirmation_consecutive_det_thresh
+ status[change_to_confirmed] = self.CONFIRMED
+
+ confirmation_data["status"] = status
+ confirmation_data["consecutive_det_num"] = consecutive_det_num
+ return metadata
+
+ def _load_checkpoint(self, ckpt_path: str, strict: bool = True):
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
+ missing_keys, unexpected_keys = self.load_state_dict(sd, strict=strict)
+ if len(missing_keys) > 0 or len(unexpected_keys) > 0:
+ LOGGER.warning(f"Loaded ckpt with {missing_keys=}, {unexpected_keys=}")
+ else:
+ LOGGER.info("Loaded ckpt successfully without missing or unexpected keys")
+
+ def _encode_prompt(self, **kwargs):
+ return self.model._encode_prompt(**kwargs)
+
+ def _drop_new_det_with_obj_limit(self, new_det_fa_inds, det_scores_np, num_to_keep):
+ """Drop a few new detections based on the maximum number of objects. We drop new objects based on their
+ detection scores, keeping the high-scoring ones and dropping the low-scoring ones.
+ """
+ assert 0 <= num_to_keep <= len(new_det_fa_inds)
+ if num_to_keep == 0:
+ return np.array([], np.int64) # keep none
+ if num_to_keep == len(new_det_fa_inds):
+ return new_det_fa_inds # keep all
+
+ # keep the top-scoring detections
+ score_order = np.argsort(det_scores_np[new_det_fa_inds])[::-1]
+ new_det_fa_inds = new_det_fa_inds[score_order[:num_to_keep]]
+ return new_det_fa_inds
diff --git a/ultralytics/models/sam/sam3/__init__.py b/ultralytics/models/sam/sam3/__init__.py
new file mode 100644
index 0000000000..fbe077b618
--- /dev/null
+++ b/ultralytics/models/sam/sam3/__init__.py
@@ -0,0 +1,3 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
diff --git a/ultralytics/models/sam/sam3/decoder.py b/ultralytics/models/sam/sam3/decoder.py
new file mode 100644
index 0000000000..f9d789b48f
--- /dev/null
+++ b/ultralytics/models/sam/sam3/decoder.py
@@ -0,0 +1,546 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+"""
+Transformer decoder.
+Inspired from Pytorch's version, adds the pre-norm variant.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+from torch import nn
+from torchvision.ops.roi_align import RoIAlign
+
+from ultralytics.nn.modules.transformer import MLP
+from ultralytics.nn.modules.utils import _get_clones, inverse_sigmoid
+from ultralytics.utils.ops import xywh2xyxy
+
+from .model_misc import gen_sineembed_for_position
+
+
+class TransformerDecoderLayer(nn.Module):
+ """TransformerDecoderLayer is made up of self-attn, cross-attn, and feedforward network (FFN)."""
+
+ def __init__(
+ self,
+ d_model: int,
+ dim_feedforward: int,
+ dropout: float,
+ cross_attention: nn.Module,
+ n_heads: int,
+ use_text_cross_attention: bool = False,
+ ):
+ """Initialize the TransformerDecoderLayer."""
+ super().__init__()
+ # cross attention
+ self.cross_attn = cross_attention
+ self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # cross attention text
+ self.use_text_cross_attention = use_text_cross_attention
+ if use_text_cross_attention:
+ self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
+ self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
+ self.catext_norm = nn.LayerNorm(d_model)
+
+ # self attention
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
+ self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.activation = nn.ReLU()
+ self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+ self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
+ self.norm3 = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ """Add positional embedding to the tensor."""
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, tgt):
+ """Feedforward network forward pass."""
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout4(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward(
+ self,
+ # for tgt
+ tgt: torch.Tensor, # nq, bs, d_model
+ tgt_query_pos: torch.Tensor = None, # pos for query. MLP(Sine(pos))
+ memory_text: torch.Tensor = None, # num_token, bs, d_model
+ text_attention_mask: torch.Tensor = None, # bs, num_token
+ # for memory
+ memory: torch.Tensor = None, # hw, bs, d_model
+ memory_key_padding_mask: torch.Tensor = None,
+ memory_pos: torch.Tensor = None, # pos for memory
+ # sa
+ self_attn_mask: torch.Tensor = None, # mask used for self-attention
+ cross_attn_mask: torch.Tensor = None, # mask used for cross-attention
+ # dac
+ dac=False,
+ dac_use_selfatt_ln=True,
+ presence_token=None,
+ # skip inside deformable attn
+ **kwargs, # additional kwargs for compatibility
+ ):
+ """Input: - tgt/tgt_query_pos: nq, bs, d_model. -."""
+ # self attention
+ tgt, tgt_query_pos = self._apply_self_attention(
+ tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask
+ )
+
+ if self.use_text_cross_attention:
+ tgt2 = self.ca_text(
+ self.with_pos_embed(tgt, tgt_query_pos),
+ memory_text.to(tgt.dtype),
+ memory_text.to(tgt.dtype),
+ key_padding_mask=text_attention_mask,
+ )[0]
+ tgt = tgt + self.catext_dropout(tgt2)
+ tgt = self.catext_norm(tgt)
+
+ if presence_token is not None:
+ presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
+ cross_attn_mask = torch.cat([presence_token_mask, cross_attn_mask], dim=1) # (bs*nheads, 1+nq, hw)
+
+ # Cross attention to image
+ tgt2 = self.cross_attn(
+ query=self.with_pos_embed(tgt, tgt_query_pos),
+ key=self.with_pos_embed(memory, memory_pos),
+ value=memory,
+ attn_mask=cross_attn_mask,
+ key_padding_mask=(memory_key_padding_mask.transpose(0, 1) if memory_key_padding_mask is not None else None),
+ need_weights=False,
+ )[0]
+
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+
+ # ffn
+ tgt = self.forward_ffn(tgt.to(memory.dtype))
+
+ presence_token_out = None
+ if presence_token is not None:
+ presence_token_out = tgt[:1]
+ tgt = tgt[1:]
+
+ return tgt, presence_token_out
+
+ def _apply_self_attention(self, tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask):
+ """Apply self-attention with optional DAC splitting."""
+ if self.self_attn is None:
+ return tgt
+
+ if dac:
+ # Split queries for DAC (detect-and-classify)
+ assert tgt.shape[0] % 2 == 0, "DAC requires even number of queries"
+ num_o2o_queries = tgt.shape[0] // 2
+ tgt_o2o = tgt[:num_o2o_queries]
+ tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
+ tgt_o2m = tgt[num_o2o_queries:]
+ else:
+ tgt_o2o = tgt
+ tgt_query_pos_o2o = tgt_query_pos
+
+ # Handle presence token
+ if presence_token is not None:
+ tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
+ tgt_query_pos_o2o = torch.cat([torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0).to(
+ tgt_o2o.dtype
+ )
+ tgt_query_pos = torch.cat([torch.zeros_like(presence_token), tgt_query_pos], dim=0)
+
+ # Self-attention
+ q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
+ tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0].to(tgt.dtype)
+ tgt_o2o = tgt_o2o + self.dropout2(tgt2)
+
+ # Recombine and normalize
+ if dac:
+ if not dac_use_selfatt_ln:
+ tgt_o2o = self.norm2(tgt_o2o)
+ tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0)
+ if dac_use_selfatt_ln:
+ tgt = self.norm2(tgt)
+ else:
+ tgt = tgt_o2o
+ tgt = self.norm2(tgt)
+
+ return tgt, tgt_query_pos
+
+
+class TransformerDecoder(nn.Module):
+ """Transformer Decoder consisting of multiple layers."""
+
+ def __init__(
+ self,
+ d_model: int,
+ frozen: bool,
+ interaction_layer,
+ layer,
+ num_layers: int,
+ num_queries: int,
+ return_intermediate: bool,
+ box_refine: bool = False,
+ num_o2m_queries: int = 0,
+ dac: bool = False,
+ boxRPB: str = "none",
+ # Experimental: An object query for SAM 2 tasks
+ instance_query: bool = False,
+ # Defines the number of additional instance queries,
+ # 1 or 4 are the most likely for single vs multi mask support
+ num_instances: int = 1, # Irrelevant if instance_query is False
+ dac_use_selfatt_ln: bool = True,
+ use_act_checkpoint: bool = False,
+ compile_mode=None,
+ presence_token: bool = False,
+ clamp_presence_logits: bool = True,
+ clamp_presence_logit_max_val: float = 10.0,
+ use_normed_output_consistently: bool = True,
+ separate_box_head_instance: bool = False,
+ separate_norm_instance: bool = False,
+ ):
+ """Initialize the TransformerDecoder."""
+ super().__init__()
+ self.d_model = d_model
+ self.layers = _get_clones(layer, num_layers)
+ self.fine_layers = (
+ _get_clones(interaction_layer, num_layers) if interaction_layer is not None else [None] * num_layers
+ )
+ self.num_layers = num_layers
+ self.num_queries = num_queries
+ self.dac = dac
+ if dac:
+ self.num_o2m_queries = num_queries
+ tot_num_queries = num_queries
+ else:
+ self.num_o2m_queries = num_o2m_queries
+ tot_num_queries = num_queries + num_o2m_queries
+ self.norm = nn.LayerNorm(d_model)
+ self.return_intermediate = return_intermediate
+ self.bbox_embed = MLP(d_model, d_model, 4, 3)
+ self.query_embed = nn.Embedding(tot_num_queries, d_model)
+ self.instance_query_embed = None
+ self.instance_query_reference_points = None
+ self.use_instance_query = instance_query
+ self.num_instances = num_instances
+ self.use_normed_output_consistently = use_normed_output_consistently
+
+ self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
+ self.instance_bbox_embed = None
+ if separate_box_head_instance:
+ self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
+ if instance_query:
+ self.instance_query_embed = nn.Embedding(num_instances, d_model)
+ self.box_refine = box_refine
+ if box_refine:
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
+
+ self.reference_points = nn.Embedding(num_queries, 4)
+ if instance_query:
+ self.instance_reference_points = nn.Embedding(num_instances, 4)
+
+ assert boxRPB in ["none", "log", "linear", "both"]
+ self.boxRPB = boxRPB
+ if boxRPB != "none":
+ try:
+ nheads = self.layers[0].cross_attn_image.num_heads
+ except AttributeError:
+ nheads = self.layers[0].cross_attn.num_heads
+
+ n_input = 4 if boxRPB == "both" else 2
+ self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
+ self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
+ self.compilable_cord_cache = None
+ self.compilable_stored_size = None
+ self.coord_cache = {}
+
+ self.roi_pooler = (
+ RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
+ if interaction_layer is not None
+ else None
+ )
+ if frozen:
+ for p in self.parameters():
+ p.requires_grad_(False)
+
+ self.presence_token = None
+ self.clamp_presence_logits = clamp_presence_logits
+ self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
+ if presence_token:
+ self.presence_token = nn.Embedding(1, d_model)
+ self.presence_token_head = MLP(d_model, d_model, 1, 3)
+ self.presence_token_out_norm = nn.LayerNorm(d_model)
+
+ self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
+ self.dac_use_selfatt_ln = dac_use_selfatt_ln
+ self.use_act_checkpoint = use_act_checkpoint
+
+ nn.init.normal_(self.query_embed.weight.data)
+ if self.instance_query_embed is not None:
+ nn.init.normal_(self.instance_query_embed.weight.data)
+
+ assert self.roi_pooler is None
+ assert self.return_intermediate, "support return_intermediate only"
+ assert self.box_refine, "support box refine only"
+
+ self.compile_mode = compile_mode
+ self.compiled = False
+ # We defer compilation till after the first forward, to first warm-up the boxRPB cache
+
+ # assign layer index to each layer so that some layers can decide what to do
+ # based on which layer index they are (e.g. cross attention to memory bank only
+ # in selected layers)
+ for layer_idx, layer in enumerate(self.layers):
+ layer.layer_idx = layer_idx
+
+ @staticmethod
+ def _get_coords(H, W, device, dtype):
+ """Get normalized coordinates for height and width."""
+ coords_h = torch.arange(0, H, dtype=dtype, device=device) / H
+ coords_w = torch.arange(0, W, dtype=dtype, device=device) / W
+ return coords_h, coords_w
+
+ def _get_rpb_matrix(self, reference_boxes, feat_size):
+ """Get the relative position bias (RPB) matrix for box-relative position bias."""
+ H, W = feat_size
+ boxes_xyxy = xywh2xyxy(reference_boxes).transpose(0, 1)
+ bs, num_queries, _ = boxes_xyxy.shape
+ if self.compilable_cord_cache is None:
+ self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device, reference_boxes.dtype)
+ self.compilable_stored_size = (H, W)
+
+ if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
+ H,
+ W,
+ ):
+ # good, hitting the cache, will be compilable
+ coords_h, coords_w = self.compilable_cord_cache
+ else:
+ # cache miss, will create compilation issue
+ # In case we're not compiling, we'll still rely on the dict-based cache
+ if feat_size not in self.coord_cache:
+ self.coord_cache[feat_size] = self._get_coords(H, W, reference_boxes.device)
+ coords_h, coords_w = self.coord_cache[feat_size]
+
+ assert coords_h.shape == (H,)
+ assert coords_w.shape == (W,)
+
+ deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
+ deltas_y = deltas_y.view(bs, num_queries, -1, 2)
+ deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
+ deltas_x = deltas_x.view(bs, num_queries, -1, 2)
+
+ if self.boxRPB in ["log", "both"]:
+ deltas_x_log = deltas_x * 8 # normalize to -8, 8
+ deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / np.log2(8)
+
+ deltas_y_log = deltas_y * 8 # normalize to -8, 8
+ deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / np.log2(8)
+ if self.boxRPB == "log":
+ deltas_x = deltas_x_log
+ deltas_y = deltas_y_log
+ else:
+ deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
+ deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)
+
+ if self.training:
+ assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
+ deltas_x = self.boxRPB_embed_x(x=deltas_x) # bs, num_queries, W, n_heads
+ deltas_y = self.boxRPB_embed_y(x=deltas_y) # bs, num_queries, H, n_heads
+
+ if not torch.compiler.is_dynamo_compiling():
+ assert deltas_x.shape[:3] == (bs, num_queries, W)
+ assert deltas_y.shape[:3] == (bs, num_queries, H)
+
+ B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(2) # bs, num_queries, H, W, n_heads
+ if not torch.compiler.is_dynamo_compiling():
+ assert B.shape[:4] == (bs, num_queries, H, W)
+ B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
+ B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
+ B = B.contiguous() # memeff attn likes ordered strides
+ if not torch.compiler.is_dynamo_compiling():
+ assert B.shape[2:] == (num_queries, H * W)
+ return B
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: torch.Tensor = None,
+ memory_mask: torch.Tensor = None,
+ memory_key_padding_mask: torch.Tensor = None,
+ pos: torch.Tensor = None,
+ reference_boxes: torch.Tensor = None, # num_queries, bs, 4
+ # for memory
+ spatial_shapes: torch.Tensor = None, # bs, num_levels, 2
+ valid_ratios: torch.Tensor = None,
+ # for text
+ memory_text: torch.Tensor = None,
+ text_attention_mask: torch.Tensor = None,
+ # if `apply_dac` is None, it will default to `self.dac`
+ apply_dac: bool | None = None,
+ is_instance_prompt=False,
+ decoder_extra_kwargs: dict | None = None,
+ # ROI memory bank
+ obj_roi_memory_feat=None,
+ obj_roi_memory_mask=None,
+ box_head_trk=None,
+ ):
+ """Forward pass of the TransformerDecoder."""
+ if memory_mask is not None:
+ assert self.boxRPB == "none", (
+ "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
+ )
+
+ apply_dac = apply_dac if apply_dac is not None else self.dac
+ if apply_dac:
+ assert (tgt.shape[0] == self.num_queries) or (
+ self.use_instance_query and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
+ )
+
+ tgt = tgt.repeat(2, 1, 1)
+ # note that we don't tile tgt_mask, since DAC doesn't
+ # use self-attention in o2m queries
+ if reference_boxes is not None:
+ assert (reference_boxes.shape[0] == self.num_queries) or (
+ self.use_instance_query and (reference_boxes.shape[0] == self.instance_query_embed.num_embeddings)
+ )
+ reference_boxes = reference_boxes.repeat(2, 1, 1)
+
+ bs = tgt.shape[1]
+ intermediate = []
+ intermediate_presence_logits = []
+ presence_feats = None
+
+ if self.box_refine:
+ if reference_boxes is None:
+ # In this case, we're in a one-stage model, so we generate the reference boxes
+ reference_boxes = self.reference_points.weight.unsqueeze(1)
+ reference_boxes = reference_boxes.repeat(2, bs, 1) if apply_dac else reference_boxes.repeat(1, bs, 1)
+ reference_boxes = reference_boxes.sigmoid()
+ intermediate_ref_boxes = [reference_boxes]
+ else:
+ reference_boxes = None
+ intermediate_ref_boxes = None
+
+ output = tgt
+ presence_out = None
+ if self.presence_token is not None and is_instance_prompt is False:
+ # expand to batch dim
+ presence_out = self.presence_token.weight[None].expand(1, bs, -1)
+
+ box_head = self.bbox_embed
+ if is_instance_prompt and self.instance_bbox_embed is not None:
+ box_head = self.instance_bbox_embed
+
+ out_norm = self.norm
+ if is_instance_prompt and self.instance_norm is not None:
+ out_norm = self.instance_norm
+
+ for layer_idx, layer in enumerate(self.layers):
+ reference_points_input = (
+ reference_boxes[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
+ ) # nq, bs, nlevel, 4
+
+ query_sine_embed = gen_sineembed_for_position(
+ reference_points_input[:, :, 0, :], self.d_model
+ ) # nq, bs, d_model*2
+
+ # conditional query
+ query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
+
+ if self.boxRPB != "none" and reference_boxes is not None:
+ assert spatial_shapes.shape[0] == 1, "only single scale support implemented"
+ memory_mask = self._get_rpb_matrix(
+ reference_boxes,
+ (spatial_shapes[0, 0], spatial_shapes[0, 1]),
+ )
+ memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
+ if self.training:
+ assert self.use_act_checkpoint, "Activation checkpointing not enabled in the decoder"
+ output, presence_out = layer(
+ tgt=output,
+ tgt_query_pos=query_pos,
+ memory_text=memory_text,
+ text_attention_mask=text_attention_mask,
+ memory=memory,
+ memory_key_padding_mask=memory_key_padding_mask,
+ memory_pos=pos,
+ self_attn_mask=tgt_mask,
+ cross_attn_mask=memory_mask,
+ dac=apply_dac,
+ dac_use_selfatt_ln=self.dac_use_selfatt_ln,
+ presence_token=presence_out,
+ **(decoder_extra_kwargs or {}),
+ # ROI memory bank
+ obj_roi_memory_feat=obj_roi_memory_feat,
+ obj_roi_memory_mask=obj_roi_memory_mask,
+ )
+
+ # iter update
+ if self.box_refine:
+ reference_before_sigmoid = inverse_sigmoid(reference_boxes)
+ if box_head_trk is None:
+ # delta_unsig = self.bbox_embed(output)
+ if not self.use_normed_output_consistently:
+ delta_unsig = box_head(output)
+ else:
+ delta_unsig = box_head(out_norm(output))
+ else:
+ # box_head_trk use a separate box head for tracking queries
+ Q_det = decoder_extra_kwargs["Q_det"]
+ assert output.size(0) >= Q_det
+ delta_unsig_det = self.bbox_embed(output[:Q_det])
+ delta_unsig_trk = box_head_trk(output[Q_det:])
+ delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
+ outputs_unsig = delta_unsig + reference_before_sigmoid
+ new_reference_points = outputs_unsig.sigmoid()
+
+ reference_boxes = new_reference_points.detach()
+ if layer_idx != self.num_layers - 1:
+ intermediate_ref_boxes.append(new_reference_points)
+ else:
+ raise NotImplementedError("not implemented yet")
+
+ intermediate.append(out_norm(output))
+ if self.presence_token is not None and is_instance_prompt is False:
+ # norm, mlp head
+ intermediate_layer_presence_logits = self.presence_token_head(
+ self.presence_token_out_norm(presence_out)
+ ).squeeze(-1)
+
+ # clamp to mitigate numerical issues
+ if self.clamp_presence_logits:
+ intermediate_layer_presence_logits.clamp(
+ min=-self.clamp_presence_logit_max_val,
+ max=self.clamp_presence_logit_max_val,
+ )
+
+ intermediate_presence_logits.append(intermediate_layer_presence_logits)
+ presence_feats = presence_out.clone()
+
+ if not self.compiled and self.compile_mode is not None:
+ self.forward = torch.compile(self.forward, mode=self.compile_mode, fullgraph=True)
+ self.compiled = True
+
+ return (
+ torch.stack(intermediate),
+ torch.stack(intermediate_ref_boxes),
+ (
+ torch.stack(intermediate_presence_logits)
+ if self.presence_token is not None and is_instance_prompt is False
+ else None
+ ),
+ presence_feats,
+ )
diff --git a/ultralytics/models/sam/sam3/encoder.py b/ultralytics/models/sam/sam3/encoder.py
new file mode 100644
index 0000000000..4300023e78
--- /dev/null
+++ b/ultralytics/models/sam/sam3/encoder.py
@@ -0,0 +1,535 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+# Based on https://github.com/IDEA-Research/GroundingDINO
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+from ultralytics.nn.modules.utils import _get_clones
+
+from .model_misc import get_valid_ratio
+
+
+class TransformerEncoderLayer(nn.Module):
+ """Transformer encoder layer that performs self-attention followed by cross-attention.
+
+ This layer was previously called TransformerDecoderLayer but was renamed to better reflect its role in the
+ architecture. It processes input sequences through self-attention and then cross-attention with another input
+ (typically image features).
+
+ The layer supports both pre-norm and post-norm configurations, as well as positional encoding at different stages of
+ the attention mechanism.
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ dim_feedforward: int,
+ dropout: float,
+ pos_enc_at_attn: bool,
+ pos_enc_at_cross_attn_keys: bool,
+ pos_enc_at_cross_attn_queries: bool,
+ pre_norm: bool,
+ self_attention: nn.Module = None,
+ cross_attention: nn.Module = None,
+ ):
+ """Initialize a transformer encoder layer.
+
+ Args:
+ cross_attention: Cross-attention module for attending to image features
+ d_model: Model dimension/hidden size
+ dim_feedforward: Dimension of the feedforward network
+ dropout: Dropout probability
+ pos_enc_at_attn: Whether to add positional encodings at self-attention
+ pos_enc_at_cross_attn_keys: Whether to add positional encodings to keys in cross-attention
+ pos_enc_at_cross_attn_queries: Whether to add positional encodings to queries in cross-attention
+ pre_norm: Whether to use pre-norm (True) or post-norm (False) architecture
+ self_attention: Self-attention module
+ """
+ super().__init__()
+ self.d_model = d_model
+ self.dim_feedforward = dim_feedforward
+ self.dropout_value = dropout
+ self.self_attn = self_attention or nn.MultiheadAttention(num_heads=8, dropout=0.1, embed_dim=256)
+ self.cross_attn_image = cross_attention or nn.MultiheadAttention(num_heads=8, dropout=0.1, embed_dim=256)
+
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = nn.ReLU()
+ self.pre_norm = pre_norm
+
+ self.pos_enc_at_attn = pos_enc_at_attn
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
+
+ self.layer_idx = None
+
+ def forward_post(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ tgt_mask: torch.Tensor = None,
+ memory_mask: torch.Tensor = None,
+ tgt_key_padding_mask: torch.Tensor = None,
+ memory_key_padding_mask: torch.Tensor = None,
+ pos: torch.Tensor = None,
+ query_pos: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """Forward pass for post-norm architecture.
+
+ In post-norm architecture, normalization is applied after attention and feedforward operations.
+
+ Args:
+ tgt: Input tensor to be processed
+ memory: Memory tensor for cross-attention
+ tgt_mask: Mask for self-attention
+ memory_mask: Mask for cross-attention
+ tgt_key_padding_mask: Key padding mask for self-attention
+ memory_key_padding_mask: Key padding mask for cross-attention
+ pos: Positional encoding for memory
+ query_pos: Positional encoding for query
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ Processed tensor
+ """
+ q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
+
+ # Self attention
+ tgt2 = self.self_attn(
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask, need_weights=False
+ )[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+
+ # Cross attention to image
+ tgt2 = self.cross_attn_image(
+ query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
+ key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ need_weights=False,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+
+ # FFN
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_pre(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ dac: bool = False,
+ tgt_mask: torch.Tensor = None,
+ memory_mask: torch.Tensor = None,
+ tgt_key_padding_mask: torch.Tensor = None,
+ memory_key_padding_mask: torch.Tensor = None,
+ pos: torch.Tensor = None,
+ query_pos: torch.Tensor = None,
+ # **kwargs,
+ ) -> torch.Tensor:
+ """Forward pass for pre-norm architecture.
+
+ In pre-norm architecture, normalization is applied before attention and feedforward operations.
+
+ Args:
+ tgt: Input tensor to be processed
+ memory: Memory tensor for cross-attention
+ dac: Whether to use Divide-and-Conquer attention
+ tgt_mask: Mask for self-attention
+ memory_mask: Mask for cross-attention
+ tgt_key_padding_mask: Key padding mask for self-attention
+ memory_key_padding_mask: Key padding mask for cross-attention
+ pos: Positional encoding for memory
+ query_pos: Positional encoding for query
+ attn_bias: Optional attention bias tensor
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ Processed tensor
+ """
+ if dac:
+ # we only apply self attention to the first half of the queries
+ assert tgt.shape[0] % 2 == 0
+ other_tgt = tgt[tgt.shape[0] // 2 :]
+ tgt = tgt[: tgt.shape[0] // 2]
+ tgt2 = self.norm1(tgt)
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ if dac:
+ # Recombine
+ tgt = torch.cat((tgt, other_tgt), dim=0)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.cross_attn_image(
+ query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
+ key=memory.to(tgt2.dtype) + pos if self.pos_enc_at_cross_attn_keys else memory.to(tgt2.dtype),
+ value=memory.to(tgt2.dtype),
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ # attn_bias=attn_bias,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ dac: bool = False,
+ tgt_mask: torch.Tensor = None,
+ memory_mask: torch.Tensor = None,
+ tgt_key_padding_mask: torch.Tensor = None,
+ memory_key_padding_mask: torch.Tensor = None,
+ pos: torch.Tensor = None,
+ query_pos: torch.Tensor = None,
+ # **kwds: Any,
+ ) -> torch.Tensor:
+ """Forward pass for the transformer encoder layer.
+
+ Args:
+ tgt: Input tensor to be processed
+ memory: Memory tensor (e.g., image features) for cross-attention
+ dac: Whether to use Divide-and-Conquer attention (only apply self-attention to first half)
+ tgt_mask: Mask for self-attention
+ memory_mask: Mask for cross-attention
+ tgt_key_padding_mask: Key padding mask for self-attention
+ memory_key_padding_mask: Key padding mask for cross-attention
+ pos: Positional encoding for memory
+ query_pos: Positional encoding for query
+ attn_bias: Optional attention bias tensor
+ **kwds: Additional keyword arguments
+
+ Returns:
+ Processed tensor after self-attention, cross-attention, and feedforward network
+ """
+ fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
+ return fwd_fn(
+ tgt,
+ memory,
+ dac=dac,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos,
+ query_pos=query_pos,
+ # attn_bias=attn_bias,
+ # **kwds,
+ )
+
+
+class TransformerEncoder(nn.Module):
+ """Transformer encoder that processes multi-level features.
+
+ This encoder takes multi-level features (e.g., from a backbone network) and processes them through a stack of
+ transformer encoder layers. It supports features from multiple levels (e.g., different resolutions) and can apply
+ activation checkpointing for memory efficiency during training.
+
+ Args:
+ layer: The encoder layer to be stacked multiple times
+ num_layers: Number of encoder layers to stack
+ d_model: Model dimension/hidden size
+ num_feature_levels: Number of feature levels to process
+ frozen: Whether to freeze the parameters of this module
+ use_act_checkpoint: Whether to use activation checkpointing during training
+ """
+
+ def __init__(
+ self,
+ layer: nn.Module,
+ num_layers: int,
+ d_model: int,
+ num_feature_levels: int,
+ frozen: bool = False,
+ use_act_checkpoint: bool = False,
+ ):
+ """Initialize the transformer encoder."""
+ super().__init__()
+ self.layers = _get_clones(layer, num_layers)
+ self.num_layers = num_layers
+
+ self.num_feature_levels = num_feature_levels
+ self.level_embed = None
+ if num_feature_levels > 1:
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
+
+ if frozen:
+ for p in self.parameters():
+ p.requires_grad_(False)
+
+ self.use_act_checkpoint = use_act_checkpoint
+
+ # assign layer index to each layer so that some layers can decide what to do
+ # based on which layer index they are (e.g. cross attention to memory bank only
+ # in selected layers)
+ for layer_idx, layer in enumerate(self.layers):
+ layer.layer_idx = layer_idx
+
+ def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
+ """Prepare multi-level features for transformer encoder."""
+ assert len(srcs) == self.num_feature_levels, "mismatch between expected and received # of feature levels"
+
+ src_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ has_mask = masks is not None and masks[0] is not None
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
+ _, _, h, w = src.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
+ if has_mask:
+ mask = mask.flatten(1)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
+ if self.level_embed is not None:
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+ else:
+ lvl_pos_embed = pos_embed
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ src_flatten.append(src)
+ if has_mask:
+ mask_flatten.append(mask)
+ src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
+ mask_flatten = torch.cat(mask_flatten, 1) if has_mask else None # bs, \sum{hxw}
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
+ spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
+ level_start_index = torch.cat(
+ (
+ spatial_shapes.new_zeros((1,)),
+ spatial_shapes.prod(1).cumsum(0)[:-1],
+ )
+ )
+ if has_mask:
+ valid_ratios = torch.stack([get_valid_ratio(m) for m in masks], 1)
+ else:
+ valid_ratios = torch.ones(
+ (src_flatten.shape[0], self.num_feature_levels, 2),
+ device=src_flatten.device,
+ dtype=src_flatten.dtype,
+ )
+
+ return (
+ src_flatten,
+ mask_flatten,
+ lvl_pos_embed_flatten,
+ level_start_index,
+ valid_ratios,
+ spatial_shapes,
+ )
+
+ def forward(
+ self,
+ src: list[torch.Tensor],
+ src_key_padding_masks: list[torch.Tensor] | None = None,
+ pos: list[torch.Tensor] | None = None,
+ prompt: torch.Tensor = None,
+ prompt_key_padding_mask: torch.Tensor = None,
+ encoder_extra_kwargs: dict | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Process multi-level features through the transformer encoder.
+
+ Args:
+ src: List of multi-level features, each with shape (batch_size, channels, height, width)
+ src_key_padding_masks: List of padding masks for each feature level, each with shape (batch_size, height,
+ width)
+ pos: List of positional embeddings for each feature level, each with shape (batch_size, channels, height,
+ width)
+ prompt: Optional text/prompt features to attend to, with shape (seq_len, batch_size, d_model)
+ prompt_key_padding_mask: Optional padding mask for prompt, with shape (batch_size, seq_len)
+ encoder_extra_kwargs: Optional additional arguments to pass to each encoder layer
+
+ Returns:
+ A tuple containing:
+ - output: Processed features with shape (seq_len, batch_size, d_model)
+ - key_padding_masks_flatten: Flattened padding masks
+ - lvl_pos_embed_flatten: Flattened positional embeddings
+ - level_start_index: Starting indices for each feature level
+ - spatial_shapes: Spatial dimensions of each feature level
+ - valid_ratios: Valid ratios for each feature level
+ """
+ assert len(src) == self.num_feature_levels, "must be equal to num_feature_levels"
+ if src_key_padding_masks is not None:
+ assert len(src_key_padding_masks) == self.num_feature_levels
+ if pos is not None:
+ assert len(pos) == self.num_feature_levels
+ # Flatten multilevel feats and add level pos embeds
+ (
+ src_flatten,
+ key_padding_masks_flatten,
+ lvl_pos_embed_flatten,
+ level_start_index,
+ valid_ratios,
+ spatial_shapes,
+ ) = self._prepare_multilevel_features(src, src_key_padding_masks, pos)
+
+ output = src_flatten
+ for layer in self.layers:
+ layer_kwargs = {}
+
+ assert isinstance(layer, TransformerEncoderLayer)
+ layer_kwargs["memory"] = prompt
+ layer_kwargs["memory_key_padding_mask"] = prompt_key_padding_mask
+ layer_kwargs["query_pos"] = lvl_pos_embed_flatten
+ layer_kwargs["tgt"] = output
+ layer_kwargs["tgt_key_padding_mask"] = key_padding_masks_flatten
+
+ if self.training:
+ assert self.use_act_checkpoint, "activation ckpt not enabled in encoder"
+ if encoder_extra_kwargs is not None:
+ layer_kwargs.update(encoder_extra_kwargs)
+ output = layer(**layer_kwargs)
+ # return as seq first
+ return (
+ output.transpose(0, 1),
+ (key_padding_masks_flatten.transpose(0, 1) if key_padding_masks_flatten is not None else None),
+ lvl_pos_embed_flatten.transpose(0, 1),
+ level_start_index,
+ spatial_shapes,
+ valid_ratios,
+ )
+
+
+class TransformerEncoderFusion(TransformerEncoder):
+ """Transformer encoder that fuses text and image features.
+
+ This encoder extends TransformerEncoder to handle both text and image features, with the ability to add pooled text
+ features to image features for better cross-modal fusion. It supports torch.compile for performance optimization.
+
+ Args:
+ layer: The encoder layer to be stacked multiple times
+ num_layers: Number of encoder layers to stack
+ d_model: Model dimension/hidden size
+ num_feature_levels: Number of feature levels to process
+ add_pooled_text_to_img_feat: Whether to add pooled text features to image features
+ pool_text_with_mask: Whether to use the mask when pooling text features
+ compile_mode: Mode for torch.compile, or None to disable compilation
+ **kwargs: Additional arguments to pass to the parent class
+ """
+
+ def __init__(
+ self,
+ layer: nn.Module,
+ num_layers: int,
+ d_model: int,
+ num_feature_levels: int,
+ add_pooled_text_to_img_feat: bool = True,
+ pool_text_with_mask: bool = False,
+ compile_mode: str | None = None,
+ **kwargs,
+ ):
+ """Initialize the transformer encoder with text-image fusion."""
+ super().__init__(
+ layer,
+ num_layers,
+ d_model,
+ num_feature_levels,
+ **kwargs,
+ )
+ self.add_pooled_text_to_img_feat = add_pooled_text_to_img_feat
+ if self.add_pooled_text_to_img_feat:
+ self.text_pooling_proj = nn.Linear(d_model, d_model)
+ self.pool_text_with_mask = pool_text_with_mask
+ if compile_mode is not None:
+ self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
+
+ def forward(
+ self,
+ src: list[torch.Tensor],
+ prompt: torch.Tensor,
+ src_key_padding_mask: list[torch.Tensor] | None = None,
+ src_pos: list[torch.Tensor] | None = None,
+ prompt_key_padding_mask: torch.Tensor = None,
+ feat_sizes: list[int] | None = None,
+ encoder_extra_kwargs: dict | None = None,
+ ):
+ """Forward pass for the transformer encoder with text-image fusion."""
+ # Restore spatial shapes of vision
+ bs = src[0].shape[1] # seq first
+ if feat_sizes is not None:
+ assert len(feat_sizes) == len(src)
+ if src_key_padding_mask is None:
+ src_key_padding_mask = [None] * len(src)
+ for i, (h, w) in enumerate(feat_sizes):
+ src[i] = src[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
+ src_pos[i] = src_pos[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
+ src_key_padding_mask[i] = (
+ src_key_padding_mask[i].reshape(h, w, bs).permute(2, 0, 1)
+ if src_key_padding_mask[i] is not None
+ else None
+ )
+ else:
+ assert all(x.dim == 4 for x in src), "expected list of (bs, c, h, w) tensors"
+
+ if self.add_pooled_text_to_img_feat:
+ # Fusion: Add mean pooled text to image features
+ pooled_text = pool_text_feat(prompt, prompt_key_padding_mask, self.pool_text_with_mask)
+ pooled_text = self.text_pooling_proj(pooled_text)[..., None, None] # prompt is seq first
+ src = [x.add_(pooled_text) for x in src]
+
+ (
+ out,
+ key_padding_masks_flatten,
+ lvl_pos_embed_flatten,
+ level_start_index,
+ spatial_shapes,
+ valid_ratios,
+ ) = super().forward(
+ src,
+ src_key_padding_masks=src_key_padding_mask,
+ pos=src_pos,
+ prompt=prompt.transpose(0, 1),
+ prompt_key_padding_mask=prompt_key_padding_mask,
+ encoder_extra_kwargs=encoder_extra_kwargs,
+ )
+
+ return {
+ "memory": out,
+ "padding_mask": key_padding_masks_flatten,
+ "pos_embed": lvl_pos_embed_flatten,
+ "memory_text": prompt,
+ "level_start_index": level_start_index,
+ "spatial_shapes": spatial_shapes,
+ "valid_ratios": valid_ratios,
+ }
+
+
+def pool_text_feat(prompt, prompt_mask, pool_with_mask):
+ """Mean-pool the prompt embeddings over the valid tokens only."""
+ # prompt has shape (seq, bs, dim)
+ if not pool_with_mask:
+ return prompt.mean(dim=0)
+
+ # prompt_mask has shape (bs, seq), where False is valid and True is padding
+ assert prompt_mask.dim() == 2
+ # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
+ is_valid = (~prompt_mask).float().permute(1, 0)[..., None]
+ # num_valid has shape (bs, 1)
+ num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
+
+ # mean pool over all the valid tokens
+ pooled_text = (prompt * is_valid).sum(dim=0) / num_valid
+ return pooled_text
diff --git a/ultralytics/models/sam/sam3/geometry_encoders.py b/ultralytics/models/sam/sam3/geometry_encoders.py
new file mode 100644
index 0000000000..433c392d78
--- /dev/null
+++ b/ultralytics/models/sam/sam3/geometry_encoders.py
@@ -0,0 +1,415 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+import torch
+import torch.nn as nn
+import torchvision
+
+from ultralytics.nn.modules.utils import _get_clones
+from ultralytics.utils.ops import xywh2xyxy
+
+
+def is_right_padded(mask: torch.Tensor):
+ """Given a padding mask (following pytorch convention, 1s for padded values), returns whether the padding is on the
+ right or not.
+ """
+ return (mask.long() == torch.sort(mask.long(), dim=-1)[0]).all()
+
+
+def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
+ """
+ Concatenates two right-padded sequences, such that the resulting sequence
+ is contiguous and also right-padded.
+
+ Following pytorch's convention, tensors are sequence first, and the mask are
+ batch first, with 1s for padded values.
+
+ :param seq1: A tensor of shape (seq1_length, batch_size, hidden_size).
+ :param mask1: A tensor of shape (batch_size, seq1_length).
+ :param seq2: A tensor of shape (seq2_length, batch_size, hidden_size).
+ :param mask2: A tensor of shape (batch_size, seq2_length).
+ :param return_index: If True, also returns the index of the ids of the element of seq2
+ in the concatenated sequence. This can be used to retrieve the elements of seq2
+ :return: A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
+ otherwise (concatenated_sequence, concatenated_mask, index).
+ """
+ seq1_length, batch_size, hidden_size = seq1.shape
+ seq2_length, batch_size, hidden_size = seq2.shape
+
+ assert batch_size == seq1.size(1) == seq2.size(1) == mask1.size(0) == mask2.size(0)
+ assert hidden_size == seq1.size(2) == seq2.size(2)
+ assert seq1_length == mask1.size(1)
+ assert seq2_length == mask2.size(1)
+
+ torch._assert_async(is_right_padded(mask1))
+ torch._assert_async(is_right_padded(mask2))
+
+ actual_seq1_lengths = (~mask1).sum(dim=-1)
+ actual_seq2_lengths = (~mask2).sum(dim=-1)
+
+ final_lengths = actual_seq1_lengths + actual_seq2_lengths
+ max_length = seq1_length + seq2_length
+ concatenated_mask = (
+ torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) >= final_lengths[:, None]
+ )
+
+ # (max_len, batch_size, hidden_size)
+ concatenated_sequence = torch.zeros((max_length, batch_size, hidden_size), device=seq2.device, dtype=seq2.dtype)
+ concatenated_sequence[:seq1_length, :, :] = seq1
+
+ # At this point, the element of seq1 are in the right place
+ # We just need to shift the elements of seq2
+
+ index = torch.arange(seq2_length, device=seq2.device)[:, None].repeat(1, batch_size)
+ index = index + actual_seq1_lengths[None]
+
+ concatenated_sequence = concatenated_sequence.scatter(0, index[:, :, None].expand(-1, -1, hidden_size), seq2)
+
+ if return_index:
+ return concatenated_sequence, concatenated_mask, index
+
+ return concatenated_sequence, concatenated_mask
+
+
+class Prompt:
+ """Utility class to manipulate geometric prompts.
+
+ We expect the sequences in pytorch convention, that is sequence first, batch second The dimensions are expected as
+ follows: box_embeddings shape: N_boxes x B x C_box box_mask shape: B x N_boxes. Can be None if nothing is masked out
+ point_embeddings shape: N_points x B x C_point point_mask shape: B x N_points. Can be None if nothing is masked out
+ mask_embeddings shape: N_masks x B x 1 x H_mask x W_mask mask_mask shape: B x N_masks. Can be None if nothing is
+ masked out
+
+ We also store positive/negative labels. These tensors are also stored batch-first If they are None, we'll assume
+ positive labels everywhere box_labels: long tensor of shape N_boxes x B point_labels: long tensor of shape N_points
+ x B mask_labels: long tensor of shape N_masks x B
+ """
+
+ def __init__(self, box_embeddings=None, box_mask=None, box_labels=None):
+ """Initialize the Prompt object."""
+ # Check for null prompt
+ # Check for null prompt
+ if box_embeddings is None:
+ self.box_embeddings = None
+ self.box_labels = None
+ self.box_mask = None
+ return
+
+ # Get sequence length, batch size, and device
+ box_seq_len = box_embeddings.shape[0]
+ bs = box_embeddings.shape[1]
+ device = box_embeddings.device
+
+ # Initialize labels and attention mask if not provided
+ if box_labels is None:
+ box_labels = torch.ones(box_seq_len, bs, device=device, dtype=torch.long)
+ if box_mask is None:
+ box_mask = torch.zeros(bs, box_seq_len, device=device, dtype=torch.bool)
+
+ # Dimension checks
+ assert list(box_embeddings.shape[:2]) == [box_seq_len, bs], (
+ f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
+ )
+ assert box_embeddings.shape[-1] == 4, (
+ f"Expected box embeddings to have 4 coordinates, got {box_embeddings.shape[-1]}"
+ )
+ assert list(box_mask.shape) == [bs, box_seq_len], (
+ f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
+ )
+ assert list(box_labels.shape) == [box_seq_len, bs], (
+ f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
+ )
+
+ # Device checks
+ assert box_embeddings.device == device, (
+ f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
+ )
+ assert box_mask.device == device, f"Expected box mask to be on device {device}, got {box_mask.device}"
+ assert box_labels.device == device, f"Expected box labels to be on device {device}, got {box_labels.device}"
+
+ self.box_embeddings = box_embeddings
+ self.box_mask = box_mask
+ self.box_labels = box_labels
+
+ def append_boxes(self, boxes, labels=None, mask=None):
+ """Append box prompts to existing prompts.
+
+ Args:
+ boxes: Tensor of shape (N_new_boxes, B, 4) with normalized box coordinates
+ labels: Optional tensor of shape (N_new_boxes, B) with positive/negative labels
+ mask: Optional tensor of shape (B, N_new_boxes) for attention mask
+ """
+ if self.box_embeddings is None:
+ # First boxes - initialize
+ self.box_embeddings = boxes
+ bs = boxes.shape[1]
+ box_seq_len = boxes.shape[0]
+
+ if labels is None:
+ labels = torch.ones(box_seq_len, bs, device=boxes.device, dtype=torch.long)
+ if mask is None:
+ mask = torch.zeros(bs, box_seq_len, device=boxes.device, dtype=torch.bool)
+
+ self.box_labels = labels
+ self.box_mask = mask
+ return
+
+ # Append to existing boxes
+ bs = self.box_embeddings.shape[1]
+ assert boxes.shape[1] == bs, f"Batch size mismatch: expected {bs}, got {boxes.shape[1]}"
+
+ if labels is None:
+ labels = torch.ones(boxes.shape[0], bs, device=boxes.device, dtype=torch.long)
+ if mask is None:
+ mask = torch.zeros(bs, boxes.shape[0], dtype=torch.bool, device=boxes.device)
+
+ assert list(boxes.shape[:2]) == list(labels.shape[:2]), (
+ f"Shape mismatch between boxes {boxes.shape} and labels {labels.shape}"
+ )
+
+ # Concatenate using the helper function
+ self.box_labels, _ = concat_padded_sequences(
+ self.box_labels.unsqueeze(-1), self.box_mask, labels.unsqueeze(-1), mask
+ )
+ self.box_labels = self.box_labels.squeeze(-1)
+ self.box_embeddings, self.box_mask = concat_padded_sequences(self.box_embeddings, self.box_mask, boxes, mask)
+
+
+class SequenceGeometryEncoder(nn.Module):
+ """Encoder for geometric box prompts. Assumes boxes are passed in the "normalized CxCyWH" format.
+
+ Boxes can be encoded with any of the three possibilities:
+ - direct projection: linear projection from coordinate space to d_model
+ - pooling: RoI align features from the backbone
+ - pos encoder: position encoding of the box center
+
+ These three options are mutually compatible and will be summed if multiple are selected.
+
+ As an alternative, boxes can be encoded as two corner points (top-left and bottom-right).
+
+ The encoded sequence can be further processed with a transformer.
+ """
+
+ def __init__(
+ self,
+ encode_boxes_as_points: bool,
+ boxes_direct_project: bool,
+ boxes_pool: bool,
+ boxes_pos_enc: bool,
+ d_model: int,
+ pos_enc,
+ num_layers: int,
+ layer: nn.Module,
+ roi_size: int = 7,
+ add_cls: bool = True,
+ add_post_encode_proj: bool = True,
+ use_act_ckpt: bool = False,
+ ):
+ """Initialize the SequenceGeometryEncoder."""
+ super().__init__()
+
+ self.d_model = d_model
+ self.pos_enc = pos_enc
+ self.encode_boxes_as_points = encode_boxes_as_points
+ self.roi_size = roi_size
+
+ # Label embeddings: 2 labels if encoding as boxes (pos/neg)
+ # 6 labels if encoding as points (regular pos/neg, top-left pos/neg, bottom-right pos/neg)
+ num_labels = 6 if self.encode_boxes_as_points else 2
+ self.label_embed = torch.nn.Embedding(num_labels, self.d_model)
+
+ # CLS token for pooling
+ self.cls_embed = None
+ if add_cls:
+ self.cls_embed = torch.nn.Embedding(1, self.d_model)
+
+ # Point encoding (used when encode_boxes_as_points is True)
+ if encode_boxes_as_points:
+ self.points_direct_project = nn.Linear(2, self.d_model)
+ self.points_pool_project = None
+ self.points_pos_enc_project = None
+ else:
+ # Box encoding modules
+ assert boxes_direct_project or boxes_pos_enc or boxes_pool, "Error: need at least one way to encode boxes"
+ self.points_direct_project = None
+ self.points_pool_project = None
+ self.points_pos_enc_project = None
+
+ self.boxes_direct_project = None
+ self.boxes_pool_project = None
+ self.boxes_pos_enc_project = None
+
+ if boxes_direct_project:
+ self.boxes_direct_project = nn.Linear(4, self.d_model)
+ if boxes_pool:
+ self.boxes_pool_project = nn.Conv2d(self.d_model, self.d_model, self.roi_size)
+ if boxes_pos_enc:
+ self.boxes_pos_enc_project = nn.Linear(self.d_model + 2, self.d_model)
+
+ self.final_proj = None
+ if add_post_encode_proj:
+ self.final_proj = nn.Linear(self.d_model, self.d_model)
+ self.norm = nn.LayerNorm(self.d_model)
+
+ self.img_pre_norm = nn.Identity()
+ if self.points_pool_project is not None or self.boxes_pool_project is not None:
+ self.img_pre_norm = nn.LayerNorm(self.d_model)
+
+ self.encode = None
+ if num_layers > 0:
+ assert add_cls, "It's currently highly recommended to add a CLS when using a transformer"
+ self.encode = _get_clones(layer, num_layers)
+ self.encode_norm = nn.LayerNorm(self.d_model)
+
+ self.use_act_ckpt = use_act_ckpt
+
+ def _encode_points(self, points, points_mask, points_labels, img_feats):
+ """Encode points (used when boxes are converted to corner points)."""
+ # Direct projection of coordinates
+ points_embed = self.points_direct_project(points.to(img_feats.dtype))
+
+ # Add label embeddings
+ type_embed = self.label_embed(points_labels.long())
+ return type_embed + points_embed, points_mask
+
+ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats: torch.Tensor):
+ """Encode boxes using configured encoding methods."""
+ boxes_embed = None
+ n_boxes, bs = boxes.shape[:2]
+
+ if self.boxes_direct_project is not None:
+ proj = self.boxes_direct_project(boxes.to(img_feats.dtype))
+ boxes_embed = proj
+
+ if self.boxes_pool_project is not None:
+ H, W = img_feats.shape[-2:]
+
+ # Convert boxes to xyxy format and denormalize
+ boxes_xyxy = xywh2xyxy(boxes.to(img_feats.dtype))
+ scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
+ scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
+ scale = scale.view(1, 1, 4)
+ boxes_xyxy = boxes_xyxy * scale
+
+ # RoI align
+ sampled = torchvision.ops.roi_align(img_feats, boxes_xyxy.transpose(0, 1).unbind(0), self.roi_size)
+ assert list(sampled.shape) == [
+ bs * n_boxes,
+ self.d_model,
+ self.roi_size,
+ self.roi_size,
+ ]
+ proj = self.boxes_pool_project(sampled)
+ proj = proj.view(bs, n_boxes, self.d_model).transpose(0, 1)
+
+ if boxes_embed is None:
+ boxes_embed = proj
+ else:
+ boxes_embed = boxes_embed + proj
+
+ if self.boxes_pos_enc_project is not None:
+ cx, cy, w, h = boxes.unbind(-1)
+ enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten())
+ enc = enc.view(boxes.shape[0], boxes.shape[1], enc.shape[-1])
+
+ proj = self.boxes_pos_enc_project(enc.to(img_feats.dtype))
+ if boxes_embed is None:
+ boxes_embed = proj
+ else:
+ boxes_embed = boxes_embed + proj
+
+ # Add label embeddings
+ type_embed = self.label_embed(boxes_labels.long())
+ return type_embed + boxes_embed, boxes_mask
+
+ def forward(self, geo_prompt: Prompt, img_feats, img_sizes, img_pos_embeds=None):
+ """Encode geometric box prompts.
+
+ Args:
+ geo_prompt: Prompt object containing box embeddings, masks, and labels
+ img_feats: List of image features from backbone
+ img_sizes: List of (H, W) tuples for each feature level
+ img_pos_embeds: Optional position embeddings for image features
+
+ Returns:
+ Tuple of (encoded_embeddings, attention_mask)
+ """
+ boxes = geo_prompt.box_embeddings
+ boxes_mask = geo_prompt.box_mask
+ boxes_labels = geo_prompt.box_labels
+
+ seq_first_img_feats = img_feats[-1] # [H*W, B, C]
+ seq_first_img_pos_embeds = (
+ img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(seq_first_img_feats)
+ )
+
+ # Prepare image features for pooling if needed
+ if self.points_pool_project or self.boxes_pool_project:
+ assert len(img_feats) == len(img_sizes)
+ cur_img_feat = img_feats[-1]
+ cur_img_feat = self.img_pre_norm(cur_img_feat)
+ H, W = img_sizes[-1]
+ assert cur_img_feat.shape[0] == H * W
+ N, C = cur_img_feat.shape[-2:]
+ # Reshape to NxCxHxW
+ cur_img_feat = cur_img_feat.permute(1, 2, 0)
+ cur_img_feat = cur_img_feat.view(N, C, H, W)
+ img_feats = cur_img_feat
+
+ if self.encode_boxes_as_points:
+ # Convert boxes to corner points
+ assert boxes is not None and boxes.shape[-1] == 4
+
+ boxes_xyxy = xywh2xyxy(boxes)
+ top_left, bottom_right = boxes_xyxy.split(split_size=2, dim=-1)
+
+ # Adjust labels for corner points (offset by 2 and 4)
+ labels_tl = boxes_labels + 2
+ labels_br = boxes_labels + 4
+
+ # Concatenate top-left and bottom-right points
+ points = torch.cat([top_left, bottom_right], dim=0)
+ points_labels = torch.cat([labels_tl, labels_br], dim=0)
+ points_mask = torch.cat([boxes_mask, boxes_mask], dim=1)
+
+ final_embeds, final_mask = self._encode_points(
+ points=points,
+ points_mask=points_mask,
+ points_labels=points_labels,
+ img_feats=img_feats,
+ )
+ else:
+ # Encode boxes directly
+ final_embeds, final_mask = self._encode_boxes(
+ boxes=boxes,
+ boxes_mask=boxes_mask,
+ boxes_labels=boxes_labels,
+ img_feats=img_feats,
+ )
+
+ bs = final_embeds.shape[1]
+ assert final_mask.shape[0] == bs
+
+ # Add CLS token if configured
+ if self.cls_embed is not None:
+ cls = self.cls_embed.weight.view(1, 1, self.d_model).repeat(1, bs, 1)
+ cls_mask = torch.zeros(bs, 1, dtype=final_mask.dtype, device=final_mask.device)
+ final_embeds, final_mask = concat_padded_sequences(final_embeds, final_mask, cls, cls_mask)
+
+ # Final projection
+ if self.final_proj is not None:
+ final_embeds = self.norm(self.final_proj(final_embeds))
+
+ # Transformer encoding layers
+ if self.encode is not None:
+ for lay in self.encode:
+ final_embeds = lay(
+ tgt=final_embeds,
+ memory=seq_first_img_feats,
+ tgt_key_padding_mask=final_mask,
+ pos=seq_first_img_pos_embeds,
+ )
+ final_embeds = self.encode_norm(final_embeds)
+
+ return final_embeds, final_mask
diff --git a/ultralytics/models/sam/sam3/maskformer_segmentation.py b/ultralytics/models/sam/sam3/maskformer_segmentation.py
new file mode 100644
index 0000000000..f91fd5b1bc
--- /dev/null
+++ b/ultralytics/models/sam/sam3/maskformer_segmentation.py
@@ -0,0 +1,286 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+from __future__ import annotations
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+
+from ultralytics.nn.modules.transformer import MLP
+
+
+class LinearPresenceHead(nn.Sequential):
+ """Linear presence head for predicting the presence of classes in an image."""
+
+ def __init__(self, d_model):
+ """Initializes the LinearPresenceHead."""
+ # a hack to make `LinearPresenceHead` compatible with old checkpoints
+ super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1))
+
+ def forward(self, hs, prompt, prompt_mask):
+ """Forward pass of the presence head."""
+ return super().forward(hs)
+
+
+class MaskPredictor(nn.Module):
+ """Predicts masks from object queries and pixel embeddings."""
+
+ def __init__(self, hidden_dim, mask_dim):
+ """Initializes the MaskPredictor."""
+ super().__init__()
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+
+ def forward(self, obj_queries, pixel_embed):
+ """Predicts masks from object queries and pixel embeddings."""
+ if len(obj_queries.shape) == 3:
+ if pixel_embed.ndim == 3:
+ # batch size was omitted
+ mask_preds = torch.einsum("bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed)
+ else:
+ mask_preds = torch.einsum("bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed)
+ else:
+ # Assumed to have aux masks
+ if pixel_embed.ndim == 3:
+ # batch size was omitted
+ mask_preds = torch.einsum("lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed)
+ else:
+ mask_preds = torch.einsum("lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed)
+
+ return mask_preds
+
+
+class SegmentationHead(nn.Module):
+ """Segmentation head that predicts masks from backbone features and object queries."""
+
+ def __init__(
+ self,
+ hidden_dim,
+ upsampling_stages,
+ use_encoder_inputs=False,
+ aux_masks=False,
+ no_dec=False,
+ pixel_decoder=None,
+ act_ckpt=False,
+ shared_conv=False,
+ compile_mode_pixel_decoder=None,
+ ):
+ """Initializes the SegmentationHead."""
+ super().__init__()
+ self.use_encoder_inputs = use_encoder_inputs
+ self.aux_masks = aux_masks
+ if pixel_decoder is not None:
+ self.pixel_decoder = pixel_decoder
+ else:
+ self.pixel_decoder = PixelDecoder(
+ hidden_dim,
+ upsampling_stages,
+ shared_conv=shared_conv,
+ compile_mode=compile_mode_pixel_decoder,
+ )
+ self.no_dec = no_dec
+ if no_dec:
+ self.mask_predictor = nn.Conv2d(hidden_dim, 1, kernel_size=3, stride=1, padding=1)
+ else:
+ self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim)
+
+ self.act_ckpt = act_ckpt
+
+ # used to update the output dictionary
+ self.instance_keys = ["pred_masks"]
+
+ def _embed_pixels(self, backbone_feats: list[torch.Tensor], encoder_hidden_states) -> torch.Tensor:
+ """Embeds pixels using the pixel decoder."""
+ if self.use_encoder_inputs:
+ backbone_visual_feats = [bb_feat.clone() for bb_feat in backbone_feats]
+ # Extract visual embeddings
+ encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0)
+ spatial_dim = math.prod(backbone_feats[-1].shape[-2:])
+ encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape(-1, *backbone_feats[-1].shape[1:])
+
+ backbone_visual_feats[-1] = encoder_visual_embed
+ if self.act_ckpt:
+ pixel_embed = checkpoint.checkpoint(self.pixel_decoder, backbone_visual_feats, use_reentrant=False)
+ else:
+ pixel_embed = self.pixel_decoder(backbone_visual_feats)
+ else:
+ backbone_feats = [x for x in backbone_feats]
+ pixel_embed = self.pixel_decoder(backbone_feats)
+ if pixel_embed.shape[0] == 1:
+ # For batch_size=1 training, we can avoid the indexing to save memory
+ pixel_embed = pixel_embed.squeeze(0)
+ else:
+ pixel_embed = pixel_embed[[0], ...]
+ return pixel_embed
+
+ def forward(
+ self,
+ backbone_feats: list[torch.Tensor],
+ obj_queries: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ **kwargs,
+ ) -> dict[str, torch.Tensor]:
+ """Forward pass of the SegmentationHead."""
+ if self.use_encoder_inputs:
+ assert encoder_hidden_states is not None
+
+ pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states)
+
+ if self.no_dec:
+ mask_pred = self.mask_predictor(pixel_embed)
+ elif self.aux_masks:
+ mask_pred = self.mask_predictor(obj_queries, pixel_embed)
+ else:
+ mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed)
+
+ return {"pred_masks": mask_pred}
+
+
+class PixelDecoder(nn.Module):
+ """Pixel decoder module that upsamples backbone features."""
+
+ def __init__(
+ self,
+ hidden_dim,
+ num_upsampling_stages,
+ interpolation_mode="nearest",
+ shared_conv=False,
+ compile_mode=None,
+ ):
+ """Initializes the PixelDecoder."""
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.num_upsampling_stages = num_upsampling_stages
+ self.interpolation_mode = interpolation_mode
+ conv_layers = []
+ norms = []
+ num_convs = 1 if shared_conv else num_upsampling_stages
+ for _ in range(num_convs):
+ conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1))
+ norms.append(nn.GroupNorm(8, self.hidden_dim))
+
+ self.conv_layers = nn.ModuleList(conv_layers)
+ self.norms = nn.ModuleList(norms)
+ self.shared_conv = shared_conv
+ self.out_dim = self.conv_layers[-1].out_channels
+ if compile_mode is not None:
+ self.forward = torch.compile(self.forward, mode=compile_mode, dynamic=True, fullgraph=True)
+ # Needed to make checkpointing happy. But we don't know if the module is checkpointed, so we disable it by default.
+ torch._dynamo.config.optimize_ddp = False
+
+ def forward(self, backbone_feats: list[torch.Tensor]):
+ """Forward pass of the PixelDecoder."""
+ prev_fpn = backbone_feats[-1]
+ fpn_feats = backbone_feats[:-1]
+ for layer_idx, bb_feat in enumerate(fpn_feats[::-1]):
+ curr_fpn = bb_feat
+ prev_fpn = curr_fpn + F.interpolate(prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode)
+ if self.shared_conv:
+ # only one conv layer
+ layer_idx = 0
+ prev_fpn = self.conv_layers[layer_idx](prev_fpn)
+ prev_fpn = F.relu(self.norms[layer_idx](prev_fpn))
+
+ return prev_fpn
+
+
+class UniversalSegmentationHead(SegmentationHead):
+ """This module handles semantic+instance segmentation."""
+
+ def __init__(
+ self,
+ hidden_dim,
+ upsampling_stages,
+ pixel_decoder,
+ aux_masks=False,
+ no_dec=False,
+ act_ckpt=False,
+ presence_head: bool = False,
+ dot_product_scorer=None,
+ cross_attend_prompt=None,
+ ):
+ """Initializes the UniversalSegmentationHead."""
+ super().__init__(
+ hidden_dim=hidden_dim,
+ upsampling_stages=upsampling_stages,
+ use_encoder_inputs=True,
+ aux_masks=aux_masks,
+ no_dec=no_dec,
+ pixel_decoder=pixel_decoder,
+ act_ckpt=act_ckpt,
+ )
+ self.d_model = hidden_dim
+
+ if dot_product_scorer is not None:
+ assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
+
+ self.presence_head = None
+ if presence_head:
+ self.presence_head = (
+ dot_product_scorer if dot_product_scorer is not None else LinearPresenceHead(self.d_model)
+ )
+
+ self.cross_attend_prompt = cross_attend_prompt
+ if self.cross_attend_prompt is not None:
+ self.cross_attn_norm = nn.LayerNorm(self.d_model)
+
+ self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1)
+ self.instance_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, self.d_model, kernel_size=1)
+
+ def forward(
+ self,
+ backbone_feats: list[torch.Tensor],
+ obj_queries: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ prompt: torch.Tensor = None,
+ prompt_mask: torch.Tensor = None,
+ **kwargs,
+ ) -> dict[str, torch.Tensor]:
+ """Forward pass of the UniversalSegmentationHead."""
+ assert encoder_hidden_states is not None
+ bs = encoder_hidden_states.shape[1]
+
+ if self.cross_attend_prompt is not None:
+ tgt2 = self.cross_attn_norm(encoder_hidden_states)
+ tgt2 = self.cross_attend_prompt(
+ query=tgt2,
+ key=prompt.to(tgt2.dtype),
+ value=prompt.to(tgt2.dtype),
+ key_padding_mask=prompt_mask,
+ need_weights=False,
+ )[0]
+ encoder_hidden_states = tgt2 + encoder_hidden_states
+
+ presence_logit = None
+ if self.presence_head is not None:
+ pooled_enc = encoder_hidden_states.mean(0)
+ presence_logit = (
+ self.presence_head(
+ pooled_enc.view(1, bs, 1, self.d_model),
+ prompt=prompt,
+ prompt_mask=prompt_mask,
+ )
+ .squeeze(0)
+ .squeeze(1)
+ )
+
+ pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states)
+
+ instance_embeds = self.instance_seg_head(pixel_embed)
+
+ if self.no_dec:
+ mask_pred = self.mask_predictor(instance_embeds)
+ elif self.aux_masks:
+ mask_pred = self.mask_predictor(obj_queries, instance_embeds)
+ else:
+ mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds)
+
+ return {
+ "pred_masks": mask_pred,
+ "semantic_seg": self.semantic_seg_head(pixel_embed),
+ "presence_logit": presence_logit,
+ }
diff --git a/ultralytics/models/sam/sam3/model_misc.py b/ultralytics/models/sam/sam3/model_misc.py
new file mode 100644
index 0000000000..1b66b05a5f
--- /dev/null
+++ b/ultralytics/models/sam/sam3/model_misc.py
@@ -0,0 +1,198 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+"""Various utility models."""
+
+from __future__ import annotations
+
+import math
+
+import numpy as np
+import torch
+from torch import Tensor, nn
+
+
+class DotProductScoring(torch.nn.Module):
+ """A module that computes dot-product scores between a set of query features and a."""
+
+ def __init__(
+ self,
+ d_model,
+ d_proj,
+ prompt_mlp=None,
+ clamp_logits=True,
+ clamp_max_val=12.0,
+ ):
+ """Initialize the DotProductScoring module."""
+ super().__init__()
+ self.d_proj = d_proj
+ assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None
+ self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt
+ self.prompt_proj = torch.nn.Linear(d_model, d_proj)
+ self.hs_proj = torch.nn.Linear(d_model, d_proj)
+ self.scale = float(1.0 / np.sqrt(d_proj))
+ self.clamp_logits = clamp_logits
+ if self.clamp_logits:
+ self.clamp_max_val = clamp_max_val
+
+ def mean_pool_text(self, prompt, prompt_mask):
+ """Mean-pool the prompt embeddings over the valid tokens only."""
+ # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
+ is_valid = (~prompt_mask).to(prompt.dtype).permute(1, 0)[..., None]
+ # num_valid has shape (bs, 1)
+ num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
+ # mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim)
+ pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid
+ return pooled_prompt
+
+ def forward(self, hs, prompt, prompt_mask):
+ """Compute dot-product scores between hs and prompt."""
+ # hs has shape (num_layer, bs, num_query, d_model)
+ # prompt has shape (seq, bs, d_model)
+ # prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding
+ assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2
+
+ # apply MLP on prompt if specified
+ if self.prompt_mlp is not None:
+ prompt = self.prompt_mlp(prompt.to(hs.dtype))
+
+ # first, get the mean-pooled version of the prompt
+ pooled_prompt = self.mean_pool_text(prompt, prompt_mask)
+
+ # then, project pooled_prompt and hs to d_proj dimensions
+ proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj)
+ proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj)
+
+ # finally, get dot-product scores of shape (num_layer, bs, num_query, 1)
+ scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1))
+ scores *= self.scale
+
+ # clamp scores to a max value to avoid numerical issues in loss or matcher
+ if self.clamp_logits:
+ scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val)
+
+ return scores
+
+
+class LayerScale(nn.Module):
+ """LayerScale module as introduced in "Meta Pseudo Labels" and used in."""
+
+ def __init__(
+ self,
+ dim: int,
+ init_values: float | Tensor = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ """Initialize the LayerScale module."""
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Apply LayerScale to the input tensor."""
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class TransformerWrapper(nn.Module):
+ """A wrapper for the transformer consisting of an encoder and a decoder."""
+
+ def __init__(
+ self,
+ encoder,
+ decoder,
+ d_model: int,
+ two_stage_type="none", # ["none"] only for now
+ pos_enc_at_input_dec=True,
+ ):
+ """Initialize the TransformerWrapper."""
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ self.num_queries = decoder.num_queries if decoder is not None else None
+ self.pos_enc_at_input_dec = pos_enc_at_input_dec
+
+ # for two stage
+ assert two_stage_type in ["none"], f"unknown param {two_stage_type} of two_stage_type"
+ self.two_stage_type = two_stage_type
+
+ self._reset_parameters()
+ self.d_model = d_model
+
+ def _reset_parameters(self):
+ """Initialize the parameters of the model."""
+ for n, p in self.named_parameters():
+ if p.dim() > 1:
+ if "box_embed" not in n and "query_embed" not in n and "reference_points" not in n:
+ nn.init.xavier_uniform_(p)
+
+
+def get_valid_ratio(mask):
+ """Compute the valid ratio of height and width from the mask."""
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+
+def gen_sineembed_for_position(pos_tensor: torch.Tensor, num_feats: int = 256):
+ """Generate sinusoidal position embeddings for 2D or 4D coordinate tensors.
+
+ This function creates sinusoidal embeddings using sine and cosine functions at different frequencies, similar to the
+ positional encoding used in Transformer models. It supports both 2D position tensors (x, y) and 4D tensors (x, y, w,
+ h) for bounding box coordinates.
+
+ Args:
+ pos_tensor (torch.Tensor): Input position tensor of shape (n_query, bs, 2) for 2D coordinates or (n_query, bs,
+ 4) for 4D coordinates (bounding boxes).
+ num_feats (int): Number of feature dimensions for the output embedding. Must be even. Defaults to 256.
+
+ Returns:
+ (torch.Tensor): Sinusoidal position embeddings of shape (n_query, bs, num_feats) for 2D input or (n_query, bs,
+ num_feats * 2) for 4D input.
+
+ Raises:
+ AssertionError: If num_feats is not even.
+ ValueError: If pos_tensor.size(-1) is not 2 or 4.
+
+ Examples:
+ >>> pos_2d = torch.rand(100, 8, 2) # 100 queries, batch size 8, 2D coordinates
+ >>> embeddings_2d = gen_sineembed_for_position(pos_2d, num_feats=256)
+ >>> embeddings_2d.shape
+ torch.Size([100, 8, 256])
+ >>> pos_4d = torch.rand(50, 4, 4) # 50 queries, batch size 4, 4D coordinates
+ >>> embeddings_4d = gen_sineembed_for_position(pos_4d, num_feats=128)
+ >>> embeddings_4d.shape
+ torch.Size([50, 4, 256])
+ """
+ assert num_feats % 2 == 0
+ num_feats = num_feats // 2
+ # n_query, bs, _ = pos_tensor.size()
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
+ scale = 2 * math.pi
+ dim_t = torch.arange(num_feats, dtype=pos_tensor.dtype, device=pos_tensor.device)
+ dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats)
+ x_embed = pos_tensor[:, :, 0] * scale
+ y_embed = pos_tensor[:, :, 1] * scale
+ pos_x = x_embed[:, :, None] / dim_t
+ pos_y = y_embed[:, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
+ if pos_tensor.size(-1) == 2:
+ pos = torch.cat((pos_y, pos_x), dim=2)
+ elif pos_tensor.size(-1) == 4:
+ w_embed = pos_tensor[:, :, 2] * scale
+ pos_w = w_embed[:, :, None] / dim_t
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
+
+ h_embed = pos_tensor[:, :, 3] * scale
+ pos_h = h_embed[:, :, None] / dim_t
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
+
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
+ else:
+ raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
+ return pos
diff --git a/ultralytics/models/sam/sam3/necks.py b/ultralytics/models/sam/sam3/necks.py
new file mode 100644
index 0000000000..db2036ffde
--- /dev/null
+++ b/ultralytics/models/sam/sam3/necks.py
@@ -0,0 +1,129 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+"""Necks are the interface between a vision backbone and the rest of the detection model."""
+
+from __future__ import annotations
+
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+
+
+class Sam3DualViTDetNeck(nn.Module):
+ """A neck that implements a simple FPN as in ViTDet, with support for dual necks (for SAM3 and SAM2)."""
+
+ def __init__(
+ self,
+ trunk: nn.Module,
+ position_encoding: nn.Module,
+ d_model: int,
+ scale_factors=(4.0, 2.0, 1.0, 0.5),
+ add_sam2_neck: bool = False,
+ ):
+ """
+ SimpleFPN neck a la ViTDet
+ (From detectron2, very lightly adapted)
+ It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights.
+
+ :param trunk: the backbone
+ :param position_encoding: the positional encoding to use
+ :param d_model: the dimension of the model
+ """
+ super().__init__()
+ self.trunk = trunk
+ self.position_encoding = position_encoding
+ self.convs = nn.ModuleList()
+
+ self.scale_factors = scale_factors
+ use_bias = True
+ dim: int = self.trunk.channel_list[-1]
+
+ for _, scale in enumerate(scale_factors):
+ current = nn.Sequential()
+
+ if scale == 4.0:
+ current.add_module(
+ "dconv_2x2_0",
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
+ )
+ current.add_module(
+ "gelu",
+ nn.GELU(),
+ )
+ current.add_module(
+ "dconv_2x2_1",
+ nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
+ )
+ out_dim = dim // 4
+ elif scale == 2.0:
+ current.add_module(
+ "dconv_2x2",
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
+ )
+ out_dim = dim // 2
+ elif scale == 1.0:
+ out_dim = dim
+ elif scale == 0.5:
+ current.add_module(
+ "maxpool_2x2",
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ )
+ out_dim = dim
+ else:
+ raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
+
+ current.add_module(
+ "conv_1x1",
+ nn.Conv2d(
+ in_channels=out_dim,
+ out_channels=d_model,
+ kernel_size=1,
+ bias=use_bias,
+ ),
+ )
+ current.add_module(
+ "conv_3x3",
+ nn.Conv2d(
+ in_channels=d_model,
+ out_channels=d_model,
+ kernel_size=3,
+ padding=1,
+ bias=use_bias,
+ ),
+ )
+ self.convs.append(current)
+
+ self.sam2_convs = None
+ if add_sam2_neck:
+ # Assumes sam2 neck is just a clone of the original neck
+ self.sam2_convs = deepcopy(self.convs)
+
+ def forward(
+ self, tensor_list: list[torch.Tensor]
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
+ """Get the feature maps and positional encodings from the neck."""
+ xs = self.trunk(tensor_list)
+ sam3_out, sam3_pos = [], []
+ sam2_out, sam2_pos = None, None
+ if self.sam2_convs is not None:
+ sam2_out, sam2_pos = [], []
+ x = xs[-1] # simpleFPN
+ for i in range(len(self.convs)):
+ sam3_x_out = self.convs[i](x)
+ sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype)
+ sam3_out.append(sam3_x_out)
+ sam3_pos.append(sam3_pos_out)
+
+ if self.sam2_convs is not None:
+ sam2_x_out = self.sam2_convs[i](x)
+ sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
+ sam2_out.append(sam2_x_out)
+ sam2_pos.append(sam2_pos_out)
+ return sam3_out, sam3_pos, sam2_out, sam2_pos
+
+ def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
+ """Set the image size for the trunk backbone."""
+ self.trunk.set_imgsz(imgsz)
diff --git a/ultralytics/models/sam/sam3/sam3_image.py b/ultralytics/models/sam/sam3/sam3_image.py
new file mode 100644
index 0000000000..c8bccc92f4
--- /dev/null
+++ b/ultralytics/models/sam/sam3/sam3_image.py
@@ -0,0 +1,357 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+from __future__ import annotations
+
+from copy import deepcopy
+
+import torch
+
+from ultralytics.nn.modules.utils import inverse_sigmoid
+from ultralytics.utils.ops import xywh2xyxy
+
+from .geometry_encoders import Prompt
+from .vl_combiner import SAM3VLBackbone
+
+
+def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True):
+ """Helper function to update output dictionary with main and auxiliary outputs."""
+ out[out_name] = out_value[-1] if auxiliary else out_value
+ if auxiliary and update_aux:
+ if "aux_outputs" not in out:
+ out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)]
+ assert len(out["aux_outputs"]) == len(out_value) - 1
+ for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]):
+ aux_output[out_name] = aux_value
+
+
+class SAM3SemanticModel(torch.nn.Module):
+ """SAM3 model for semantic segmentation with vision-language backbone."""
+
+ def __init__(
+ self,
+ backbone: SAM3VLBackbone,
+ transformer,
+ input_geometry_encoder,
+ segmentation_head=None,
+ num_feature_levels=1,
+ o2m_mask_predict=True,
+ dot_prod_scoring=None,
+ use_instance_query: bool = True,
+ multimask_output: bool = True,
+ use_act_checkpoint_seg_head: bool = True,
+ matcher=None,
+ use_dot_prod_scoring=True,
+ supervise_joint_box_scores: bool = False, # only relevant if using presence token/score
+ detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score
+ separate_scorer_for_instance: bool = False,
+ num_interactive_steps_val: int = 0,
+ ):
+ """Initialize the SAM3SemanticModel."""
+ super().__init__()
+ self.backbone = backbone
+ self.geometry_encoder = input_geometry_encoder
+ self.transformer = transformer
+ self.hidden_dim = transformer.d_model
+ self.num_feature_levels = num_feature_levels
+ self.segmentation_head = segmentation_head
+
+ self.o2m_mask_predict = o2m_mask_predict
+
+ self.dot_prod_scoring = dot_prod_scoring
+ self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head
+ self.matcher = matcher
+
+ self.num_interactive_steps_val = num_interactive_steps_val
+ self.use_dot_prod_scoring = use_dot_prod_scoring
+
+ if self.use_dot_prod_scoring:
+ assert dot_prod_scoring is not None
+ self.dot_prod_scoring = dot_prod_scoring
+ self.instance_dot_prod_scoring = None
+ if separate_scorer_for_instance:
+ self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring)
+ else:
+ self.class_embed = torch.nn.Linear(self.hidden_dim, 1)
+ self.instance_class_embed = None
+ if separate_scorer_for_instance:
+ self.instance_class_embed = deepcopy(self.class_embed)
+
+ self.supervise_joint_box_scores = supervise_joint_box_scores
+ self.detach_presence_in_joint_score = detach_presence_in_joint_score
+
+ # verify the number of queries for O2O and O2M
+ num_o2o_static = self.transformer.decoder.num_queries
+ num_o2m_static = self.transformer.decoder.num_o2m_queries
+ assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0)
+ self.dac = self.transformer.decoder.dac
+
+ self.use_instance_query = use_instance_query
+ self.multimask_output = multimask_output
+
+ self.text_embeddings = {}
+ self.names = []
+
+ def _prepare_backbone_features(self, backbone_out, num_prompts=1):
+ """Prepare and flatten visual features from the image backbone output for further processing."""
+ if num_prompts > 1: # expand features if there's more than one prompt
+ for i, feat in enumerate(backbone_out["backbone_fpn"]):
+ backbone_out["backbone_fpn"][i] = feat.expand(num_prompts, -1, -1, -1)
+ for i, pos in enumerate(backbone_out["vision_pos_enc"]):
+ pos = pos.expand(num_prompts, -1, -1, -1)
+ backbone_out["vision_pos_enc"][i] = pos
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
+
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
+ # flatten NxCxHxW to HWxNxC
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
+
+ def _encode_prompt(
+ self,
+ img_feats,
+ img_pos_embeds,
+ vis_feat_sizes,
+ geometric_prompt,
+ visual_prompt_embed=None,
+ visual_prompt_mask=None,
+ prev_mask_pred=None,
+ ):
+ """Encode the geometric and visual prompts."""
+ if prev_mask_pred is not None:
+ img_feats = [img_feats[-1] + prev_mask_pred]
+ # Encode geometry
+ geo_feats, geo_masks = self.geometry_encoder(
+ geo_prompt=geometric_prompt,
+ img_feats=img_feats,
+ img_sizes=vis_feat_sizes,
+ img_pos_embeds=img_pos_embeds,
+ )
+ if visual_prompt_embed is None:
+ visual_prompt_embed = torch.zeros((0, *geo_feats.shape[1:]), device=geo_feats.device)
+ visual_prompt_mask = torch.zeros(
+ (*geo_masks.shape[:-1], 0),
+ device=geo_masks.device,
+ dtype=geo_masks.dtype,
+ )
+ prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0)
+ prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1)
+ return prompt, prompt_mask
+
+ def _run_encoder(
+ self,
+ img_feats,
+ img_pos_embeds,
+ vis_feat_sizes,
+ prompt,
+ prompt_mask,
+ encoder_extra_kwargs: dict | None = None,
+ ):
+ """Run the transformer encoder."""
+ # Run the encoder
+ # make a copy of the image feature lists since the encoder may modify these lists in-place
+ memory = self.transformer.encoder(
+ src=img_feats.copy(),
+ src_key_padding_mask=None,
+ src_pos=img_pos_embeds.copy(),
+ prompt=prompt,
+ prompt_key_padding_mask=prompt_mask,
+ feat_sizes=vis_feat_sizes,
+ encoder_extra_kwargs=encoder_extra_kwargs,
+ )
+ encoder_out = {
+ # encoded image features
+ "encoder_hidden_states": memory["memory"],
+ "pos_embed": memory["pos_embed"],
+ "padding_mask": memory["padding_mask"],
+ "spatial_shapes": memory["spatial_shapes"],
+ "valid_ratios": memory["valid_ratios"],
+ "vis_feat_sizes": vis_feat_sizes,
+ # encoded text features (or other prompts)
+ "prompt_before_enc": prompt,
+ "prompt_after_enc": memory.get("memory_text", prompt),
+ "prompt_mask": prompt_mask,
+ }
+ return encoder_out
+
+ def _run_decoder(
+ self,
+ pos_embed,
+ memory,
+ src_mask,
+ out,
+ prompt,
+ prompt_mask,
+ encoder_out,
+ ):
+ """Run the transformer decoder."""
+ bs = memory.shape[1]
+ query_embed = self.transformer.decoder.query_embed.weight
+ tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
+
+ hs, reference_boxes, dec_presence_out, _ = self.transformer.decoder(
+ tgt=tgt,
+ memory=memory,
+ memory_key_padding_mask=src_mask,
+ pos=pos_embed,
+ reference_boxes=None,
+ spatial_shapes=encoder_out["spatial_shapes"],
+ valid_ratios=encoder_out["valid_ratios"],
+ tgt_mask=None,
+ memory_text=prompt,
+ text_attention_mask=prompt_mask,
+ apply_dac=False,
+ )
+ hs = hs.transpose(1, 2) # seq-first to batch-first
+ reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first
+ if dec_presence_out is not None:
+ # seq-first to batch-first
+ dec_presence_out = dec_presence_out.transpose(1, 2)
+ self._update_scores_and_boxes(
+ out,
+ hs,
+ reference_boxes,
+ prompt,
+ prompt_mask,
+ dec_presence_out=dec_presence_out,
+ )
+ return out, hs
+
+ def _update_scores_and_boxes(
+ self,
+ out,
+ hs,
+ reference_boxes,
+ prompt,
+ prompt_mask,
+ dec_presence_out=None,
+ is_instance_prompt=False,
+ ):
+ """Update output dict with class scores and box predictions."""
+ num_o2o = hs.size(2)
+ # score prediction
+ if self.use_dot_prod_scoring:
+ dot_prod_scoring_head = self.dot_prod_scoring
+ if is_instance_prompt and self.instance_dot_prod_scoring is not None:
+ dot_prod_scoring_head = self.instance_dot_prod_scoring
+ outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask)
+ else:
+ class_embed_head = self.class_embed
+ if is_instance_prompt and self.instance_class_embed is not None:
+ class_embed_head = self.instance_class_embed
+ outputs_class = class_embed_head(hs)
+
+ # box prediction
+ box_head = self.transformer.decoder.bbox_embed
+ if is_instance_prompt and self.transformer.decoder.instance_bbox_embed is not None:
+ box_head = self.transformer.decoder.instance_bbox_embed
+ anchor_box_offsets = box_head(hs)
+ reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
+ outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid()
+ outputs_boxes_xyxy = xywh2xyxy(outputs_coord)
+
+ if dec_presence_out is not None:
+ _update_out(out, "presence_logit_dec", dec_presence_out, update_aux=False)
+
+ if self.supervise_joint_box_scores:
+ assert dec_presence_out is not None
+ prob_dec_presence_out = dec_presence_out.clone().sigmoid()
+ if self.detach_presence_in_joint_score:
+ prob_dec_presence_out = prob_dec_presence_out.detach()
+
+ outputs_class = inverse_sigmoid(outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)).clamp(
+ min=-10.0, max=10.0
+ )
+
+ _update_out(out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=False)
+ _update_out(out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=False)
+ _update_out(out, "pred_boxes_xyxy", outputs_boxes_xyxy[:, :, :num_o2o], update_aux=False)
+
+ def _run_segmentation_heads(
+ self,
+ out,
+ backbone_out,
+ encoder_hidden_states,
+ prompt,
+ prompt_mask,
+ hs,
+ ):
+ """Run segmentation heads and get masks."""
+ if self.segmentation_head is not None:
+ num_o2o = hs.size(2)
+ obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o]
+ seg_head_outputs = self.segmentation_head(
+ backbone_feats=backbone_out["backbone_fpn"],
+ obj_queries=obj_queries,
+ encoder_hidden_states=encoder_hidden_states,
+ prompt=prompt,
+ prompt_mask=prompt_mask,
+ )
+ for k, v in seg_head_outputs.items():
+ if k in self.segmentation_head.instance_keys:
+ _update_out(out, k, v[:, :num_o2o], auxiliary=False)
+ else:
+ out[k] = v
+ else:
+ backbone_out.pop("backbone_fpn", None)
+
+ def forward_grounding(
+ self, backbone_out: dict[str, torch.Tensor], text_ids: torch.Tensor, geometric_prompt: Prompt = None
+ ):
+ """Forward pass for grounding (detection + segmentation) given input images and text."""
+ backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = self._prepare_backbone_features(
+ backbone_out, num_prompts=len(text_ids)
+ )
+ backbone_out.update({k: v for k, v in self.text_embeddings.items()})
+ with torch.profiler.record_function("SAM3Image._encode_prompt"):
+ prompt, prompt_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt)
+ # index text features (note that regardless of early or late fusion, the batch size of
+ # `txt_feats` is always the number of *prompts* in the encoder)
+ txt_feats = backbone_out["language_features"][:, text_ids]
+ txt_masks = backbone_out["language_mask"][text_ids]
+ # encode text
+ prompt = torch.cat([txt_feats, prompt], dim=0)
+ prompt_mask = torch.cat([txt_masks, prompt_mask], dim=1)
+
+ # Run the encoder
+ with torch.profiler.record_function("SAM3Image._run_encoder"):
+ encoder_out = self._run_encoder(img_feats, img_pos_embeds, vis_feat_sizes, prompt, prompt_mask)
+ out = {"backbone_out": backbone_out}
+
+ # Run the decoder
+ with torch.profiler.record_function("SAM3Image._run_decoder"):
+ out, hs = self._run_decoder(
+ memory=encoder_out["encoder_hidden_states"],
+ pos_embed=encoder_out["pos_embed"],
+ src_mask=encoder_out["padding_mask"],
+ out=out,
+ prompt=prompt,
+ prompt_mask=prompt_mask,
+ encoder_out=encoder_out,
+ )
+
+ # Run segmentation heads
+ with torch.profiler.record_function("SAM3Image._run_segmentation_heads"):
+ self._run_segmentation_heads(
+ out=out,
+ backbone_out=backbone_out,
+ encoder_hidden_states=encoder_out["encoder_hidden_states"],
+ prompt=prompt,
+ prompt_mask=prompt_mask,
+ hs=hs,
+ )
+ return out
+
+ def set_classes(self, text: list[str]):
+ """Set the text embeddings for the given class names."""
+ self.text_embeddings = self.backbone.forward_text(text)
+ self.names = text
+
+ def set_imgsz(self, imgsz: tuple[int, int]):
+ """Set the image size for the model."""
+ self.backbone.set_imgsz(imgsz)
diff --git a/ultralytics/models/sam/sam3/text_encoder_ve.py b/ultralytics/models/sam/sam3/text_encoder_ve.py
new file mode 100644
index 0000000000..7c45a2ed27
--- /dev/null
+++ b/ultralytics/models/sam/sam3/text_encoder_ve.py
@@ -0,0 +1,307 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+from __future__ import annotations
+
+from collections import OrderedDict
+from typing import Callable
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from .model_misc import LayerScale
+
+
+class ResidualAttentionBlock(nn.Module):
+ """Transformer block with multi-head attention, layer normalization, and MLP feed-forward network."""
+
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float | None = None,
+ act_layer: Callable[[], nn.Module] = nn.GELU,
+ norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
+ ):
+ """Initialize residual attention block with configurable dimensions and normalization."""
+ super().__init__()
+ # Attention
+ self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
+
+ # LayerNorm, LayerScale
+ self.ln_1 = norm_layer(d_model)
+ self.ln_2 = norm_layer(d_model)
+
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ # MLP
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(
+ OrderedDict(
+ [
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model)),
+ ]
+ )
+ )
+
+ def attention(
+ self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None
+ ) -> torch.Tensor:
+ """Compute multi-head attention with optional cross-attention support and masking."""
+ k_x = k_x if k_x is not None else q_x
+ v_x = v_x if v_x is not None else q_x
+ if attn_mask is not None:
+ # Leave boolean masks as is
+ if not attn_mask.dtype == torch.bool:
+ attn_mask = attn_mask.to(q_x.dtype)
+
+ return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0]
+
+ def forward(
+ self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None
+ ) -> torch.Tensor:
+ """Apply residual attention with layer normalization and MLP, supporting optional cross-attention."""
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
+ x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
+ return x
+
+
+class Transformer(nn.Module):
+ """Stack of residual attention blocks forming a transformer encoder with optional gradient checkpointing."""
+
+ def __init__(
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float | None = None,
+ act_layer: Callable[[], nn.Module] = nn.GELU,
+ norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
+ compile_mode: str | None = None,
+ use_act_checkpoint: bool = False,
+ ):
+ """Initialize transformer with configurable depth, width, and optional compilation/checkpointing."""
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.grad_checkpointing = use_act_checkpoint
+ self.resblocks = nn.ModuleList(
+ [
+ ResidualAttentionBlock(
+ width,
+ heads,
+ mlp_ratio,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+ for _ in range(layers)
+ ]
+ )
+
+ if compile_mode is not None:
+ self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
+ if self.grad_checkpointing:
+ torch._dynamo.config.optimize_ddp = False
+
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor:
+ """Process input through all transformer blocks with optional gradient checkpointing during training."""
+ for _, r in enumerate(self.resblocks):
+ if self.grad_checkpointing and not torch.jit.is_scripting() and self.training:
+ x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+
+def text_global_pool(
+ x: torch.Tensor, text: torch.Tensor = None, pool_type: str = "argmax"
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Extract pooled representation and tokens from text embeddings using specified pooling strategy
+ (first/last/argmax/none).
+ """
+ if pool_type == "first":
+ pooled, tokens = x[:, 0], x[:, 1:]
+ elif pool_type == "last":
+ pooled, tokens = x[:, -1], x[:, :-1]
+ elif pool_type == "argmax":
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ assert text is not None
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
+ else:
+ pooled = tokens = x
+ return pooled, tokens
+
+
+class TextTransformer(nn.Module):
+ """Text transformer encoder with causal masking and flexible pooling strategies."""
+
+ def __init__(
+ self,
+ context_length: int = 77,
+ vocab_size: int = 49408,
+ width: int = 512,
+ heads: int = 8,
+ layers: int = 12,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float | None = None,
+ output_dim: int = 512,
+ no_causal_mask: bool = False,
+ pool_type: str = "none", # no pooling
+ proj_bias: bool = False,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = nn.LayerNorm,
+ output_tokens: bool = False,
+ use_ln_post: bool = True,
+ compile_mode: str | None = None,
+ use_act_checkpoint: bool = False,
+ ):
+ """Initialize text transformer with embedding layers, transformer blocks, and pooling options."""
+ super().__init__()
+ assert pool_type in ("first", "last", "argmax", "none")
+ self.output_tokens = output_tokens
+ self.num_pos = self.context_length = context_length
+ self.vocab_size = vocab_size
+ self.width = width
+ self.output_dim = output_dim
+ self.heads = heads
+ self.pool_type = pool_type
+
+ self.token_embedding = nn.Embedding(self.vocab_size, width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
+ self.transformer = Transformer(
+ width=width,
+ layers=layers,
+ heads=heads,
+ mlp_ratio=mlp_ratio,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ compile_mode=compile_mode,
+ use_act_checkpoint=use_act_checkpoint,
+ )
+ self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
+ if no_causal_mask:
+ self.attn_mask = None
+ else:
+ self.register_buffer("attn_mask", self.build_causal_mask(), persistent=False)
+ if proj_bias:
+ self.text_projection = nn.Linear(width, output_dim)
+ else:
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
+
+ def build_causal_mask(self) -> torch.Tensor:
+ """Create a causal attention mask to prevent attention to future tokens."""
+ # lazily create causal attention mask, with full attention between the tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.num_pos, self.num_pos)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def forward(self, text: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ """Forward pass through the text transformer, returning pooled output and optionally token embeddings."""
+ seq_len = text.shape[1]
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
+
+ attn_mask = self.attn_mask
+ if attn_mask is not None:
+ attn_mask = attn_mask[:seq_len, :seq_len]
+
+ x = x + self.positional_embedding[:seq_len]
+ x = self.transformer(x, attn_mask=attn_mask)
+
+ x = self.ln_final(x)
+ pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
+ if self.text_projection is not None:
+ if isinstance(self.text_projection, nn.Linear):
+ pooled = self.text_projection(pooled)
+ else:
+ pooled = pooled @ self.text_projection
+ if self.output_tokens:
+ return pooled, tokens
+ return pooled
+
+
+class VETextEncoder(nn.Module):
+ """Text encoder for Vision Encoder (VE) models, combining a text transformer and a linear resizer."""
+
+ def __init__(
+ self,
+ d_model: int,
+ tokenizer: Callable,
+ width: int = 1024,
+ heads: int = 16,
+ layers: int = 24,
+ context_length: int = 32,
+ vocab_size: int = 49408,
+ use_ln_post: bool = True,
+ compile_mode: str | None = None,
+ use_act_checkpoint: bool = True,
+ ):
+ """Initialize VE text encoder with a text transformer and a linear resizer to match decoder dimensions."""
+ super().__init__()
+ self.context_length = context_length
+ self.use_ln_post = use_ln_post
+ self.tokenizer = tokenizer
+
+ self.encoder = TextTransformer(
+ context_length=self.context_length,
+ vocab_size=vocab_size,
+ width=width,
+ heads=heads,
+ layers=layers,
+ # we want the tokens, not just the pooled output
+ output_tokens=True,
+ use_ln_post=use_ln_post,
+ compile_mode=compile_mode,
+ use_act_checkpoint=use_act_checkpoint,
+ )
+ self.resizer = nn.Linear(self.encoder.width, d_model)
+
+ def forward(
+ self, text: list[str] | tuple[torch.Tensor, torch.Tensor, dict], input_boxes: list | None = None
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Encode text input, either raw strings or pre-encoded tensors, and resize to match decoder dimensions."""
+ if isinstance(text[0], str):
+ # no use case for this
+ assert input_boxes is None or len(input_boxes) == 0, "not supported"
+
+ # Encode the text
+ tokenized = self.tokenizer(text, context_length=self.context_length).to(
+ self.resizer.weight.device
+ ) # [b, seq_len]
+ text_attention_mask = (tokenized != 0).bool()
+
+ # manually embed the tokens
+ inputs_embeds = self.encoder.token_embedding(tokenized) # [b, seq_len, d=1024]
+ _, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024]
+
+ assert text_memory.shape[1] == inputs_embeds.shape[1]
+ # Invert attention mask because its the opposite in pytorch transformer
+ text_attention_mask = text_attention_mask.ne(1)
+ # Transpose memory because pytorch's attention expects sequence first
+ text_memory = text_memory.transpose(0, 1)
+ # Resize the encoder hidden states to be of the same d_model as the decoder
+ text_memory_resized = self.resizer(text_memory)
+ else:
+ # The text is already encoded, use as is.
+ text_attention_mask, text_memory_resized, tokenized = text
+ inputs_embeds = tokenized["inputs_embeds"]
+ assert input_boxes is None or len(input_boxes) == 0, "Can't replace boxes in text if it's already encoded"
+
+ # Note that the input_embeds are returned in pytorch's convention (sequence first)
+ return (
+ text_attention_mask,
+ text_memory_resized,
+ inputs_embeds.transpose(0, 1),
+ )
diff --git a/ultralytics/models/sam/sam3/tokenizer_ve.py b/ultralytics/models/sam/sam3/tokenizer_ve.py
new file mode 100644
index 0000000000..05810b4dc4
--- /dev/null
+++ b/ultralytics/models/sam/sam3/tokenizer_ve.py
@@ -0,0 +1,242 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+"""
+Text Tokenizer.
+
+Copied and lightly adapted from VE repo, which in turn copied
+from open_clip and openAI CLIP.
+"""
+
+from __future__ import annotations
+
+import gzip
+import html
+import io
+import os
+import string
+from functools import lru_cache
+
+import ftfy
+import regex as re
+import torch
+from iopath.common.file_io import g_pathmgr
+
+
+@lru_cache
+def bytes_to_unicode():
+ """Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode
+ strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When
+ you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a
+ significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8
+ bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("ยก"), ord("ยฌ") + 1)) + list(range(ord("ยฎ"), ord("รฟ") + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length
+ strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ """Basic text cleaning: fix unicode and unescape HTML entities."""
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ """Remove redundant whitespace."""
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def _clean_canonicalize(x):
+ """Clean text and canonicalize it."""
+ # basic, remove whitespace, remove punctuation, lower case
+ return canonicalize_text(basic_clean(x))
+
+
+def _clean_lower(x):
+ """Clean text and return lowercase."""
+ # basic, remove whitespace, lower case
+ return whitespace_clean(basic_clean(x)).lower()
+
+
+def _clean_whitespace(x):
+ """Clean text and remove redundant whitespace."""
+ # basic, remove whitespace
+ return whitespace_clean(basic_clean(x))
+
+
+def get_clean_fn(type: str):
+ """Get text cleaning function by name."""
+ if type == "canonicalize":
+ return _clean_canonicalize
+ elif type == "lower":
+ return _clean_lower
+ elif type == "whitespace":
+ return _clean_whitespace
+ else:
+ assert False, f"Invalid clean function ({type})."
+
+
+def canonicalize_text(text, *, keep_punctuation_exact_string=None):
+ """Returns canonicalized `text` (lowercase and punctuation removed). From:
+ https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94.
+
+ Args:
+ text: string to be canonicalized.
+ keep_punctuation_exact_string: If provided, then this exact string kept. For example providing '{}' will keep
+ any occurrences of '{}' (but will still remove '{' and '}' that appear separately).
+ """
+ text = text.replace("_", " ")
+ if keep_punctuation_exact_string:
+ text = keep_punctuation_exact_string.join(
+ part.translate(str.maketrans("", "", string.punctuation))
+ for part in text.split(keep_punctuation_exact_string)
+ )
+ else:
+ text = text.translate(str.maketrans("", "", string.punctuation))
+ text = text.lower()
+ text = re.sub(r"\s+", " ", text)
+ return text.strip()
+
+
+class SimpleTokenizer:
+ """A simple tokenizer for text inputs."""
+
+ def __init__(
+ self,
+ bpe_path: str | os.PathLike,
+ additional_special_tokens: list[str] | None = None,
+ context_length: int = 77,
+ clean: str = "lower",
+ ):
+ """The tokenizer for text inputs."""
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ with g_pathmgr.open(bpe_path, "rb") as fh:
+ bpe_bytes = io.BytesIO(fh.read())
+ merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
+ # merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
+ merges = merges[1 : 49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + "" for v in vocab]
+ for merge in merges:
+ vocab.append("".join(merge))
+ special_tokens = ["", ""]
+ if additional_special_tokens:
+ special_tokens += additional_special_tokens
+ vocab.extend(special_tokens)
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {t: t for t in special_tokens}
+ special = "|".join(special_tokens)
+ self.pat = re.compile(
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ re.IGNORECASE,
+ )
+ self.vocab_size = len(self.encoder)
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
+ self.sot_token_id = self.all_special_ids[0]
+ self.eot_token_id = self.all_special_ids[1]
+ self.context_length = context_length
+ self.clean_fn = get_clean_fn(clean)
+
+ def bpe(self, token):
+ """Byte Pair Encoding."""
+ if token in self.cache:
+ return self.cache[token]
+ word = (*tuple(token[:-1]), token[-1] + "")
+ pairs = get_pairs(word)
+ if not pairs:
+ return token + ""
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except Exception:
+ new_word.extend(word[i:])
+ break
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ """Encode text to a sequence of BPE tokens."""
+ bpe_tokens = []
+ text = self.clean_fn(text)
+ for token in re.findall(self.pat, text):
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ """Decodes a sequence of tokens back into a text string."""
+ text = "".join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("", " ")
+ return text
+
+ def __call__(self, texts: str | list[str], context_length: int | None = None) -> torch.LongTensor:
+ """Returns the tokenized representation of given input string(s) Parameters. ---------- texts : Union[str,
+ list[str]] An input string or a list of input strings to tokenize context_length : int The context
+ length to use; all CLIP models use 77 as the context length.
+
+ Returns:
+ -------: A two-dimensional tensor containing the resulting tokens, shape = [number of input strings,
+ context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+ context_length = context_length or self.context_length
+ assert context_length, "Please set a valid context length"
+ all_tokens = [[self.sot_token_id, *self.encode(text), self.eot_token_id] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length] # Truncate
+ tokens[-1] = self.eot_token_id
+ result[i, : len(tokens)] = torch.tensor(tokens)
+ return result
diff --git a/ultralytics/models/sam/sam3/vitdet.py b/ultralytics/models/sam/sam3/vitdet.py
new file mode 100644
index 0000000000..ee28fba738
--- /dev/null
+++ b/ultralytics/models/sam/sam3/vitdet.py
@@ -0,0 +1,546 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+"""
+ViTDet backbone adapted from Detectron2.
+This module implements Vision Transformer (ViT) backbone for object detection.
+
+Rope embedding code adopted from:
+1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
+2. https://github.com/naver-ai/rope-vit
+3. https://github.com/lucidrains/rotary-embedding-torch
+"""
+
+from __future__ import annotations
+
+import math
+from functools import partial
+from typing import Callable
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from torch import Tensor
+
+from ultralytics.models.sam.modules.blocks import PatchEmbed
+from ultralytics.models.sam.modules.utils import (
+ apply_rotary_enc,
+ compute_axial_cis,
+ concat_rel_pos,
+ get_abs_pos,
+ window_partition,
+ window_unpartition,
+)
+from ultralytics.utils.checks import check_requirements
+
+from .model_misc import LayerScale
+
+
+class Attention(nn.Module):
+ """Multi-head Attention block with relative position embeddings and 2d-rope."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ input_size: tuple[int, int] | None = None,
+ cls_token: bool = False,
+ use_rope: bool = False,
+ rope_theta: float = 10000.0,
+ rope_pt_size: tuple[int, int] | None = None,
+ rope_interp: bool = False,
+ ):
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ input_size (int or None): Input resolution for calculating the relative positional parameter size or rope
+ size.
+ attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer".
+ cls_token: whether a cls_token is present.
+ use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
+ use_rel_pos: whether to use relative positional embeddings
+ rope_theta: control frequencies of rope
+ rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
+ rope_interp: whether to interpolate (or extrapolate) rope to match input size.
+ """
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.cls_token = cls_token
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ # rel_pos embeddings and rope
+ self.use_rel_pos = use_rel_pos
+ self.input_size = input_size
+
+ self.use_rope = use_rope
+ self.rope_theta = rope_theta
+ self.rope_pt_size = rope_pt_size
+ self.rope_interp = rope_interp
+
+ # init rel_pos embeddings and rope
+ self._setup_rel_pos(rel_pos_zero_init, input_size)
+ self._setup_rope_freqs(input_size)
+
+ def _setup_rel_pos(self, rel_pos_zero_init: bool = True, input_size: tuple[int, int] | None = None) -> None:
+ """Setup relative positional embeddings."""
+ if not self.use_rel_pos:
+ self.rel_pos_h = None
+ self.rel_pos_w = None
+ return
+
+ assert input_size is not None
+ assert self.cls_token is False, "not supported"
+ # initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
+
+ if not rel_pos_zero_init:
+ nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
+ nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
+
+ # Precompute the relative coords
+ H, W = input_size
+ q_coords = torch.arange(H)[:, None]
+ k_coords = torch.arange(W)[None, :]
+ relative_coords = (q_coords - k_coords) + (H - 1)
+ self.relative_coords = relative_coords.long()
+
+ def _setup_rope_freqs(self, input_size: tuple[int, int] | None = None) -> None:
+ """Setup 2d-rope frequencies."""
+ if not self.use_rope:
+ self.freqs_cis = None
+ return
+
+ assert input_size is not None
+ # determine rope input size
+ if self.rope_pt_size is None:
+ self.rope_pt_size = input_size
+
+ # initialize 2d rope freqs
+ self.compute_cis = partial(
+ compute_axial_cis,
+ dim=self.head_dim,
+ theta=self.rope_theta,
+ )
+
+ # interpolate rope
+ scale_pos = 1.0
+ if self.rope_interp:
+ scale_pos = self.rope_pt_size[0] / input_size[0]
+ # get scaled freqs_cis
+ freqs_cis = self.compute_cis(
+ end_x=input_size[0],
+ end_y=input_size[1],
+ scale_pos=scale_pos,
+ )
+ if self.cls_token:
+ t = torch.zeros(
+ self.head_dim // 2,
+ dtype=torch.float32,
+ device=freqs_cis.device,
+ )
+ cls_freqs_cis = torch.polar(torch.ones_like(t), t)[None, :]
+ freqs_cis = torch.cat([cls_freqs_cis, freqs_cis], dim=0)
+
+ self.freqs_cis = freqs_cis
+
+ def _apply_rope(self, q, k) -> tuple[Tensor, Tensor]:
+ """Apply 2d-rope to q and k."""
+ if not self.use_rope:
+ return q, k
+
+ assert self.freqs_cis is not None
+ return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis.to(q.device))
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass of attention block."""
+ s = 1 if self.cls_token else 0 # used to exclude cls_token
+ if x.ndim == 4:
+ B, H, W, _ = x.shape
+ assert s == 0 # no cls_token
+ L = H * W
+ ndim = 4
+ else:
+ assert x.ndim == 3
+ B, L, _ = x.shape
+ ndim = 3
+ H = W = math.sqrt(L - s)
+
+ # qkv with shape (3, B, nHead, L, C)
+ qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, -1)
+ # q, k, v with shape (B, nHead, L, C)
+ q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
+
+ # handle rope and rel pos embeddings
+ q, k = self._apply_rope(q, k)
+ if self.use_rel_pos:
+ q, k = concat_rel_pos(
+ q.flatten(0, 1),
+ k.flatten(0, 1),
+ (H, W),
+ x.shape[1:3],
+ self.rel_pos_h,
+ self.rel_pos_w,
+ rescale=True,
+ relative_coords=self.relative_coords,
+ )
+
+ # sdpa expects [B, nheads, H*W, C] so we transpose back
+ q = q.reshape(B, self.num_heads, H * W, -1)
+ k = k.reshape(B, self.num_heads, H * W, -1)
+
+ x = F.scaled_dot_product_attention(q, k, v)
+
+ if ndim == 4:
+ x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+ else:
+ x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
+
+ x = self.proj(x)
+
+ return x
+
+
+class Block(nn.Module):
+ """Transformer blocks with support of window attention."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ drop_path: float = 0.0,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ input_size: tuple[int, int] | None = None,
+ use_rope: bool = False,
+ rope_pt_size: tuple[int, int] | None = None,
+ rope_interp: bool = False,
+ cls_token: bool = False,
+ dropout: float = 0.0,
+ init_values: float | None = None,
+ ):
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ drop_path (float): Stochastic depth rate.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks. If it equals 0, then not use window attention.
+ input_size (int or None): Input resolution for calculating the relative positional parameter size.
+ dropout (float): Dropout rate.
+ cls_token: whether a cls_token is present.
+ use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
+ rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
+ rope_interp: whether to interpolate (or extrapolate) rope to match target input size, expected to specify
+ source size as rope_pt_size.
+ init_values: layer scale init, None for no layer scale.
+ """
+ super().__init__()
+
+ check_requirements("timm")
+ from timm.layers import DropPath, Mlp
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ input_size=input_size if window_size == 0 else (window_size, window_size),
+ use_rope=use_rope,
+ rope_pt_size=rope_pt_size,
+ rope_interp=rope_interp,
+ cls_token=cls_token,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer,
+ drop=(dropout, 0.0),
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.dropout = nn.Dropout(dropout)
+ self.window_size = window_size
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass of the transformer block."""
+ shortcut = x
+ x = self.norm1(x)
+ # Window partition
+ if self.window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, self.window_size)
+
+ x = self.ls1(self.attn(x))
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+ x = shortcut + self.dropout(self.drop_path(x))
+ x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x)))))
+
+ return x
+
+
+class ViT(nn.Module):
+ """This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. "Exploring Plain Vision Transformer
+ Backbones for Object Detection", https://arxiv.org/abs/2203.16527.
+ """
+
+ def __init__(
+ self,
+ img_size: int = 1024,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ drop_path_rate: float = 0.0,
+ norm_layer: Callable[..., nn.Module] | str = "LayerNorm",
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ use_abs_pos: bool = True,
+ tile_abs_pos: bool = True,
+ rel_pos_blocks: tuple[int, ...] | bool = (2, 5, 8, 11),
+ rel_pos_zero_init: bool = True,
+ window_size: int = 14,
+ global_att_blocks: tuple[int, ...] = (2, 5, 8, 11),
+ use_rope: bool = False,
+ rope_pt_size: int | None = None,
+ use_interp_rope: bool = False,
+ pretrain_img_size: int = 224,
+ pretrain_use_cls_token: bool = True,
+ retain_cls_token: bool = True,
+ dropout: float = 0.0,
+ return_interm_layers: bool = False,
+ init_values: float | None = None, # for layerscale
+ ln_pre: bool = False,
+ ln_post: bool = False,
+ bias_patch_embed: bool = True,
+ compile_mode: str | None = None,
+ use_act_checkpoint: bool = True,
+ ):
+ """
+ Args:
+ img_size (int): Input image size. Only relevant for rel pos or rope.
+ patch_size (int): Patch size.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ depth (int): Depth of ViT.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ drop_path_rate (float): Stochastic depth rate.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_abs_pos (bool): If True, use absolute positional embeddings.
+ tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation.
+ rel_pos_blocks (list): Blocks which have rel pos embeddings.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks.
+ global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention).
+ use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together).
+ rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling.
+ use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size, expected to
+ specify source size as rope_pt_size.
+ use_act_checkpoint (bool): If True, use activation checkpointing.
+ pretrain_img_size (int): input image size for pretraining models.
+ pretrain_use_cls_token (bool): If True, pretraining models use class token.
+ retain_cls_token: whether cls_token should be retained.
+ dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp.
+ return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks).
+ init_values: layer scale init, None for no layer scale.
+ ln_pre (bool): If True, apply layer norm before transformer blocks.
+ ln_post (bool): If True, apply layer norm after transformer blocks.
+ bias_patch_embed (bool): bias in conv for patch embed?
+ compile_mode (str): mode to compile the forward.
+ """
+ super().__init__()
+ self.pretrain_use_cls_token = pretrain_use_cls_token
+
+ window_block_indexes = [i for i in range(depth) if i not in global_att_blocks]
+ self.full_attn_ids = list(global_att_blocks)
+ self.rel_pos_blocks = [False] * depth
+ if isinstance(rel_pos_blocks, bool) and rel_pos_blocks:
+ self.rel_pos_blocks = [True] * depth
+ else:
+ for i in rel_pos_blocks:
+ self.rel_pos_blocks[i] = True
+
+ self.retain_cls_token = retain_cls_token
+ if self.retain_cls_token:
+ assert pretrain_use_cls_token
+ assert len(window_block_indexes) == 0, "windowing not supported with cls token"
+
+ assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
+
+ scale = embed_dim**-0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, 1, embed_dim))
+
+ if isinstance(norm_layer, str):
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-5)
+
+ self.patch_embed = PatchEmbed(
+ kernel_size=(patch_size, patch_size),
+ stride=(patch_size, patch_size),
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ bias=bias_patch_embed,
+ )
+
+ # Handle absolute positional embedding
+ self.tile_abs_pos = tile_abs_pos
+ self.use_abs_pos = use_abs_pos
+ if self.tile_abs_pos:
+ assert self.use_abs_pos
+
+ if self.use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
+ else:
+ self.pos_embed = None
+
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+
+ self.patch_size = patch_size
+ self.window_size = window_size
+ self.blocks = nn.ModuleList()
+ cur_stage = 1
+ for i in range(depth):
+ block = Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ use_rel_pos=self.rel_pos_blocks[i],
+ rel_pos_zero_init=rel_pos_zero_init,
+ window_size=window_size if i in window_block_indexes else 0,
+ input_size=(img_size // patch_size, img_size // patch_size),
+ use_rope=use_rope,
+ rope_pt_size=((window_size, window_size) if rope_pt_size is None else (rope_pt_size, rope_pt_size)),
+ rope_interp=use_interp_rope,
+ cls_token=self.retain_cls_token,
+ dropout=dropout,
+ init_values=init_values,
+ )
+
+ if i not in window_block_indexes:
+ cur_stage += 1
+
+ self.use_act_checkpoint = use_act_checkpoint
+
+ self.blocks.append(block)
+
+ self.return_interm_layers = return_interm_layers
+ self.channel_list = [embed_dim] * len(self.full_attn_ids) if return_interm_layers else [embed_dim]
+
+ if self.pos_embed is not None:
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
+
+ self.ln_pre = norm_layer(embed_dim) if ln_pre else nn.Identity()
+ self.ln_post = norm_layer(embed_dim) if ln_post else nn.Identity()
+
+ self.apply(self._init_weights)
+
+ if compile_mode is not None:
+ self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
+ if self.use_act_checkpoint and self.training:
+ torch._dynamo.config.optimize_ddp = False
+
+ def _init_weights(self, m: nn.Module) -> None:
+ """Initialize the weights."""
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
+ """Vit forward path and get feature maps."""
+ x = self.patch_embed(x)
+ h, w = x.shape[1], x.shape[2]
+
+ s = 0
+ if self.retain_cls_token:
+ # If cls_token is retained, we don't
+ # maintain spatial shape
+ x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
+ s = 1
+
+ if self.pos_embed is not None:
+ x = x + get_abs_pos(
+ self.pos_embed,
+ self.pretrain_use_cls_token,
+ (h, w),
+ self.retain_cls_token,
+ tiling=self.tile_abs_pos,
+ )
+
+ x = self.ln_pre(x)
+
+ outputs = []
+ for i, blk in enumerate(self.blocks):
+ if self.use_act_checkpoint and self.training:
+ x = checkpoint.checkpoint(blk, x, use_reentrant=False)
+ else:
+ x = blk(x)
+ if (i == self.full_attn_ids[-1]) or (self.return_interm_layers and i in self.full_attn_ids):
+ if i == self.full_attn_ids[-1]:
+ x = self.ln_post(x)
+
+ feats = x[:, s:]
+ if feats.ndim == 4:
+ feats = feats.permute(0, 3, 1, 2)
+ else:
+ assert feats.ndim == 3
+ h = w = math.sqrt(feats.shape[1])
+ feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)
+
+ outputs.append(feats)
+
+ return outputs
+
+ def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
+ """Setup rel pos embeddings and rope freqs for a new input image size."""
+ for block in self.blocks:
+ if block.window_size != 0:
+ continue
+ block.attn._setup_rel_pos(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))
+ block.attn._setup_rope_freqs(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))
diff --git a/ultralytics/models/sam/sam3/vl_combiner.py b/ultralytics/models/sam/sam3/vl_combiner.py
new file mode 100644
index 0000000000..c47c0f9c09
--- /dev/null
+++ b/ultralytics/models/sam/sam3/vl_combiner.py
@@ -0,0 +1,165 @@
+# Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+"""Provides utility to combine a vision backbone with a language backbone."""
+
+from __future__ import annotations
+
+from copy import copy
+
+import torch
+import torch.nn as nn
+from torch.nn.attention import SDPBackend, sdpa_kernel
+
+from .necks import Sam3DualViTDetNeck
+
+
+class SAM3VLBackbone(nn.Module):
+ """This backbone combines a vision backbone and a language backbone without fusion. As such it is more of a
+ convenience wrapper to handle the two backbones together.
+
+ It adds support for activation checkpointing and compilation.
+ """
+
+ def __init__(
+ self,
+ visual: Sam3DualViTDetNeck,
+ text,
+ compile_visual: bool = False,
+ act_ckpt_whole_vision_backbone: bool = False,
+ act_ckpt_whole_language_backbone: bool = False,
+ scalp=0,
+ ):
+ """Initialize the backbone combiner.
+
+ :param visual: The vision backbone to use
+ :param text: The text encoder to use
+ """
+ super().__init__()
+ self.vision_backbone: Sam3DualViTDetNeck = torch.compile(visual) if compile_visual else visual
+ self.language_backbone = text
+ self.scalp = scalp
+ # allow running activation checkpointing on the entire vision and language backbones
+ self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
+ self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
+
+ def forward(
+ self,
+ samples: torch.Tensor,
+ captions: list[str],
+ input_boxes: torch.Tensor = None,
+ additional_text: list[str] | None = None,
+ ):
+ """Forward pass of the backbone combiner.
+
+ :param samples: The input images
+ :param captions: The input captions
+ :param input_boxes: If the text contains place-holders for boxes, this
+ parameter contains the tensor containing their spatial features
+ :param additional_text: This can be used to encode some additional text
+ (different from the captions) in the same forward of the backbone
+ :return: Output dictionary with the following keys:
+ - vision_features: The output of the vision backbone
+ - language_features: The output of the language backbone
+ - language_mask: The attention mask of the language backbone
+ - vision_pos_enc: The positional encoding of the vision backbone
+ - (optional) additional_text_features: The output of the language
+ backbone for the additional text
+ - (optional) additional_text_mask: The attention mask of the
+ language backbone for the additional text
+ """
+ output = self.forward_image(samples)
+ output.update(self.forward_text(captions, input_boxes, additional_text))
+ return output
+
+ def forward_image(self, samples: torch.Tensor):
+ """Forward pass of the vision backbone and get both SAM3 and SAM2 features."""
+ # Forward through backbone
+ sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(samples)
+ if self.scalp > 0:
+ # Discard the lowest resolution features
+ sam3_features, sam3_pos = (
+ sam3_features[: -self.scalp],
+ sam3_pos[: -self.scalp],
+ )
+ if sam2_features is not None and sam2_pos is not None:
+ sam2_features, sam2_pos = (
+ sam2_features[: -self.scalp],
+ sam2_pos[: -self.scalp],
+ )
+
+ sam2_output = None
+
+ if sam2_features is not None and sam2_pos is not None:
+ sam2_src = sam2_features[-1]
+ sam2_output = {
+ "vision_features": sam2_src,
+ "vision_pos_enc": sam2_pos,
+ "backbone_fpn": sam2_features,
+ }
+
+ sam3_src = sam3_features[-1]
+ return {
+ "vision_features": sam3_src,
+ "vision_pos_enc": sam3_pos,
+ "backbone_fpn": sam3_features,
+ "sam2_backbone_out": sam2_output,
+ }
+
+ def forward_image_sam2(self, samples: torch.Tensor):
+ """Forward pass of the vision backbone to get SAM2 features only."""
+ xs = self.vision_backbone.trunk(samples)
+ sam2_features, sam2_pos = [], []
+ x = xs[-1] # simpleFPN
+
+ assert self.vision_backbone.sam2_convs is not None, "SAM2 neck is not available."
+ for i in range(len(self.vision_backbone.sam2_convs)):
+ sam2_x_out = self.vision_backbone.sam2_convs[i](x)
+ sam2_pos_out = self.vision_backbone.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
+ sam2_features.append(sam2_x_out)
+ sam2_pos.append(sam2_pos_out)
+
+ if self.scalp > 0:
+ # Discard the lowest resolution features
+ sam2_features, sam2_pos = (
+ sam2_features[: -self.scalp],
+ sam2_pos[: -self.scalp],
+ )
+
+ return {
+ "vision_features": sam2_features[-1],
+ "vision_pos_enc": sam2_pos,
+ "backbone_fpn": sam2_features,
+ }
+
+ def forward_text(self, captions, input_boxes=None, additional_text=None):
+ """Forward pass of the text encoder."""
+ output = {}
+
+ # Forward through text_encoder
+ text_to_encode = copy(captions)
+ if additional_text is not None:
+ # if there are additional_text, we piggy-back them into this forward.
+ # They'll be used later for output alignment
+ text_to_encode += additional_text
+
+ with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]):
+ text_attention_mask, text_memory, text_embeds = self.language_backbone(text_to_encode, input_boxes)
+
+ if additional_text is not None:
+ output["additional_text_features"] = text_memory[:, -len(additional_text) :]
+ output["additional_text_mask"] = text_attention_mask[-len(additional_text) :]
+
+ text_memory = text_memory[:, : len(captions)]
+ text_attention_mask = text_attention_mask[: len(captions)]
+ text_embeds = text_embeds[:, : len(captions)]
+ output["language_features"] = text_memory
+ output["language_mask"] = text_attention_mask
+ output["language_embeds"] = text_embeds # Text embeddings before forward to the encoder
+
+ return output
+
+ def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
+ """Set the image size for the vision backbone."""
+ self.vision_backbone.set_imgsz(imgsz)
diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py
index c8b5f04eb7..0e54e21198 100644
--- a/ultralytics/nn/modules/transformer.py
+++ b/ultralytics/nn/modules/transformer.py
@@ -359,7 +359,15 @@ class MLP(nn.Module):
"""
def __init__(
- self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act=nn.ReLU, sigmoid: bool = False
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ act=nn.ReLU,
+ sigmoid: bool = False,
+ residual: bool = False,
+ out_norm: nn.Module = None,
):
"""Initialize the MLP with specified input, hidden, output dimensions and number of layers.
@@ -370,6 +378,8 @@ class MLP(nn.Module):
num_layers (int): Number of layers.
act (nn.Module): Activation function.
sigmoid (bool): Whether to apply sigmoid to the output.
+ residual (bool): Whether to use residual connections.
+ out_norm (nn.Module, optional): Normalization layer for the output.
"""
super().__init__()
self.num_layers = num_layers
@@ -377,6 +387,12 @@ class MLP(nn.Module):
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim, *h], [*h, output_dim]))
self.sigmoid = sigmoid
self.act = act()
+ if residual and input_dim != output_dim:
+ raise ValueError("residual is only supported if input_dim == output_dim")
+ self.residual = residual
+ # whether to apply a normalization layer to the output
+ assert isinstance(out_norm, nn.Module) or out_norm is None
+ self.out_norm = out_norm or nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass for the entire MLP.
@@ -387,8 +403,12 @@ class MLP(nn.Module):
Returns:
(torch.Tensor): Output tensor after MLP.
"""
+ orig_x = x
for i, layer in enumerate(self.layers):
x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x)
+ if getattr(self, "residual", False):
+ x = x + orig_x
+ x = getattr(self, "out_norm", nn.Identity())(x)
return x.sigmoid() if getattr(self, "sigmoid", False) else x
diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py
index 787b48e889..2f8cec5c0e 100644
--- a/ultralytics/utils/ops.py
+++ b/ultralytics/utils/ops.py
@@ -660,6 +660,4 @@ def clean_str(s):
def empty_like(x):
"""Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
- return (
- torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
- )
+ return torch.empty_like(x, dtype=x.dtype) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=x.dtype)