ultralytics 8.3.237 SAM3 integration (#22897)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: fatih akyon <34196005+fcakyon@users.noreply.github.com>
This commit is contained in:
Jing Qiu 2025-12-12 21:04:33 +08:00 committed by GitHub
parent 24b1ec6252
commit e0764aa55d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
45 changed files with 7070 additions and 252 deletions

View file

@ -114,13 +114,17 @@ SAM 3 will be available directly in the Ultralytics package once integration lan
pip install ultralytics
```
Models will download automatically when first used. You can then use standard [predict mode](../modes/predict.md) and later [export](../modes/export.md) models to formats like [ONNX](../integrations/onnx.md) and [TensorRT](../integrations/tensorrt.md) for deployment. Watch for a package update with SAM-3 weights and configs soon.
!!! warning "SAM 3 Model Weights Required"
Unlike other Ultralytics models, SAM 3 weights (`sam3.pt`) are **not automatically downloaded**. You must manually download the model weights from the [official SAM 3 repository](https://github.com/facebookresearch/sam3) before using SAM 3. Place the downloaded `sam3.pt` file in your working directory or specify the full path when loading the model.
!!! note "BPE Vocabulary for Text Prompts"
If you plan to use text-based concept segmentation with `SAM3SemanticPredictor`, you also need to download the BPE vocabulary file `bpe_simple_vocab_16e6.txt.gz` from the [SAM 3 assets](https://github.com/facebookresearch/sam3/blob/main/assets/bpe_simple_vocab_16e6.txt.gz).
## How to Use SAM 3: Versatility in Concept Segmentation
!!! warning "Ultralytics API preview"
The following examples show the intended Ultralytics API once SAM 3 ships in the package. Until integration lands, details may change.
SAM 3 supports both Promptable Concept Segmentation (PCS) and Promptable Visual Segmentation (PVS) tasks through different predictor interfaces.
### Supported Tasks and Models
@ -138,143 +142,176 @@ SAM 3 supports both Promptable Concept Segmentation (PCS) and Promptable Visual
!!! example "Text-based Concept Segmentation"
Find and segment all instances of a concept using a text description.
Find and segment all instances of a concept using a text description. Text prompts require the `SAM3SemanticPredictor` interface.
=== "Python"
```python
from ultralytics import SAM
from ultralytics.models.sam.predict import SAM3SemanticPredictor
# Load SAM 3 model
model = SAM("sam3.pt")
# Initialize predictor with configuration
overrides = dict(
conf=0.25,
task="segment",
mode="predict",
model="sam3.pt",
half=True, # Use FP16 for faster inference
)
predictor = SAM3SemanticPredictor(
overrides=overrides,
bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz", # Required for text encoding
)
# Segment all instances of a concept
results = model("path/to/image.jpg", prompt="yellow school bus")
# Set image once for multiple queries
predictor.set_image("path/to/image.jpg")
# Query with multiple text prompts
results = predictor(text=["person", "bus", "glasses"], save=True)
# Works with descriptive phrases
results = model("path/to/image.jpg", prompt="person wearing a red hat")
results = predictor(text=["person with red cloth", "person with blue cloth"], save=True)
# Or simple object names
results = model("path/to/image.jpg", prompt="striped cat")
# Query with a single concept
results = predictor(text=["a person"], save=True)
```
=== "CLI"
!!! note "Text Encoding Requirement"
```bash
# Segment all matching concepts in an image
yolo segment model=sam3.pt source=path/to/image.jpg prompt="yellow school bus"
```
!!! warning "API Preview"
This example shows intended usage. Actual implementation pending Ultralytics integration.
The `bpe_path` parameter is required for text prompt encoding. Download the BPE vocabulary file from the [bpe_simple_vocab_16e6.txt.gz](https://github.com/facebookresearch/sam3/blob/main/assets/bpe_simple_vocab_16e6.txt.gz).
#### Segment with Image Exemplars
!!! example "Image Exemplar-based Segmentation"
Use one or more example objects to find all similar instances.
Use bounding boxes as visual prompts to find all similar instances. This also requires `SAM3SemanticPredictor` for concept-based matching.
=== "Python"
```python
from ultralytics import SAM
from ultralytics.models.sam.predict import SAM3SemanticPredictor
model = SAM("sam3.pt")
# Initialize predictor
overrides = dict(conf=0.25, task="segment", mode="predict", model="sam3.pt", half=True)
predictor = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
# Provide a positive example box - finds all similar objects
results = model("path/to/image.jpg", bboxes=[100, 150, 300, 400], labels=[1])
# Set image
predictor.set_image("path/to/image.jpg")
# Add negative examples to exclude certain instances
results = model(
"path/to/image.jpg",
bboxes=[[100, 150, 300, 400], [500, 200, 600, 350]], # Two boxes
labels=[1, 0], # First is positive, second is negative
)
# Provide bounding box examples to segment similar objects
results = predictor(bboxes=[[480.0, 290.0, 590.0, 650.0]], save=True)
# Combine text and image exemplars for precision
results = model("path/to/image.jpg", prompt="dog", bboxes=[100, 150, 300, 400], labels=[1])
# Multiple bounding boxes for different concepts
results = predictor(bboxes=[[539, 599, 589, 639], [343, 267, 499, 662]], save=True)
```
!!! warning "API Preview"
#### Feature-based Inference for Efficiency
This example shows intended usage. Actual implementation pending Ultralytics integration.
!!! example "Reusing Image Features for Multiple Queries"
#### Interactive Refinement
!!! example "Iterative Refinement with Exemplars"
Progressively improve results by adding exemplar prompts based on initial output.
Extract image features once and reuse them for multiple segmentation queries to improve efficiency.
=== "Python"
```python
from ultralytics import SAM
import cv2
model = SAM("sam3.pt")
from ultralytics.models.sam.predict import SAM3SemanticPredictor
from ultralytics.utils.plotting import Annotator, colors
# Initial segmentation with text
results = model("path/to/image.jpg", prompt="car")
# Initialize predictors
overrides = dict(conf=0.50, task="segment", mode="predict", model="sam3.pt", verbose=False)
predictor = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
predictor2 = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
# If some cars are missed, add a positive exemplar
results = model(
"path/to/image.jpg",
prompt="car",
bboxes=[missed_car_box],
labels=[1], # Positive example
)
# Extract features from the first predictor
source = "path/to/image.jpg"
predictor.set_image(source)
src_shape = cv2.imread(source).shape[:2]
# If false positives appear, add negative exemplars
results = model(
"path/to/image.jpg",
prompt="car",
bboxes=[false_positive_box],
labels=[0], # Negative example
)
# Setup second predictor and reuse features
predictor2.setup_model()
# Perform inference using shared features with text prompt
masks, boxes = predictor2.inference_features(predictor.features, src_shape=src_shape, text=["person"])
# Perform inference using shared features with bounding box prompt
masks, boxes = predictor2.inference_features(predictor.features, src_shape=src_shape, bboxes=[439, 437, 524, 709])
# Visualize results
masks, boxes = masks.cpu().numpy(), boxes.cpu().numpy()
im = cv2.imread(source)
annotator = Annotator(im, pil=False)
annotator.masks(masks, [colors(x, True) for x in range(len(masks))])
cv2.imshow("result", annotator.result())
cv2.waitKey(0)
```
!!! warning "API Preview"
This example shows intended usage. Actual implementation pending Ultralytics integration.
### Video Concept Segmentation
!!! example "Track Concepts Across Video"
#### Track Concepts Across Video with Bounding Boxes
Detect and track all instances of a concept throughout a video.
!!! example "Video Tracking with Visual Prompts"
Detect and track object instances across video frames using bounding box prompts.
=== "Python"
```python
from ultralytics.models.sam import SAM3VideoPredictor
from ultralytics.models.sam.predict import SAM3VideoPredictor
# Create video predictor
predictor = SAM3VideoPredictor(model="sam3.pt", imgsz=1024, conf=0.25)
overrides = dict(conf=0.25, task="segment", mode="predict", model="sam3.pt", half=True)
predictor = SAM3VideoPredictor(overrides=overrides)
# Track all instances of a concept
results = predictor(source="video.mp4", prompt="person wearing blue shirt")
# Track objects using bounding box prompts
results = predictor(source="path/to/video.mp4", bboxes=[[706.5, 442.5, 905.25, 555], [598, 635, 725, 750]], stream=True)
# Combine text with exemplar for precision
# Process and display results
for r in results:
r.show() # Display frame with segmentation masks
```
#### Track Concepts with Text Prompts
!!! example "Video Tracking with Semantic Queries"
Track all instances of concepts specified by text across video frames.
=== "Python"
```python
from ultralytics.models.sam.sam3_video_model import SAM3VideoSemanticPredictor
# Initialize semantic video predictor
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=640, model="sam3.pt", half=True)
predictor = SAM3VideoSemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
# Track concepts using text prompts
results = predictor(source="path/to/video.mp4", text=["person", "bicycle"], stream=True, save=True)
# Process results
for r in results:
r.show() # Display frame with tracked objects
# Alternative: Track with bounding box prompts
results = predictor(
source="video.mp4",
prompt="kangaroo",
bboxes=[initial_box], # Exemplar from first frame
labels=[1],
source="path/to/video.mp4",
bboxes=[[864, 383, 975, 620], [705, 229, 782, 402]],
labels=[1, 1], # Positive labels
stream=True,
save=True,
)
```
!!! warning "API Preview"
This example shows intended usage. Actual implementation pending Ultralytics integration.
For broader streaming and production setups, see [object tracking](../guides/object-counting.md) and [view results in terminal](../guides/view-results-in-terminal.md).
### Visual Prompts (SAM 2 Compatibility)
SAM 3 maintains full backward compatibility with SAM 2's visual prompting:
SAM 3 maintains full backward compatibility with SAM 2's visual prompting for single-object segmentation:
!!! example "SAM 2 Style Visual Prompts"
The basic `SAM` interface behaves exactly like SAM 2, segmenting only the specific area indicated by visual prompts (points, boxes, or masks).
=== "Python"
```python
@ -282,19 +319,21 @@ SAM 3 maintains full backward compatibility with SAM 2's visual prompting:
model = SAM("sam3.pt")
# Single point prompt (SAM 2 style)
results = model(points=[900, 370], labels=[1])
# Single point prompt - segments object at specific location
results = model.predict(source="path/to/image.jpg", points=[900, 370], labels=[1])
results[0].show()
# Multiple points
results = model(points=[[400, 370], [900, 370]], labels=[1, 1])
# Multiple points - segments single object with multiple point hints
results = model.predict(source="path/to/image.jpg", points=[[400, 370], [900, 370]], labels=[1, 1])
# Box prompt
results = model(bboxes=[100, 150, 300, 400])
# Box prompt - segments object within bounding box
results = model.predict(source="path/to/image.jpg", bboxes=[100, 150, 300, 400])
results[0].show()
```
!!! warning "API Preview"
!!! warning "Visual Prompts vs Concept Segmentation"
This example shows intended usage. Actual implementation pending Ultralytics integration.
Using `SAM("sam3.pt")` with visual prompts (points/boxes/masks) will segment **only the specific object** at that location, just like SAM 2. To segment **all instances of a concept**, use `SAM3SemanticPredictor` with text or exemplar prompts as shown above.
## Performance Benchmarks

View file

@ -11,6 +11,10 @@ keywords: Ultralytics, SAM model, Segment Anything Model, SAM 2 model, Segment A
<br>
## ::: ultralytics.models.sam.build._load_checkpoint
<br><br><hr><br>
## ::: ultralytics.models.sam.build.build_sam_vit_h
<br><br><hr><br>

View file

@ -0,0 +1,32 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/build_sam3.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/build_sam3.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/build_sam3.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.build_sam3._create_vision_backbone
<br><br><hr><br>
## ::: ultralytics.models.sam.build_sam3._create_sam3_transformer
<br><br><hr><br>
## ::: ultralytics.models.sam.build_sam3.build_sam3_image_model
<br><br><hr><br>
## ::: ultralytics.models.sam.build_sam3.build_interactive_sam3
<br><br><hr><br>
## ::: ultralytics.models.sam.build_sam3._load_checkpoint
<br><br>

View file

@ -17,4 +17,8 @@ keywords: Ultralytics, SAM Module, SAM 2 Module, object segmentation, image enco
## ::: ultralytics.models.sam.modules.sam.SAM2Model
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.sam.SAM3Model
<br><br>

View file

@ -49,4 +49,12 @@ keywords: Ultralytics, SAM, SAM 2, API Reference, models, window partition, data
## ::: ultralytics.models.sam.modules.utils.add_decomposed_rel_pos
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.utils.get_abs_pos
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.utils.concat_rel_pos
<br><br>

View file

@ -25,4 +25,20 @@ keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Mode
## ::: ultralytics.models.sam.predict.SAM2DynamicInteractivePredictor
<br><br><hr><br>
## ::: ultralytics.models.sam.predict.SAM3Predictor
<br><br><hr><br>
## ::: ultralytics.models.sam.predict.SAM3SemanticPredictor
<br><br><hr><br>
## ::: ultralytics.models.sam.predict.SAM3VideoPredictor
<br><br><hr><br>
## ::: ultralytics.models.sam.predict.SAM3VideoSemanticPredictor
<br><br>

View file

@ -0,0 +1,20 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/sam3/decoder.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/decoder.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/decoder.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.decoder.TransformerDecoder
<br><br>

View file

@ -0,0 +1,28 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/sam3/encoder.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/encoder.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/encoder.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoderLayer
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoder
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoderFusion
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.encoder.pool_text_feat
<br><br>

View file

@ -0,0 +1,28 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/sam3/geometry_encoders.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/geometry_encoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/geometry_encoders.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.sam3.geometry_encoders.Prompt
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.geometry_encoders.SequenceGeometryEncoder
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.geometry_encoders.is_right_padded
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.geometry_encoders.concat_padded_sequences
<br><br>

View file

@ -0,0 +1,32 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/sam3/maskformer_segmentation.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/maskformer_segmentation.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/maskformer_segmentation.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.sam3.maskformer_segmentation.LinearPresenceHead
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.maskformer_segmentation.MaskPredictor
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.maskformer_segmentation.SegmentationHead
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.maskformer_segmentation.PixelDecoder
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.maskformer_segmentation.UniversalSegmentationHead
<br><br>

View file

@ -0,0 +1,32 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/sam3/model_misc.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/model_misc.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/model_misc.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.sam3.model_misc.DotProductScoring
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.model_misc.LayerScale
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.model_misc.TransformerWrapper
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.model_misc.get_valid_ratio
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.model_misc.gen_sineembed_for_position
<br><br>

View file

@ -0,0 +1,16 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/sam3/necks.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/necks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/necks.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.sam3.necks.Sam3DualViTDetNeck
<br><br>

View file

@ -0,0 +1,20 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/sam3/sam3_image.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/sam3_image.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/sam3_image.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.sam3.sam3_image.SAM3SemanticModel
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.sam3_image._update_out
<br><br>

View file

@ -0,0 +1,32 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/sam3/text_encoder_ve.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/text_encoder_ve.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/text_encoder_ve.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.sam3.text_encoder_ve.ResidualAttentionBlock
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.text_encoder_ve.Transformer
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.text_encoder_ve.TextTransformer
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.text_encoder_ve.VETextEncoder
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.text_encoder_ve.text_global_pool
<br><br>

View file

@ -0,0 +1,52 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/sam3/tokenizer_ve.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/tokenizer_ve.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/tokenizer_ve.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.sam3.tokenizer_ve.SimpleTokenizer
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.tokenizer_ve.bytes_to_unicode
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.tokenizer_ve.get_pairs
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.tokenizer_ve.basic_clean
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.tokenizer_ve.whitespace_clean
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_canonicalize
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_lower
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_whitespace
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.tokenizer_ve.get_clean_fn
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.tokenizer_ve.canonicalize_text
<br><br>

View file

@ -0,0 +1,24 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/sam3/vitdet.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vitdet.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vitdet.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.sam3.vitdet.Attention
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.vitdet.Block
<br><br><hr><br>
## ::: ultralytics.models.sam.sam3.vitdet.ViT
<br><br>

View file

@ -0,0 +1,16 @@
---
description: TODO ADD DESCRIPTION
keywords: TODO ADD KEYWORDS
---
# Reference for `ultralytics/models/sam/sam3/vl_combiner.py`
!!! success "Improvements"
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vl_combiner.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vl_combiner.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
<br>
## ::: ultralytics.models.sam.sam3.vl_combiner.SAM3VLBackbone
<br><br>

View file

@ -592,6 +592,7 @@ nav:
- sam:
- amg: reference/models/sam/amg.md
- build: reference/models/sam/build.md
- build_sam3: reference/models/sam/build_sam3.md
- model: reference/models/sam/model.md
- modules:
- blocks: reference/models/sam/modules/blocks.md
@ -603,6 +604,18 @@ nav:
- transformer: reference/models/sam/modules/transformer.md
- utils: reference/models/sam/modules/utils.md
- predict: reference/models/sam/predict.md
- sam3:
- decoder: reference/models/sam/sam3/decoder.md
- encoder: reference/models/sam/sam3/encoder.md
- geometry_encoders: reference/models/sam/sam3/geometry_encoders.md
- maskformer_segmentation: reference/models/sam/sam3/maskformer_segmentation.md
- model_misc: reference/models/sam/sam3/model_misc.md
- necks: reference/models/sam/sam3/necks.md
- sam3_image: reference/models/sam/sam3/sam3_image.md
- text_encoder_ve: reference/models/sam/sam3/text_encoder_ve.md
- tokenizer_ve: reference/models/sam/sam3/tokenizer_ve.md
- vitdet: reference/models/sam/sam3/vitdet.md
- vl_combiner: reference/models/sam/sam3/vl_combiner.md
- utils:
- loss: reference/models/utils/loss.md
- ops: reference/models/utils/ops.md

View file

@ -1,6 +1,6 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
__version__ = "8.3.236"
__version__ = "8.3.237"
import importlib
import os

View file

@ -244,14 +244,15 @@ class BasePredictor:
for _ in gen: # sourcery skip: remove-empty-nested-block, noqa
pass
def setup_source(self, source):
def setup_source(self, source, stride: int | None = None):
"""Set up source and inference mode.
Args:
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor): Source for
inference.
stride (int, optional): Model stride for image size checking.
"""
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
self.imgsz = check_imgsz(self.args.imgsz, stride=stride or self.model.stride, min_dim=2) # check image size
self.dataset = load_inference_source(
source=source,
batch=self.args.batch,

View file

@ -1,7 +1,16 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .model import SAM
from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor
from .predict import (
Predictor,
SAM2DynamicInteractivePredictor,
SAM2Predictor,
SAM2VideoPredictor,
SAM3Predictor,
SAM3SemanticPredictor,
SAM3VideoPredictor,
SAM3VideoSemanticPredictor,
)
__all__ = (
"SAM",
@ -9,4 +18,8 @@ __all__ = (
"SAM2DynamicInteractivePredictor",
"SAM2Predictor",
"SAM2VideoPredictor",
"SAM3Predictor",
"SAM3SemanticPredictor",
"SAM3VideoPredictor",
"SAM3VideoSemanticPredictor",
) # tuple or list of exportable items

View file

@ -21,6 +21,21 @@ from .modules.tiny_encoder import TinyViT
from .modules.transformer import TwoWayTransformer
def _load_checkpoint(model, checkpoint):
"""Load checkpoint into model from file path."""
if checkpoint is None:
return model
checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, "rb") as f:
state_dict = torch_load(f)
# Handle nested "model" key
if "model" in state_dict and isinstance(state_dict["model"], dict):
state_dict = state_dict["model"]
model.load_state_dict(state_dict)
return model
def build_sam_vit_h(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
return _build_sam(
@ -205,10 +220,7 @@ def _build_sam(
pixel_std=[58.395, 57.12, 57.375],
)
if checkpoint is not None:
checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, "rb") as f:
state_dict = torch_load(f)
sam.load_state_dict(state_dict)
sam = _load_checkpoint(sam, checkpoint)
sam.eval()
return sam
@ -299,10 +311,7 @@ def _build_sam2(
)
if checkpoint is not None:
checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, "rb") as f:
state_dict = torch_load(f)["model"]
sam2.load_state_dict(state_dict)
sam2 = _load_checkpoint(sam2, checkpoint)
sam2.eval()
return sam2

View file

@ -0,0 +1,374 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import torch.nn as nn
from ultralytics.nn.modules.transformer import MLP
from ultralytics.utils.patches import torch_load
from .modules.blocks import PositionEmbeddingSine, RoPEAttention
from .modules.encoders import MemoryEncoder
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
from .modules.sam import SAM3Model
from .sam3.decoder import TransformerDecoder, TransformerDecoderLayer
from .sam3.encoder import TransformerEncoderFusion, TransformerEncoderLayer
from .sam3.geometry_encoders import SequenceGeometryEncoder
from .sam3.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
from .sam3.model_misc import DotProductScoring, TransformerWrapper
from .sam3.necks import Sam3DualViTDetNeck
from .sam3.sam3_image import SAM3SemanticModel
from .sam3.text_encoder_ve import VETextEncoder
from .sam3.tokenizer_ve import SimpleTokenizer
from .sam3.vitdet import ViT
from .sam3.vl_combiner import SAM3VLBackbone
def _create_vision_backbone(compile_mode=None, enable_inst_interactivity=True) -> Sam3DualViTDetNeck:
"""Create SAM3 visual backbone with ViT and neck."""
# Position encoding
position_encoding = PositionEmbeddingSine(
num_pos_feats=256,
normalize=True,
scale=None,
temperature=10000,
)
# ViT backbone
vit_backbone = ViT(
img_size=1008,
pretrain_img_size=336,
patch_size=14,
embed_dim=1024,
depth=32,
num_heads=16,
mlp_ratio=4.625,
norm_layer="LayerNorm",
drop_path_rate=0.1,
qkv_bias=True,
use_abs_pos=True,
tile_abs_pos=True,
global_att_blocks=(7, 15, 23, 31),
rel_pos_blocks=(),
use_rope=True,
use_interp_rope=True,
window_size=24,
pretrain_use_cls_token=True,
retain_cls_token=False,
ln_pre=True,
ln_post=False,
return_interm_layers=False,
bias_patch_embed=False,
compile_mode=compile_mode,
)
return Sam3DualViTDetNeck(
position_encoding=position_encoding,
d_model=256,
scale_factors=[4.0, 2.0, 1.0, 0.5],
trunk=vit_backbone,
add_sam2_neck=enable_inst_interactivity,
)
def _create_sam3_transformer() -> TransformerWrapper:
"""Create SAM3 detector encoder and decoder."""
encoder: TransformerEncoderFusion = TransformerEncoderFusion(
layer=TransformerEncoderLayer(
d_model=256,
dim_feedforward=2048,
dropout=0.1,
pos_enc_at_attn=True,
pos_enc_at_cross_attn_keys=False,
pos_enc_at_cross_attn_queries=False,
pre_norm=True,
self_attention=nn.MultiheadAttention(
num_heads=8,
dropout=0.1,
embed_dim=256,
batch_first=True,
),
cross_attention=nn.MultiheadAttention(
num_heads=8,
dropout=0.1,
embed_dim=256,
batch_first=True,
),
),
num_layers=6,
d_model=256,
num_feature_levels=1,
frozen=False,
use_act_checkpoint=True,
add_pooled_text_to_img_feat=False,
pool_text_with_mask=True,
)
decoder: TransformerDecoder = TransformerDecoder(
layer=TransformerDecoderLayer(
d_model=256,
dim_feedforward=2048,
dropout=0.1,
cross_attention=nn.MultiheadAttention(
num_heads=8,
dropout=0.1,
embed_dim=256,
),
n_heads=8,
use_text_cross_attention=True,
),
num_layers=6,
num_queries=200,
return_intermediate=True,
box_refine=True,
num_o2m_queries=0,
dac=True,
boxRPB="log",
d_model=256,
frozen=False,
interaction_layer=None,
dac_use_selfatt_ln=True,
use_act_checkpoint=True,
presence_token=True,
)
return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
def build_sam3_image_model(
checkpoint_path: str, bpe_path: str, enable_segmentation: bool = True, compile: bool = False
):
"""Build SAM3 image model.
Args:
checkpoint_path: Optional path to model checkpoint
bpe_path: Path to the BPE tokenizer vocabulary
enable_segmentation: Whether to enable segmentation head
compile: To enable compilation, set to "default"
Returns:
A SAM3 image model
"""
# Create visual components
compile_mode = "default" if compile else None
vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True)
# Create text components
text_encoder = VETextEncoder(
tokenizer=SimpleTokenizer(bpe_path=bpe_path),
d_model=256,
width=1024,
heads=16,
layers=24,
)
# Create visual-language backbone
backbone = SAM3VLBackbone(visual=vision_encoder, text=text_encoder, scalp=1)
# Create transformer components
transformer = _create_sam3_transformer()
# Create dot product scoring
dot_prod_scoring = DotProductScoring(
d_model=256,
d_proj=256,
prompt_mlp=MLP(
input_dim=256,
hidden_dim=2048,
output_dim=256,
num_layers=2,
residual=True,
out_norm=nn.LayerNorm(256),
),
)
# Create segmentation head if enabled
segmentation_head = (
UniversalSegmentationHead(
hidden_dim=256,
upsampling_stages=3,
aux_masks=False,
presence_head=False,
dot_product_scorer=None,
act_ckpt=True,
cross_attend_prompt=nn.MultiheadAttention(
num_heads=8,
dropout=0,
embed_dim=256,
),
pixel_decoder=PixelDecoder(
num_upsampling_stages=3,
interpolation_mode="nearest",
hidden_dim=256,
compile_mode=compile_mode,
),
)
if enable_segmentation
else None
)
# Create geometry encoder
input_geometry_encoder = SequenceGeometryEncoder(
pos_enc=PositionEmbeddingSine(
num_pos_feats=256,
normalize=True,
scale=None,
temperature=10000,
),
encode_boxes_as_points=False,
boxes_direct_project=True,
boxes_pool=True,
boxes_pos_enc=True,
d_model=256,
num_layers=3,
layer=TransformerEncoderLayer(
d_model=256,
dim_feedforward=2048,
dropout=0.1,
pos_enc_at_attn=False,
pre_norm=True,
pos_enc_at_cross_attn_queries=False,
pos_enc_at_cross_attn_keys=True,
),
use_act_ckpt=True,
add_cls=True,
add_post_encode_proj=True,
)
# Create the SAM3SemanticModel model
model = SAM3SemanticModel(
backbone=backbone,
transformer=transformer,
input_geometry_encoder=input_geometry_encoder,
segmentation_head=segmentation_head,
num_feature_levels=1,
o2m_mask_predict=True,
dot_prod_scoring=dot_prod_scoring,
use_instance_query=False,
multimask_output=True,
)
# Load checkpoint
model = _load_checkpoint(model, checkpoint_path)
model.eval()
return model
def build_interactive_sam3(checkpoint_path: str, compile=None, with_backbone=True) -> SAM3Model:
"""Build the SAM3 Tracker module for video tracking.
Returns:
Sam3TrackerPredictor: Wrapped SAM3 Tracker module
"""
# Create model components
memory_encoder = MemoryEncoder(out_dim=64, interpol_size=[1152, 1152])
memory_attention = MemoryAttention(
batch_first=True,
d_model=256,
pos_enc_at_input=True,
layer=MemoryAttentionLayer(
dim_feedforward=2048,
dropout=0.1,
pos_enc_at_attn=False,
pos_enc_at_cross_attn_keys=True,
pos_enc_at_cross_attn_queries=False,
self_attn=RoPEAttention(
embedding_dim=256,
num_heads=1,
downsample_rate=1,
rope_theta=10000.0,
feat_sizes=[72, 72],
),
d_model=256,
cross_attn=RoPEAttention(
embedding_dim=256,
num_heads=1,
downsample_rate=1,
kv_in_dim=64,
rope_theta=10000.0,
feat_sizes=[72, 72],
rope_k_repeat=True,
),
),
num_layers=4,
)
backbone = (
SAM3VLBackbone(scalp=1, visual=_create_vision_backbone(compile_mode=compile), text=None)
if with_backbone
else None
)
model = SAM3Model(
image_size=1008,
image_encoder=backbone,
memory_attention=memory_attention,
memory_encoder=memory_encoder,
backbone_stride=14,
num_maskmem=7,
sigmoid_scale_for_mem_enc=20.0,
sigmoid_bias_for_mem_enc=-10.0,
use_mask_input_as_output_without_sam=True,
directly_add_no_mem_embed=True,
use_high_res_features_in_sam=True,
multimask_output_in_sam=True,
iou_prediction_use_sigmoid=True,
use_obj_ptrs_in_encoder=True,
add_tpos_enc_to_obj_ptrs=True,
only_obj_ptrs_in_the_past_for_eval=True,
pred_obj_scores=True,
pred_obj_scores_mlp=True,
fixed_no_obj_ptr=True,
multimask_output_for_tracking=True,
use_multimask_token_for_obj_ptr=True,
multimask_min_pt_num=0,
multimask_max_pt_num=1,
use_mlp_for_obj_ptr_proj=True,
compile_image_encoder=False,
no_obj_embed_spatial=True,
proj_tpos_enc_in_obj_ptrs=True,
use_signed_tpos_enc_to_obj_ptrs=True,
sam_mask_decoder_extra_args=dict(
dynamic_multimask_via_stability=True,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
),
)
# Load checkpoint if provided
model = _load_checkpoint(model, checkpoint_path, interactive=True)
# Setup device and mode
model.eval()
return model
def _load_checkpoint(model, checkpoint, interactive=False):
"""Load SAM3 model checkpoint from file."""
with open(checkpoint, "rb") as f:
ckpt = torch_load(f)
if "model" in ckpt and isinstance(ckpt["model"], dict):
ckpt = ckpt["model"]
sam3_image_ckpt = {k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k}
if interactive:
sam3_image_ckpt.update(
{
k.replace("backbone.vision_backbone", "image_encoder.vision_backbone"): v
for k, v in sam3_image_ckpt.items()
if "backbone.vision_backbone" in k
}
)
sam3_image_ckpt.update(
{
k.replace("tracker.transformer.encoder", "memory_attention"): v
for k, v in ckpt.items()
if "tracker.transformer" in k
}
)
sam3_image_ckpt.update(
{
k.replace("tracker.maskmem_backbone", "memory_encoder"): v
for k, v in ckpt.items()
if "tracker.maskmem_backbone" in k
}
)
sam3_image_ckpt.update({k.replace("tracker.", ""): v for k, v in ckpt.items() if "tracker." in k})
model.load_state_dict(sam3_image_ckpt, strict=False)
return model

View file

@ -21,7 +21,7 @@ from pathlib import Path
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
from .predict import Predictor, SAM2Predictor
from .predict import Predictor, SAM2Predictor, SAM3Predictor
class SAM(Model):
@ -59,6 +59,7 @@ class SAM(Model):
if model and Path(model).suffix not in {".pt", ".pth"}:
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
self.is_sam2 = "sam2" in Path(model).stem
self.is_sam3 = "sam3" in Path(model).stem
super().__init__(model=model, task="segment")
def _load(self, weights: str, task=None):
@ -72,9 +73,14 @@ class SAM(Model):
>>> sam = SAM("sam_b.pt")
>>> sam._load("path/to/custom_weights.pt")
"""
from .build import build_sam # slow import
if self.is_sam3:
from .build_sam3 import build_interactive_sam3
self.model = build_sam(weights)
self.model = build_interactive_sam3(weights)
else:
from .build import build_sam # slow import
self.model = build_sam(weights)
def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
"""Perform segmentation prediction on the given image or video source.
@ -158,4 +164,6 @@ class SAM(Model):
>>> print(task_map)
{'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
"""
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
return {
"segment": {"predictor": SAM2Predictor if self.is_sam2 else SAM3Predictor if self.is_sam3 else Predictor}
}

View file

@ -79,6 +79,7 @@ class MaskDownSampler(nn.Module):
padding: int = 0,
total_stride: int = 16,
activation: type[nn.Module] = nn.GELU,
interpol_size: tuple[int, int] | None = None,
):
"""Initialize a mask downsampler module for progressive downsampling and channel expansion."""
super().__init__()
@ -102,9 +103,24 @@ class MaskDownSampler(nn.Module):
mask_in_chans = mask_out_chans
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
self.interpol_size = interpol_size
if self.interpol_size is not None:
assert isinstance(self.interpol_size, (list, tuple)), (
f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
)
self.interpol_size = list(interpol_size)
assert len(self.interpol_size) == 2
def forward(self, x: Tensor) -> Tensor:
"""Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]):
x = F.interpolate(
x.float(),
size=self.interpol_size,
align_corners=False,
mode="bilinear",
antialias=True,
).to(x.dtype)
return self.encoder(x)
@ -429,13 +445,7 @@ class RoPEAttention(Attention):
)
# Attention
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ v
out = F.scaled_dot_product_attention(q, k, v)
out = self._recombine_heads(out)
out = self.out_proj(out)
@ -1033,6 +1043,7 @@ class PatchEmbed(nn.Module):
padding: tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
bias: bool = True,
) -> None:
"""Initialize the PatchEmbed module for converting image patches to embeddings.
@ -1045,10 +1056,11 @@ class PatchEmbed(nn.Module):
padding (tuple[int, int]): Padding applied to the input before convolution.
in_chans (int): Number of input image channels.
embed_dim (int): Dimensionality of the output patch embeddings.
bias (bool): Whether to include a bias term in the convolutional layer.
"""
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Compute patch embedding by applying convolution and transposing resulting tensor."""

View file

@ -436,9 +436,8 @@ class SAM2MaskDecoder(nn.Module):
def _get_stability_scores(self, mask_logits):
"""Compute mask stability scores based on IoU between upper and lower thresholds."""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
area_i = torch.sum(mask_logits > self.dynamic_multimask_stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -self.dynamic_multimask_stability_delta, dim=-1).float()
return torch.where(area_u > 0, area_i / area_u, 1.0)
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):

View file

@ -361,6 +361,7 @@ class MemoryEncoder(nn.Module):
self,
out_dim,
in_dim=256, # in_dim of pix_feats
interpol_size: tuple[int, int] | None = None,
):
"""Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
@ -370,10 +371,12 @@ class MemoryEncoder(nn.Module):
Args:
out_dim (int): Output dimension of the encoded features.
in_dim (int): Input dimension of the pixel features.
interpol_size (tuple[int, int] | None): Size to interpolate masks to. If None, uses the size of pixel
features.
"""
super().__init__()
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1, interpol_size=interpol_size)
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)

View file

@ -59,6 +59,8 @@ class MemoryAttentionLayer(nn.Module):
pos_enc_at_attn: bool = False,
pos_enc_at_cross_attn_keys: bool = True,
pos_enc_at_cross_attn_queries: bool = False,
self_attn: nn.Module | None = None,
cross_attn: nn.Module | None = None,
):
"""Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
@ -69,13 +71,15 @@ class MemoryAttentionLayer(nn.Module):
pos_enc_at_attn (bool): Whether to add positional encoding at attention.
pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
self_attn (nn.Module | None): Custom self-attention module. If None, a default RoPEAttention is used.
cross_attn (nn.Module | None): Custom cross-attention module. If None, a default RoPEAttention is used.
"""
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
self.dropout_value = dropout
self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
self.cross_attn_image = RoPEAttention(
self.self_attn = self_attn or RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
self.cross_attn_image = cross_attn or RoPEAttention(
rope_k_repeat=True,
embedding_dim=256,
num_heads=1,

View file

@ -13,7 +13,7 @@ from torch.nn.init import trunc_normal_
from ultralytics.nn.modules import MLP
from ultralytics.utils import LOGGER
from .blocks import SAM2TwoWayTransformer
from .blocks import SAM2TwoWayTransformer, TwoWayTransformer
from .decoders import MaskDecoder, SAM2MaskDecoder
from .encoders import ImageEncoderViT, PromptEncoder
from .utils import get_1d_sine_pe, select_closest_cond_frames
@ -329,6 +329,7 @@ class SAM2Model(torch.nn.Module):
self._build_sam_heads()
self.max_cond_frames_in_attn = max_cond_frames_in_attn
self.add_all_frames_to_correct_as_cond = True
# Model compilation
if compile_image_encoder:
@ -473,7 +474,7 @@ class SAM2Model(torch.nn.Module):
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
sam_mask_prompt = F.interpolate(
mask_inputs.float(),
mask_inputs.to(backbone_features.dtype),
size=self.sam_prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
@ -571,7 +572,7 @@ class SAM2Model(torch.nn.Module):
# produce an object pointer using the SAM decoder from the mask input
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
backbone_features=backbone_features,
mask_inputs=self.mask_downsample(mask_inputs_float),
mask_inputs=self.mask_downsample(mask_inputs_float.to(backbone_features.dtype)),
high_res_features=high_res_features,
)
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
@ -818,7 +819,6 @@ class SAM2Model(torch.nn.Module):
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
maskmem_features = maskmem_out["vision_features"]
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
# add a no-object embedding to the spatial memory to indicate that the frame
# is predicted to be occluded (i.e. no object is appearing in the frame)
if self.no_obj_embed_spatial is not None:
@ -827,7 +827,7 @@ class SAM2Model(torch.nn.Module):
..., None, None
].expand(*maskmem_features.shape)
return maskmem_features, maskmem_pos_enc
return maskmem_features, maskmem_out["vision_pos_enc"]
def _track_step(
self,
@ -1005,7 +1005,151 @@ class SAM2Model(torch.nn.Module):
def set_imgsz(self, imgsz):
"""Set image size to make model compatible with different image sizes."""
if hasattr(self.image_encoder, "set_imgsz"):
self.image_encoder.set_imgsz(imgsz)
self.image_size = imgsz[0]
self.sam_prompt_encoder.input_image_size = imgsz
self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
self.sam_prompt_encoder.image_embedding_size = [
x // self.backbone_stride for x in imgsz
] # fixed ViT patch size of 16
self.sam_prompt_encoder.mask_input_size = [
x // self.backbone_stride * 4 for x in imgsz
] # fixed ViT patch size of 16
self.sam_image_embedding_size = self.image_size // self.backbone_stride # update image embedding size
class SAM3Model(SAM2Model):
"""SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
def __init__(
self,
image_encoder,
memory_attention,
memory_encoder,
num_maskmem=7,
image_size=1008,
backbone_stride=14,
sigmoid_scale_for_mem_enc=1,
sigmoid_bias_for_mem_enc=0,
binarize_mask_from_pts_for_mem_enc=False,
use_mask_input_as_output_without_sam=False,
max_cond_frames_in_attn=-1,
directly_add_no_mem_embed=False,
use_high_res_features_in_sam=False,
multimask_output_in_sam=False,
multimask_min_pt_num=1,
multimask_max_pt_num=1,
multimask_output_for_tracking=False,
use_multimask_token_for_obj_ptr: bool = False,
iou_prediction_use_sigmoid=False,
memory_temporal_stride_for_eval=1,
non_overlap_masks_for_mem_enc=False,
use_obj_ptrs_in_encoder=False,
max_obj_ptrs_in_encoder=16,
add_tpos_enc_to_obj_ptrs=True,
proj_tpos_enc_in_obj_ptrs=False,
use_signed_tpos_enc_to_obj_ptrs=False,
only_obj_ptrs_in_the_past_for_eval=False,
pred_obj_scores: bool = False,
pred_obj_scores_mlp: bool = False,
fixed_no_obj_ptr: bool = False,
soft_no_obj_ptr: bool = False,
use_mlp_for_obj_ptr_proj: bool = False,
no_obj_embed_spatial: bool = False,
sam_mask_decoder_extra_args=None,
compile_image_encoder: bool = False,
):
"""SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
super().__init__(
image_encoder,
memory_attention,
memory_encoder,
num_maskmem,
image_size,
backbone_stride,
sigmoid_scale_for_mem_enc,
sigmoid_bias_for_mem_enc,
binarize_mask_from_pts_for_mem_enc,
use_mask_input_as_output_without_sam,
max_cond_frames_in_attn,
directly_add_no_mem_embed,
use_high_res_features_in_sam,
multimask_output_in_sam,
multimask_min_pt_num,
multimask_max_pt_num,
multimask_output_for_tracking,
use_multimask_token_for_obj_ptr,
iou_prediction_use_sigmoid,
memory_temporal_stride_for_eval,
non_overlap_masks_for_mem_enc,
use_obj_ptrs_in_encoder,
max_obj_ptrs_in_encoder,
add_tpos_enc_to_obj_ptrs,
proj_tpos_enc_in_obj_ptrs,
use_signed_tpos_enc_to_obj_ptrs,
only_obj_ptrs_in_the_past_for_eval,
pred_obj_scores,
pred_obj_scores_mlp,
fixed_no_obj_ptr,
soft_no_obj_ptr,
use_mlp_for_obj_ptr_proj,
no_obj_embed_spatial,
sam_mask_decoder_extra_args,
compile_image_encoder,
)
self.sam_mask_decoder = SAM2MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=self.sam_prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=self.sam_prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
use_high_res_features=self.use_high_res_features_in_sam,
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
pred_obj_scores=self.pred_obj_scores,
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
**(self.sam_mask_decoder_extra_args or {}),
)
def forward_image(self, img_batch: torch.Tensor):
"""Process image batch through encoder to extract multi-level features for SAM model."""
backbone_out = self.image_encoder.forward_image_sam2(img_batch)
if self.use_high_res_features_in_sam:
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
return backbone_out
def set_imgsz(self, imgsz: tuple[int, int]):
"""Set the image size for the model and mask downsampler."""
super().set_imgsz(imgsz)
self.memory_encoder.mask_downsampler.interpol_size = [size // 14 * 16 for size in imgsz]
@staticmethod
def _suppress_shrinked_masks(pred_masks, new_pred_masks, shrink_threshold=0.3):
"""Suppress masks that shrink in area after applying pixelwise non-overlapping constraints."""
area_before = (pred_masks > 0).sum(dim=(-1, -2))
area_after = (new_pred_masks > 0).sum(dim=(-1, -2))
area_before = torch.clamp(area_before, min=1.0)
area_ratio = area_after / area_before
keep = area_ratio >= shrink_threshold
keep_mask = keep[..., None, None].expand_as(pred_masks)
pred_masks_after = torch.where(keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0))
return pred_masks_after
def _suppress_object_pw_area_shrinkage(self, pred_masks):
"""This function suppresses masks that shrink in area after applying pixelwise non-overlapping constraints. Note
that the final output can still be overlapping.
"""
# Apply pixel-wise non-overlapping constraint based on mask scores
pixel_level_non_overlapping_masks = self._apply_non_overlapping_constraints(pred_masks)
# Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints
# NOTE: The output of this function can be a no op if none of the masks shrinked by a large factor.
pred_masks = self._suppress_shrinked_masks(pred_masks, pixel_level_non_overlapping_masks)
return pred_masks

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import math
from typing import Any
import torch
@ -86,7 +87,7 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
return pos_embed
def init_t_xy(end_x: int, end_y: int):
def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0):
"""Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
@ -95,6 +96,8 @@ def init_t_xy(end_x: int, end_y: int):
Args:
end_x (int): Width of the grid (number of columns).
end_y (int): Height of the grid (number of rows).
scale (float): Scaling factor to apply to the coordinates.
offset (int): Offset to add to the coordinates.
Returns:
t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
@ -110,10 +113,10 @@ def init_t_xy(end_x: int, end_y: int):
t = torch.arange(end_x * end_y, dtype=torch.float32)
t_x = (t % end_x).float()
t_y = torch.div(t, end_x, rounding_mode="floor").float()
return t_x, t_y
return t_x * scale + offset, t_y * scale + offset
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0, scale_pos: float = 1.0):
"""Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
@ -124,6 +127,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
end_x (int): Width of the 2D grid.
end_y (int): Height of the 2D grid.
theta (float, optional): Scaling factor for frequency computation.
scale_pos (float, optional): Scaling factor for position coordinates.
Returns:
(torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).
@ -137,7 +141,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
t_x, t_y = init_t_xy(end_x, end_y)
t_x, t_y = init_t_xy(end_x, end_y, scale=scale_pos)
freqs_x = torch.outer(t_x, freqs_x)
freqs_y = torch.outer(t_y, freqs_y)
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
@ -375,3 +379,129 @@ def add_decomposed_rel_pos(
)
return attn
def get_abs_pos(
abs_pos: torch.Tensor,
has_cls_token: bool,
hw: tuple[int, int],
retain_cls_token: bool = False,
tiling: bool = False,
) -> torch.Tensor:
"""Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
original embeddings.
Args:
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
hw (Tuple): size of input image tokens.
retain_cls_token: whether to retain the cls_token
tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win)
Returns:
Absolute positional embeddings after processing with shape (1, H, W, C),: if retain_cls_token is False,
otherwise (1, 1+H*W, C).
"""
if retain_cls_token:
assert has_cls_token
h, w = hw
if has_cls_token:
cls_pos = abs_pos[:, :1]
abs_pos = abs_pos[:, 1:]
xy_num = abs_pos.shape[1]
size = int(math.sqrt(xy_num))
assert size * size == xy_num
if size != h or size != w:
new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2)
if tiling:
new_abs_pos = new_abs_pos.tile([1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])])[
:, :, :h, :w
]
else:
new_abs_pos = F.interpolate(
new_abs_pos,
size=(h, w),
mode="bicubic",
align_corners=False,
)
if not retain_cls_token:
return new_abs_pos.permute(0, 2, 3, 1)
else:
# add cls_token back, flatten spatial dims
assert has_cls_token
return torch.cat(
[cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)],
dim=1,
)
else:
if not retain_cls_token:
return abs_pos.reshape(1, h, w, -1)
else:
assert has_cls_token
return torch.cat([cls_pos, abs_pos], dim=1)
def concat_rel_pos(
q: torch.Tensor,
k: torch.Tensor,
q_hw: tuple[int, int],
k_hw: tuple[int, int],
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
rescale: bool = False,
relative_coords: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now effectively including rel pos biases.
Args:
q (torch.Tensor): q tensor with shape (B, L_q, C).
k (torch.Tensor): k tensor with shape (B, L_k, C).
q_hw: These are spatial size of q tensors.
k_hw: These are spatial size of k tensors.
rel_pos_h: These are relative pos embeddings/params of height.
rel_pos_w: These are relative pos embeddings/params of width.
rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will scale by the wrong factor due to
the concat.
relative_coords (torch.Tensor, optional): Precomputed relative coords index tensor.
Returns:
q, k: But, padded so that qk^T accounts for rel pos biases.
"""
q_h, q_w = q_hw
k_h, k_w = k_hw
assert (q_h == q_w) and (k_h == k_w), "only square inputs supported"
if relative_coords is not None:
Rh = rel_pos_h[relative_coords]
Rw = rel_pos_w[relative_coords]
else:
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
old_scale = dim**0.5
new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa
# attn will be divided by new_scale, but we want to divide q by old_scale
scale_ratio = new_scale / old_scale
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w)
eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device)
eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device)
eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h])
eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w])
q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1)
k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(B, k_h * k_w, -1)
return q, k

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,3 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

View file

@ -0,0 +1,546 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Transformer decoder.
Inspired from Pytorch's version, adds the pre-norm variant.
"""
from __future__ import annotations
import numpy as np
import torch
from torch import nn
from torchvision.ops.roi_align import RoIAlign
from ultralytics.nn.modules.transformer import MLP
from ultralytics.nn.modules.utils import _get_clones, inverse_sigmoid
from ultralytics.utils.ops import xywh2xyxy
from .model_misc import gen_sineembed_for_position
class TransformerDecoderLayer(nn.Module):
"""TransformerDecoderLayer is made up of self-attn, cross-attn, and feedforward network (FFN)."""
def __init__(
self,
d_model: int,
dim_feedforward: int,
dropout: float,
cross_attention: nn.Module,
n_heads: int,
use_text_cross_attention: bool = False,
):
"""Initialize the TransformerDecoderLayer."""
super().__init__()
# cross attention
self.cross_attn = cross_attention
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm1 = nn.LayerNorm(d_model)
# cross attention text
self.use_text_cross_attention = use_text_cross_attention
if use_text_cross_attention:
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.catext_norm = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = nn.ReLU()
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm3 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
"""Add positional embedding to the tensor."""
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
"""Feedforward network forward pass."""
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward(
self,
# for tgt
tgt: torch.Tensor, # nq, bs, d_model
tgt_query_pos: torch.Tensor = None, # pos for query. MLP(Sine(pos))
memory_text: torch.Tensor = None, # num_token, bs, d_model
text_attention_mask: torch.Tensor = None, # bs, num_token
# for memory
memory: torch.Tensor = None, # hw, bs, d_model
memory_key_padding_mask: torch.Tensor = None,
memory_pos: torch.Tensor = None, # pos for memory
# sa
self_attn_mask: torch.Tensor = None, # mask used for self-attention
cross_attn_mask: torch.Tensor = None, # mask used for cross-attention
# dac
dac=False,
dac_use_selfatt_ln=True,
presence_token=None,
# skip inside deformable attn
**kwargs, # additional kwargs for compatibility
):
"""Input: - tgt/tgt_query_pos: nq, bs, d_model. -."""
# self attention
tgt, tgt_query_pos = self._apply_self_attention(
tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask
)
if self.use_text_cross_attention:
tgt2 = self.ca_text(
self.with_pos_embed(tgt, tgt_query_pos),
memory_text.to(tgt.dtype),
memory_text.to(tgt.dtype),
key_padding_mask=text_attention_mask,
)[0]
tgt = tgt + self.catext_dropout(tgt2)
tgt = self.catext_norm(tgt)
if presence_token is not None:
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
cross_attn_mask = torch.cat([presence_token_mask, cross_attn_mask], dim=1) # (bs*nheads, 1+nq, hw)
# Cross attention to image
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=cross_attn_mask,
key_padding_mask=(memory_key_padding_mask.transpose(0, 1) if memory_key_padding_mask is not None else None),
need_weights=False,
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt.to(memory.dtype))
presence_token_out = None
if presence_token is not None:
presence_token_out = tgt[:1]
tgt = tgt[1:]
return tgt, presence_token_out
def _apply_self_attention(self, tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask):
"""Apply self-attention with optional DAC splitting."""
if self.self_attn is None:
return tgt
if dac:
# Split queries for DAC (detect-and-classify)
assert tgt.shape[0] % 2 == 0, "DAC requires even number of queries"
num_o2o_queries = tgt.shape[0] // 2
tgt_o2o = tgt[:num_o2o_queries]
tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
tgt_o2m = tgt[num_o2o_queries:]
else:
tgt_o2o = tgt
tgt_query_pos_o2o = tgt_query_pos
# Handle presence token
if presence_token is not None:
tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
tgt_query_pos_o2o = torch.cat([torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0).to(
tgt_o2o.dtype
)
tgt_query_pos = torch.cat([torch.zeros_like(presence_token), tgt_query_pos], dim=0)
# Self-attention
q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0].to(tgt.dtype)
tgt_o2o = tgt_o2o + self.dropout2(tgt2)
# Recombine and normalize
if dac:
if not dac_use_selfatt_ln:
tgt_o2o = self.norm2(tgt_o2o)
tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0)
if dac_use_selfatt_ln:
tgt = self.norm2(tgt)
else:
tgt = tgt_o2o
tgt = self.norm2(tgt)
return tgt, tgt_query_pos
class TransformerDecoder(nn.Module):
"""Transformer Decoder consisting of multiple layers."""
def __init__(
self,
d_model: int,
frozen: bool,
interaction_layer,
layer,
num_layers: int,
num_queries: int,
return_intermediate: bool,
box_refine: bool = False,
num_o2m_queries: int = 0,
dac: bool = False,
boxRPB: str = "none",
# Experimental: An object query for SAM 2 tasks
instance_query: bool = False,
# Defines the number of additional instance queries,
# 1 or 4 are the most likely for single vs multi mask support
num_instances: int = 1, # Irrelevant if instance_query is False
dac_use_selfatt_ln: bool = True,
use_act_checkpoint: bool = False,
compile_mode=None,
presence_token: bool = False,
clamp_presence_logits: bool = True,
clamp_presence_logit_max_val: float = 10.0,
use_normed_output_consistently: bool = True,
separate_box_head_instance: bool = False,
separate_norm_instance: bool = False,
):
"""Initialize the TransformerDecoder."""
super().__init__()
self.d_model = d_model
self.layers = _get_clones(layer, num_layers)
self.fine_layers = (
_get_clones(interaction_layer, num_layers) if interaction_layer is not None else [None] * num_layers
)
self.num_layers = num_layers
self.num_queries = num_queries
self.dac = dac
if dac:
self.num_o2m_queries = num_queries
tot_num_queries = num_queries
else:
self.num_o2m_queries = num_o2m_queries
tot_num_queries = num_queries + num_o2m_queries
self.norm = nn.LayerNorm(d_model)
self.return_intermediate = return_intermediate
self.bbox_embed = MLP(d_model, d_model, 4, 3)
self.query_embed = nn.Embedding(tot_num_queries, d_model)
self.instance_query_embed = None
self.instance_query_reference_points = None
self.use_instance_query = instance_query
self.num_instances = num_instances
self.use_normed_output_consistently = use_normed_output_consistently
self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
self.instance_bbox_embed = None
if separate_box_head_instance:
self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
if instance_query:
self.instance_query_embed = nn.Embedding(num_instances, d_model)
self.box_refine = box_refine
if box_refine:
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
self.reference_points = nn.Embedding(num_queries, 4)
if instance_query:
self.instance_reference_points = nn.Embedding(num_instances, 4)
assert boxRPB in ["none", "log", "linear", "both"]
self.boxRPB = boxRPB
if boxRPB != "none":
try:
nheads = self.layers[0].cross_attn_image.num_heads
except AttributeError:
nheads = self.layers[0].cross_attn.num_heads
n_input = 4 if boxRPB == "both" else 2
self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
self.compilable_cord_cache = None
self.compilable_stored_size = None
self.coord_cache = {}
self.roi_pooler = (
RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
if interaction_layer is not None
else None
)
if frozen:
for p in self.parameters():
p.requires_grad_(False)
self.presence_token = None
self.clamp_presence_logits = clamp_presence_logits
self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
if presence_token:
self.presence_token = nn.Embedding(1, d_model)
self.presence_token_head = MLP(d_model, d_model, 1, 3)
self.presence_token_out_norm = nn.LayerNorm(d_model)
self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
self.dac_use_selfatt_ln = dac_use_selfatt_ln
self.use_act_checkpoint = use_act_checkpoint
nn.init.normal_(self.query_embed.weight.data)
if self.instance_query_embed is not None:
nn.init.normal_(self.instance_query_embed.weight.data)
assert self.roi_pooler is None
assert self.return_intermediate, "support return_intermediate only"
assert self.box_refine, "support box refine only"
self.compile_mode = compile_mode
self.compiled = False
# We defer compilation till after the first forward, to first warm-up the boxRPB cache
# assign layer index to each layer so that some layers can decide what to do
# based on which layer index they are (e.g. cross attention to memory bank only
# in selected layers)
for layer_idx, layer in enumerate(self.layers):
layer.layer_idx = layer_idx
@staticmethod
def _get_coords(H, W, device, dtype):
"""Get normalized coordinates for height and width."""
coords_h = torch.arange(0, H, dtype=dtype, device=device) / H
coords_w = torch.arange(0, W, dtype=dtype, device=device) / W
return coords_h, coords_w
def _get_rpb_matrix(self, reference_boxes, feat_size):
"""Get the relative position bias (RPB) matrix for box-relative position bias."""
H, W = feat_size
boxes_xyxy = xywh2xyxy(reference_boxes).transpose(0, 1)
bs, num_queries, _ = boxes_xyxy.shape
if self.compilable_cord_cache is None:
self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device, reference_boxes.dtype)
self.compilable_stored_size = (H, W)
if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
H,
W,
):
# good, hitting the cache, will be compilable
coords_h, coords_w = self.compilable_cord_cache
else:
# cache miss, will create compilation issue
# In case we're not compiling, we'll still rely on the dict-based cache
if feat_size not in self.coord_cache:
self.coord_cache[feat_size] = self._get_coords(H, W, reference_boxes.device)
coords_h, coords_w = self.coord_cache[feat_size]
assert coords_h.shape == (H,)
assert coords_w.shape == (W,)
deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
deltas_y = deltas_y.view(bs, num_queries, -1, 2)
deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
deltas_x = deltas_x.view(bs, num_queries, -1, 2)
if self.boxRPB in ["log", "both"]:
deltas_x_log = deltas_x * 8 # normalize to -8, 8
deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / np.log2(8)
deltas_y_log = deltas_y * 8 # normalize to -8, 8
deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / np.log2(8)
if self.boxRPB == "log":
deltas_x = deltas_x_log
deltas_y = deltas_y_log
else:
deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)
if self.training:
assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
deltas_x = self.boxRPB_embed_x(x=deltas_x) # bs, num_queries, W, n_heads
deltas_y = self.boxRPB_embed_y(x=deltas_y) # bs, num_queries, H, n_heads
if not torch.compiler.is_dynamo_compiling():
assert deltas_x.shape[:3] == (bs, num_queries, W)
assert deltas_y.shape[:3] == (bs, num_queries, H)
B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(2) # bs, num_queries, H, W, n_heads
if not torch.compiler.is_dynamo_compiling():
assert B.shape[:4] == (bs, num_queries, H, W)
B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
B = B.contiguous() # memeff attn likes ordered strides
if not torch.compiler.is_dynamo_compiling():
assert B.shape[2:] == (num_queries, H * W)
return B
def forward(
self,
tgt,
memory,
tgt_mask: torch.Tensor = None,
memory_mask: torch.Tensor = None,
memory_key_padding_mask: torch.Tensor = None,
pos: torch.Tensor = None,
reference_boxes: torch.Tensor = None, # num_queries, bs, 4
# for memory
spatial_shapes: torch.Tensor = None, # bs, num_levels, 2
valid_ratios: torch.Tensor = None,
# for text
memory_text: torch.Tensor = None,
text_attention_mask: torch.Tensor = None,
# if `apply_dac` is None, it will default to `self.dac`
apply_dac: bool | None = None,
is_instance_prompt=False,
decoder_extra_kwargs: dict | None = None,
# ROI memory bank
obj_roi_memory_feat=None,
obj_roi_memory_mask=None,
box_head_trk=None,
):
"""Forward pass of the TransformerDecoder."""
if memory_mask is not None:
assert self.boxRPB == "none", (
"inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
)
apply_dac = apply_dac if apply_dac is not None else self.dac
if apply_dac:
assert (tgt.shape[0] == self.num_queries) or (
self.use_instance_query and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
)
tgt = tgt.repeat(2, 1, 1)
# note that we don't tile tgt_mask, since DAC doesn't
# use self-attention in o2m queries
if reference_boxes is not None:
assert (reference_boxes.shape[0] == self.num_queries) or (
self.use_instance_query and (reference_boxes.shape[0] == self.instance_query_embed.num_embeddings)
)
reference_boxes = reference_boxes.repeat(2, 1, 1)
bs = tgt.shape[1]
intermediate = []
intermediate_presence_logits = []
presence_feats = None
if self.box_refine:
if reference_boxes is None:
# In this case, we're in a one-stage model, so we generate the reference boxes
reference_boxes = self.reference_points.weight.unsqueeze(1)
reference_boxes = reference_boxes.repeat(2, bs, 1) if apply_dac else reference_boxes.repeat(1, bs, 1)
reference_boxes = reference_boxes.sigmoid()
intermediate_ref_boxes = [reference_boxes]
else:
reference_boxes = None
intermediate_ref_boxes = None
output = tgt
presence_out = None
if self.presence_token is not None and is_instance_prompt is False:
# expand to batch dim
presence_out = self.presence_token.weight[None].expand(1, bs, -1)
box_head = self.bbox_embed
if is_instance_prompt and self.instance_bbox_embed is not None:
box_head = self.instance_bbox_embed
out_norm = self.norm
if is_instance_prompt and self.instance_norm is not None:
out_norm = self.instance_norm
for layer_idx, layer in enumerate(self.layers):
reference_points_input = (
reference_boxes[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
) # nq, bs, nlevel, 4
query_sine_embed = gen_sineembed_for_position(
reference_points_input[:, :, 0, :], self.d_model
) # nq, bs, d_model*2
# conditional query
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
if self.boxRPB != "none" and reference_boxes is not None:
assert spatial_shapes.shape[0] == 1, "only single scale support implemented"
memory_mask = self._get_rpb_matrix(
reference_boxes,
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
)
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
if self.training:
assert self.use_act_checkpoint, "Activation checkpointing not enabled in the decoder"
output, presence_out = layer(
tgt=output,
tgt_query_pos=query_pos,
memory_text=memory_text,
text_attention_mask=text_attention_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
memory_pos=pos,
self_attn_mask=tgt_mask,
cross_attn_mask=memory_mask,
dac=apply_dac,
dac_use_selfatt_ln=self.dac_use_selfatt_ln,
presence_token=presence_out,
**(decoder_extra_kwargs or {}),
# ROI memory bank
obj_roi_memory_feat=obj_roi_memory_feat,
obj_roi_memory_mask=obj_roi_memory_mask,
)
# iter update
if self.box_refine:
reference_before_sigmoid = inverse_sigmoid(reference_boxes)
if box_head_trk is None:
# delta_unsig = self.bbox_embed(output)
if not self.use_normed_output_consistently:
delta_unsig = box_head(output)
else:
delta_unsig = box_head(out_norm(output))
else:
# box_head_trk use a separate box head for tracking queries
Q_det = decoder_extra_kwargs["Q_det"]
assert output.size(0) >= Q_det
delta_unsig_det = self.bbox_embed(output[:Q_det])
delta_unsig_trk = box_head_trk(output[Q_det:])
delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid()
reference_boxes = new_reference_points.detach()
if layer_idx != self.num_layers - 1:
intermediate_ref_boxes.append(new_reference_points)
else:
raise NotImplementedError("not implemented yet")
intermediate.append(out_norm(output))
if self.presence_token is not None and is_instance_prompt is False:
# norm, mlp head
intermediate_layer_presence_logits = self.presence_token_head(
self.presence_token_out_norm(presence_out)
).squeeze(-1)
# clamp to mitigate numerical issues
if self.clamp_presence_logits:
intermediate_layer_presence_logits.clamp(
min=-self.clamp_presence_logit_max_val,
max=self.clamp_presence_logit_max_val,
)
intermediate_presence_logits.append(intermediate_layer_presence_logits)
presence_feats = presence_out.clone()
if not self.compiled and self.compile_mode is not None:
self.forward = torch.compile(self.forward, mode=self.compile_mode, fullgraph=True)
self.compiled = True
return (
torch.stack(intermediate),
torch.stack(intermediate_ref_boxes),
(
torch.stack(intermediate_presence_logits)
if self.presence_token is not None and is_instance_prompt is False
else None
),
presence_feats,
)

View file

@ -0,0 +1,535 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# Based on https://github.com/IDEA-Research/GroundingDINO
from __future__ import annotations
import torch
from torch import nn
from ultralytics.nn.modules.utils import _get_clones
from .model_misc import get_valid_ratio
class TransformerEncoderLayer(nn.Module):
"""Transformer encoder layer that performs self-attention followed by cross-attention.
This layer was previously called TransformerDecoderLayer but was renamed to better reflect its role in the
architecture. It processes input sequences through self-attention and then cross-attention with another input
(typically image features).
The layer supports both pre-norm and post-norm configurations, as well as positional encoding at different stages of
the attention mechanism.
"""
def __init__(
self,
d_model: int,
dim_feedforward: int,
dropout: float,
pos_enc_at_attn: bool,
pos_enc_at_cross_attn_keys: bool,
pos_enc_at_cross_attn_queries: bool,
pre_norm: bool,
self_attention: nn.Module = None,
cross_attention: nn.Module = None,
):
"""Initialize a transformer encoder layer.
Args:
cross_attention: Cross-attention module for attending to image features
d_model: Model dimension/hidden size
dim_feedforward: Dimension of the feedforward network
dropout: Dropout probability
pos_enc_at_attn: Whether to add positional encodings at self-attention
pos_enc_at_cross_attn_keys: Whether to add positional encodings to keys in cross-attention
pos_enc_at_cross_attn_queries: Whether to add positional encodings to queries in cross-attention
pre_norm: Whether to use pre-norm (True) or post-norm (False) architecture
self_attention: Self-attention module
"""
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
self.dropout_value = dropout
self.self_attn = self_attention or nn.MultiheadAttention(num_heads=8, dropout=0.1, embed_dim=256)
self.cross_attn_image = cross_attention or nn.MultiheadAttention(num_heads=8, dropout=0.1, embed_dim=256)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = nn.ReLU()
self.pre_norm = pre_norm
self.pos_enc_at_attn = pos_enc_at_attn
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
self.layer_idx = None
def forward_post(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: torch.Tensor = None,
memory_mask: torch.Tensor = None,
tgt_key_padding_mask: torch.Tensor = None,
memory_key_padding_mask: torch.Tensor = None,
pos: torch.Tensor = None,
query_pos: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
"""Forward pass for post-norm architecture.
In post-norm architecture, normalization is applied after attention and feedforward operations.
Args:
tgt: Input tensor to be processed
memory: Memory tensor for cross-attention
tgt_mask: Mask for self-attention
memory_mask: Mask for cross-attention
tgt_key_padding_mask: Key padding mask for self-attention
memory_key_padding_mask: Key padding mask for cross-attention
pos: Positional encoding for memory
query_pos: Positional encoding for query
**kwargs: Additional keyword arguments
Returns:
Processed tensor
"""
q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
# Self attention
tgt2 = self.self_attn(
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask, need_weights=False
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# Cross attention to image
tgt2 = self.cross_attn_image(
query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
need_weights=False,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# FFN
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
dac: bool = False,
tgt_mask: torch.Tensor = None,
memory_mask: torch.Tensor = None,
tgt_key_padding_mask: torch.Tensor = None,
memory_key_padding_mask: torch.Tensor = None,
pos: torch.Tensor = None,
query_pos: torch.Tensor = None,
# **kwargs,
) -> torch.Tensor:
"""Forward pass for pre-norm architecture.
In pre-norm architecture, normalization is applied before attention and feedforward operations.
Args:
tgt: Input tensor to be processed
memory: Memory tensor for cross-attention
dac: Whether to use Divide-and-Conquer attention
tgt_mask: Mask for self-attention
memory_mask: Mask for cross-attention
tgt_key_padding_mask: Key padding mask for self-attention
memory_key_padding_mask: Key padding mask for cross-attention
pos: Positional encoding for memory
query_pos: Positional encoding for query
attn_bias: Optional attention bias tensor
**kwargs: Additional keyword arguments
Returns:
Processed tensor
"""
if dac:
# we only apply self attention to the first half of the queries
assert tgt.shape[0] % 2 == 0
other_tgt = tgt[tgt.shape[0] // 2 :]
tgt = tgt[: tgt.shape[0] // 2]
tgt2 = self.norm1(tgt)
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
if dac:
# Recombine
tgt = torch.cat((tgt, other_tgt), dim=0)
tgt2 = self.norm2(tgt)
tgt2 = self.cross_attn_image(
query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
key=memory.to(tgt2.dtype) + pos if self.pos_enc_at_cross_attn_keys else memory.to(tgt2.dtype),
value=memory.to(tgt2.dtype),
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
# attn_bias=attn_bias,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
dac: bool = False,
tgt_mask: torch.Tensor = None,
memory_mask: torch.Tensor = None,
tgt_key_padding_mask: torch.Tensor = None,
memory_key_padding_mask: torch.Tensor = None,
pos: torch.Tensor = None,
query_pos: torch.Tensor = None,
# **kwds: Any,
) -> torch.Tensor:
"""Forward pass for the transformer encoder layer.
Args:
tgt: Input tensor to be processed
memory: Memory tensor (e.g., image features) for cross-attention
dac: Whether to use Divide-and-Conquer attention (only apply self-attention to first half)
tgt_mask: Mask for self-attention
memory_mask: Mask for cross-attention
tgt_key_padding_mask: Key padding mask for self-attention
memory_key_padding_mask: Key padding mask for cross-attention
pos: Positional encoding for memory
query_pos: Positional encoding for query
attn_bias: Optional attention bias tensor
**kwds: Additional keyword arguments
Returns:
Processed tensor after self-attention, cross-attention, and feedforward network
"""
fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
return fwd_fn(
tgt,
memory,
dac=dac,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
# attn_bias=attn_bias,
# **kwds,
)
class TransformerEncoder(nn.Module):
"""Transformer encoder that processes multi-level features.
This encoder takes multi-level features (e.g., from a backbone network) and processes them through a stack of
transformer encoder layers. It supports features from multiple levels (e.g., different resolutions) and can apply
activation checkpointing for memory efficiency during training.
Args:
layer: The encoder layer to be stacked multiple times
num_layers: Number of encoder layers to stack
d_model: Model dimension/hidden size
num_feature_levels: Number of feature levels to process
frozen: Whether to freeze the parameters of this module
use_act_checkpoint: Whether to use activation checkpointing during training
"""
def __init__(
self,
layer: nn.Module,
num_layers: int,
d_model: int,
num_feature_levels: int,
frozen: bool = False,
use_act_checkpoint: bool = False,
):
"""Initialize the transformer encoder."""
super().__init__()
self.layers = _get_clones(layer, num_layers)
self.num_layers = num_layers
self.num_feature_levels = num_feature_levels
self.level_embed = None
if num_feature_levels > 1:
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
if frozen:
for p in self.parameters():
p.requires_grad_(False)
self.use_act_checkpoint = use_act_checkpoint
# assign layer index to each layer so that some layers can decide what to do
# based on which layer index they are (e.g. cross attention to memory bank only
# in selected layers)
for layer_idx, layer in enumerate(self.layers):
layer.layer_idx = layer_idx
def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
"""Prepare multi-level features for transformer encoder."""
assert len(srcs) == self.num_feature_levels, "mismatch between expected and received # of feature levels"
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
has_mask = masks is not None and masks[0] is not None
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
_, _, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
if has_mask:
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
if self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
if has_mask:
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
mask_flatten = torch.cat(mask_flatten, 1) if has_mask else None # bs, \sum{hxw}
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat(
(
spatial_shapes.new_zeros((1,)),
spatial_shapes.prod(1).cumsum(0)[:-1],
)
)
if has_mask:
valid_ratios = torch.stack([get_valid_ratio(m) for m in masks], 1)
else:
valid_ratios = torch.ones(
(src_flatten.shape[0], self.num_feature_levels, 2),
device=src_flatten.device,
dtype=src_flatten.dtype,
)
return (
src_flatten,
mask_flatten,
lvl_pos_embed_flatten,
level_start_index,
valid_ratios,
spatial_shapes,
)
def forward(
self,
src: list[torch.Tensor],
src_key_padding_masks: list[torch.Tensor] | None = None,
pos: list[torch.Tensor] | None = None,
prompt: torch.Tensor = None,
prompt_key_padding_mask: torch.Tensor = None,
encoder_extra_kwargs: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Process multi-level features through the transformer encoder.
Args:
src: List of multi-level features, each with shape (batch_size, channels, height, width)
src_key_padding_masks: List of padding masks for each feature level, each with shape (batch_size, height,
width)
pos: List of positional embeddings for each feature level, each with shape (batch_size, channels, height,
width)
prompt: Optional text/prompt features to attend to, with shape (seq_len, batch_size, d_model)
prompt_key_padding_mask: Optional padding mask for prompt, with shape (batch_size, seq_len)
encoder_extra_kwargs: Optional additional arguments to pass to each encoder layer
Returns:
A tuple containing:
- output: Processed features with shape (seq_len, batch_size, d_model)
- key_padding_masks_flatten: Flattened padding masks
- lvl_pos_embed_flatten: Flattened positional embeddings
- level_start_index: Starting indices for each feature level
- spatial_shapes: Spatial dimensions of each feature level
- valid_ratios: Valid ratios for each feature level
"""
assert len(src) == self.num_feature_levels, "must be equal to num_feature_levels"
if src_key_padding_masks is not None:
assert len(src_key_padding_masks) == self.num_feature_levels
if pos is not None:
assert len(pos) == self.num_feature_levels
# Flatten multilevel feats and add level pos embeds
(
src_flatten,
key_padding_masks_flatten,
lvl_pos_embed_flatten,
level_start_index,
valid_ratios,
spatial_shapes,
) = self._prepare_multilevel_features(src, src_key_padding_masks, pos)
output = src_flatten
for layer in self.layers:
layer_kwargs = {}
assert isinstance(layer, TransformerEncoderLayer)
layer_kwargs["memory"] = prompt
layer_kwargs["memory_key_padding_mask"] = prompt_key_padding_mask
layer_kwargs["query_pos"] = lvl_pos_embed_flatten
layer_kwargs["tgt"] = output
layer_kwargs["tgt_key_padding_mask"] = key_padding_masks_flatten
if self.training:
assert self.use_act_checkpoint, "activation ckpt not enabled in encoder"
if encoder_extra_kwargs is not None:
layer_kwargs.update(encoder_extra_kwargs)
output = layer(**layer_kwargs)
# return as seq first
return (
output.transpose(0, 1),
(key_padding_masks_flatten.transpose(0, 1) if key_padding_masks_flatten is not None else None),
lvl_pos_embed_flatten.transpose(0, 1),
level_start_index,
spatial_shapes,
valid_ratios,
)
class TransformerEncoderFusion(TransformerEncoder):
"""Transformer encoder that fuses text and image features.
This encoder extends TransformerEncoder to handle both text and image features, with the ability to add pooled text
features to image features for better cross-modal fusion. It supports torch.compile for performance optimization.
Args:
layer: The encoder layer to be stacked multiple times
num_layers: Number of encoder layers to stack
d_model: Model dimension/hidden size
num_feature_levels: Number of feature levels to process
add_pooled_text_to_img_feat: Whether to add pooled text features to image features
pool_text_with_mask: Whether to use the mask when pooling text features
compile_mode: Mode for torch.compile, or None to disable compilation
**kwargs: Additional arguments to pass to the parent class
"""
def __init__(
self,
layer: nn.Module,
num_layers: int,
d_model: int,
num_feature_levels: int,
add_pooled_text_to_img_feat: bool = True,
pool_text_with_mask: bool = False,
compile_mode: str | None = None,
**kwargs,
):
"""Initialize the transformer encoder with text-image fusion."""
super().__init__(
layer,
num_layers,
d_model,
num_feature_levels,
**kwargs,
)
self.add_pooled_text_to_img_feat = add_pooled_text_to_img_feat
if self.add_pooled_text_to_img_feat:
self.text_pooling_proj = nn.Linear(d_model, d_model)
self.pool_text_with_mask = pool_text_with_mask
if compile_mode is not None:
self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
def forward(
self,
src: list[torch.Tensor],
prompt: torch.Tensor,
src_key_padding_mask: list[torch.Tensor] | None = None,
src_pos: list[torch.Tensor] | None = None,
prompt_key_padding_mask: torch.Tensor = None,
feat_sizes: list[int] | None = None,
encoder_extra_kwargs: dict | None = None,
):
"""Forward pass for the transformer encoder with text-image fusion."""
# Restore spatial shapes of vision
bs = src[0].shape[1] # seq first
if feat_sizes is not None:
assert len(feat_sizes) == len(src)
if src_key_padding_mask is None:
src_key_padding_mask = [None] * len(src)
for i, (h, w) in enumerate(feat_sizes):
src[i] = src[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
src_pos[i] = src_pos[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
src_key_padding_mask[i] = (
src_key_padding_mask[i].reshape(h, w, bs).permute(2, 0, 1)
if src_key_padding_mask[i] is not None
else None
)
else:
assert all(x.dim == 4 for x in src), "expected list of (bs, c, h, w) tensors"
if self.add_pooled_text_to_img_feat:
# Fusion: Add mean pooled text to image features
pooled_text = pool_text_feat(prompt, prompt_key_padding_mask, self.pool_text_with_mask)
pooled_text = self.text_pooling_proj(pooled_text)[..., None, None] # prompt is seq first
src = [x.add_(pooled_text) for x in src]
(
out,
key_padding_masks_flatten,
lvl_pos_embed_flatten,
level_start_index,
spatial_shapes,
valid_ratios,
) = super().forward(
src,
src_key_padding_masks=src_key_padding_mask,
pos=src_pos,
prompt=prompt.transpose(0, 1),
prompt_key_padding_mask=prompt_key_padding_mask,
encoder_extra_kwargs=encoder_extra_kwargs,
)
return {
"memory": out,
"padding_mask": key_padding_masks_flatten,
"pos_embed": lvl_pos_embed_flatten,
"memory_text": prompt,
"level_start_index": level_start_index,
"spatial_shapes": spatial_shapes,
"valid_ratios": valid_ratios,
}
def pool_text_feat(prompt, prompt_mask, pool_with_mask):
"""Mean-pool the prompt embeddings over the valid tokens only."""
# prompt has shape (seq, bs, dim)
if not pool_with_mask:
return prompt.mean(dim=0)
# prompt_mask has shape (bs, seq), where False is valid and True is padding
assert prompt_mask.dim() == 2
# is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
is_valid = (~prompt_mask).float().permute(1, 0)[..., None]
# num_valid has shape (bs, 1)
num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
# mean pool over all the valid tokens
pooled_text = (prompt * is_valid).sum(dim=0) / num_valid
return pooled_text

View file

@ -0,0 +1,415 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import torch
import torch.nn as nn
import torchvision
from ultralytics.nn.modules.utils import _get_clones
from ultralytics.utils.ops import xywh2xyxy
def is_right_padded(mask: torch.Tensor):
"""Given a padding mask (following pytorch convention, 1s for padded values), returns whether the padding is on the
right or not.
"""
return (mask.long() == torch.sort(mask.long(), dim=-1)[0]).all()
def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
"""
Concatenates two right-padded sequences, such that the resulting sequence
is contiguous and also right-padded.
Following pytorch's convention, tensors are sequence first, and the mask are
batch first, with 1s for padded values.
:param seq1: A tensor of shape (seq1_length, batch_size, hidden_size).
:param mask1: A tensor of shape (batch_size, seq1_length).
:param seq2: A tensor of shape (seq2_length, batch_size, hidden_size).
:param mask2: A tensor of shape (batch_size, seq2_length).
:param return_index: If True, also returns the index of the ids of the element of seq2
in the concatenated sequence. This can be used to retrieve the elements of seq2
:return: A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
otherwise (concatenated_sequence, concatenated_mask, index).
"""
seq1_length, batch_size, hidden_size = seq1.shape
seq2_length, batch_size, hidden_size = seq2.shape
assert batch_size == seq1.size(1) == seq2.size(1) == mask1.size(0) == mask2.size(0)
assert hidden_size == seq1.size(2) == seq2.size(2)
assert seq1_length == mask1.size(1)
assert seq2_length == mask2.size(1)
torch._assert_async(is_right_padded(mask1))
torch._assert_async(is_right_padded(mask2))
actual_seq1_lengths = (~mask1).sum(dim=-1)
actual_seq2_lengths = (~mask2).sum(dim=-1)
final_lengths = actual_seq1_lengths + actual_seq2_lengths
max_length = seq1_length + seq2_length
concatenated_mask = (
torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) >= final_lengths[:, None]
)
# (max_len, batch_size, hidden_size)
concatenated_sequence = torch.zeros((max_length, batch_size, hidden_size), device=seq2.device, dtype=seq2.dtype)
concatenated_sequence[:seq1_length, :, :] = seq1
# At this point, the element of seq1 are in the right place
# We just need to shift the elements of seq2
index = torch.arange(seq2_length, device=seq2.device)[:, None].repeat(1, batch_size)
index = index + actual_seq1_lengths[None]
concatenated_sequence = concatenated_sequence.scatter(0, index[:, :, None].expand(-1, -1, hidden_size), seq2)
if return_index:
return concatenated_sequence, concatenated_mask, index
return concatenated_sequence, concatenated_mask
class Prompt:
"""Utility class to manipulate geometric prompts.
We expect the sequences in pytorch convention, that is sequence first, batch second The dimensions are expected as
follows: box_embeddings shape: N_boxes x B x C_box box_mask shape: B x N_boxes. Can be None if nothing is masked out
point_embeddings shape: N_points x B x C_point point_mask shape: B x N_points. Can be None if nothing is masked out
mask_embeddings shape: N_masks x B x 1 x H_mask x W_mask mask_mask shape: B x N_masks. Can be None if nothing is
masked out
We also store positive/negative labels. These tensors are also stored batch-first If they are None, we'll assume
positive labels everywhere box_labels: long tensor of shape N_boxes x B point_labels: long tensor of shape N_points
x B mask_labels: long tensor of shape N_masks x B
"""
def __init__(self, box_embeddings=None, box_mask=None, box_labels=None):
"""Initialize the Prompt object."""
# Check for null prompt
# Check for null prompt
if box_embeddings is None:
self.box_embeddings = None
self.box_labels = None
self.box_mask = None
return
# Get sequence length, batch size, and device
box_seq_len = box_embeddings.shape[0]
bs = box_embeddings.shape[1]
device = box_embeddings.device
# Initialize labels and attention mask if not provided
if box_labels is None:
box_labels = torch.ones(box_seq_len, bs, device=device, dtype=torch.long)
if box_mask is None:
box_mask = torch.zeros(bs, box_seq_len, device=device, dtype=torch.bool)
# Dimension checks
assert list(box_embeddings.shape[:2]) == [box_seq_len, bs], (
f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
)
assert box_embeddings.shape[-1] == 4, (
f"Expected box embeddings to have 4 coordinates, got {box_embeddings.shape[-1]}"
)
assert list(box_mask.shape) == [bs, box_seq_len], (
f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
)
assert list(box_labels.shape) == [box_seq_len, bs], (
f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
)
# Device checks
assert box_embeddings.device == device, (
f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
)
assert box_mask.device == device, f"Expected box mask to be on device {device}, got {box_mask.device}"
assert box_labels.device == device, f"Expected box labels to be on device {device}, got {box_labels.device}"
self.box_embeddings = box_embeddings
self.box_mask = box_mask
self.box_labels = box_labels
def append_boxes(self, boxes, labels=None, mask=None):
"""Append box prompts to existing prompts.
Args:
boxes: Tensor of shape (N_new_boxes, B, 4) with normalized box coordinates
labels: Optional tensor of shape (N_new_boxes, B) with positive/negative labels
mask: Optional tensor of shape (B, N_new_boxes) for attention mask
"""
if self.box_embeddings is None:
# First boxes - initialize
self.box_embeddings = boxes
bs = boxes.shape[1]
box_seq_len = boxes.shape[0]
if labels is None:
labels = torch.ones(box_seq_len, bs, device=boxes.device, dtype=torch.long)
if mask is None:
mask = torch.zeros(bs, box_seq_len, device=boxes.device, dtype=torch.bool)
self.box_labels = labels
self.box_mask = mask
return
# Append to existing boxes
bs = self.box_embeddings.shape[1]
assert boxes.shape[1] == bs, f"Batch size mismatch: expected {bs}, got {boxes.shape[1]}"
if labels is None:
labels = torch.ones(boxes.shape[0], bs, device=boxes.device, dtype=torch.long)
if mask is None:
mask = torch.zeros(bs, boxes.shape[0], dtype=torch.bool, device=boxes.device)
assert list(boxes.shape[:2]) == list(labels.shape[:2]), (
f"Shape mismatch between boxes {boxes.shape} and labels {labels.shape}"
)
# Concatenate using the helper function
self.box_labels, _ = concat_padded_sequences(
self.box_labels.unsqueeze(-1), self.box_mask, labels.unsqueeze(-1), mask
)
self.box_labels = self.box_labels.squeeze(-1)
self.box_embeddings, self.box_mask = concat_padded_sequences(self.box_embeddings, self.box_mask, boxes, mask)
class SequenceGeometryEncoder(nn.Module):
"""Encoder for geometric box prompts. Assumes boxes are passed in the "normalized CxCyWH" format.
Boxes can be encoded with any of the three possibilities:
- direct projection: linear projection from coordinate space to d_model
- pooling: RoI align features from the backbone
- pos encoder: position encoding of the box center
These three options are mutually compatible and will be summed if multiple are selected.
As an alternative, boxes can be encoded as two corner points (top-left and bottom-right).
The encoded sequence can be further processed with a transformer.
"""
def __init__(
self,
encode_boxes_as_points: bool,
boxes_direct_project: bool,
boxes_pool: bool,
boxes_pos_enc: bool,
d_model: int,
pos_enc,
num_layers: int,
layer: nn.Module,
roi_size: int = 7,
add_cls: bool = True,
add_post_encode_proj: bool = True,
use_act_ckpt: bool = False,
):
"""Initialize the SequenceGeometryEncoder."""
super().__init__()
self.d_model = d_model
self.pos_enc = pos_enc
self.encode_boxes_as_points = encode_boxes_as_points
self.roi_size = roi_size
# Label embeddings: 2 labels if encoding as boxes (pos/neg)
# 6 labels if encoding as points (regular pos/neg, top-left pos/neg, bottom-right pos/neg)
num_labels = 6 if self.encode_boxes_as_points else 2
self.label_embed = torch.nn.Embedding(num_labels, self.d_model)
# CLS token for pooling
self.cls_embed = None
if add_cls:
self.cls_embed = torch.nn.Embedding(1, self.d_model)
# Point encoding (used when encode_boxes_as_points is True)
if encode_boxes_as_points:
self.points_direct_project = nn.Linear(2, self.d_model)
self.points_pool_project = None
self.points_pos_enc_project = None
else:
# Box encoding modules
assert boxes_direct_project or boxes_pos_enc or boxes_pool, "Error: need at least one way to encode boxes"
self.points_direct_project = None
self.points_pool_project = None
self.points_pos_enc_project = None
self.boxes_direct_project = None
self.boxes_pool_project = None
self.boxes_pos_enc_project = None
if boxes_direct_project:
self.boxes_direct_project = nn.Linear(4, self.d_model)
if boxes_pool:
self.boxes_pool_project = nn.Conv2d(self.d_model, self.d_model, self.roi_size)
if boxes_pos_enc:
self.boxes_pos_enc_project = nn.Linear(self.d_model + 2, self.d_model)
self.final_proj = None
if add_post_encode_proj:
self.final_proj = nn.Linear(self.d_model, self.d_model)
self.norm = nn.LayerNorm(self.d_model)
self.img_pre_norm = nn.Identity()
if self.points_pool_project is not None or self.boxes_pool_project is not None:
self.img_pre_norm = nn.LayerNorm(self.d_model)
self.encode = None
if num_layers > 0:
assert add_cls, "It's currently highly recommended to add a CLS when using a transformer"
self.encode = _get_clones(layer, num_layers)
self.encode_norm = nn.LayerNorm(self.d_model)
self.use_act_ckpt = use_act_ckpt
def _encode_points(self, points, points_mask, points_labels, img_feats):
"""Encode points (used when boxes are converted to corner points)."""
# Direct projection of coordinates
points_embed = self.points_direct_project(points.to(img_feats.dtype))
# Add label embeddings
type_embed = self.label_embed(points_labels.long())
return type_embed + points_embed, points_mask
def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats: torch.Tensor):
"""Encode boxes using configured encoding methods."""
boxes_embed = None
n_boxes, bs = boxes.shape[:2]
if self.boxes_direct_project is not None:
proj = self.boxes_direct_project(boxes.to(img_feats.dtype))
boxes_embed = proj
if self.boxes_pool_project is not None:
H, W = img_feats.shape[-2:]
# Convert boxes to xyxy format and denormalize
boxes_xyxy = xywh2xyxy(boxes.to(img_feats.dtype))
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
scale = scale.view(1, 1, 4)
boxes_xyxy = boxes_xyxy * scale
# RoI align
sampled = torchvision.ops.roi_align(img_feats, boxes_xyxy.transpose(0, 1).unbind(0), self.roi_size)
assert list(sampled.shape) == [
bs * n_boxes,
self.d_model,
self.roi_size,
self.roi_size,
]
proj = self.boxes_pool_project(sampled)
proj = proj.view(bs, n_boxes, self.d_model).transpose(0, 1)
if boxes_embed is None:
boxes_embed = proj
else:
boxes_embed = boxes_embed + proj
if self.boxes_pos_enc_project is not None:
cx, cy, w, h = boxes.unbind(-1)
enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten())
enc = enc.view(boxes.shape[0], boxes.shape[1], enc.shape[-1])
proj = self.boxes_pos_enc_project(enc.to(img_feats.dtype))
if boxes_embed is None:
boxes_embed = proj
else:
boxes_embed = boxes_embed + proj
# Add label embeddings
type_embed = self.label_embed(boxes_labels.long())
return type_embed + boxes_embed, boxes_mask
def forward(self, geo_prompt: Prompt, img_feats, img_sizes, img_pos_embeds=None):
"""Encode geometric box prompts.
Args:
geo_prompt: Prompt object containing box embeddings, masks, and labels
img_feats: List of image features from backbone
img_sizes: List of (H, W) tuples for each feature level
img_pos_embeds: Optional position embeddings for image features
Returns:
Tuple of (encoded_embeddings, attention_mask)
"""
boxes = geo_prompt.box_embeddings
boxes_mask = geo_prompt.box_mask
boxes_labels = geo_prompt.box_labels
seq_first_img_feats = img_feats[-1] # [H*W, B, C]
seq_first_img_pos_embeds = (
img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(seq_first_img_feats)
)
# Prepare image features for pooling if needed
if self.points_pool_project or self.boxes_pool_project:
assert len(img_feats) == len(img_sizes)
cur_img_feat = img_feats[-1]
cur_img_feat = self.img_pre_norm(cur_img_feat)
H, W = img_sizes[-1]
assert cur_img_feat.shape[0] == H * W
N, C = cur_img_feat.shape[-2:]
# Reshape to NxCxHxW
cur_img_feat = cur_img_feat.permute(1, 2, 0)
cur_img_feat = cur_img_feat.view(N, C, H, W)
img_feats = cur_img_feat
if self.encode_boxes_as_points:
# Convert boxes to corner points
assert boxes is not None and boxes.shape[-1] == 4
boxes_xyxy = xywh2xyxy(boxes)
top_left, bottom_right = boxes_xyxy.split(split_size=2, dim=-1)
# Adjust labels for corner points (offset by 2 and 4)
labels_tl = boxes_labels + 2
labels_br = boxes_labels + 4
# Concatenate top-left and bottom-right points
points = torch.cat([top_left, bottom_right], dim=0)
points_labels = torch.cat([labels_tl, labels_br], dim=0)
points_mask = torch.cat([boxes_mask, boxes_mask], dim=1)
final_embeds, final_mask = self._encode_points(
points=points,
points_mask=points_mask,
points_labels=points_labels,
img_feats=img_feats,
)
else:
# Encode boxes directly
final_embeds, final_mask = self._encode_boxes(
boxes=boxes,
boxes_mask=boxes_mask,
boxes_labels=boxes_labels,
img_feats=img_feats,
)
bs = final_embeds.shape[1]
assert final_mask.shape[0] == bs
# Add CLS token if configured
if self.cls_embed is not None:
cls = self.cls_embed.weight.view(1, 1, self.d_model).repeat(1, bs, 1)
cls_mask = torch.zeros(bs, 1, dtype=final_mask.dtype, device=final_mask.device)
final_embeds, final_mask = concat_padded_sequences(final_embeds, final_mask, cls, cls_mask)
# Final projection
if self.final_proj is not None:
final_embeds = self.norm(self.final_proj(final_embeds))
# Transformer encoding layers
if self.encode is not None:
for lay in self.encode:
final_embeds = lay(
tgt=final_embeds,
memory=seq_first_img_feats,
tgt_key_padding_mask=final_mask,
pos=seq_first_img_pos_embeds,
)
final_embeds = self.encode_norm(final_embeds)
return final_embeds, final_mask

View file

@ -0,0 +1,286 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from ultralytics.nn.modules.transformer import MLP
class LinearPresenceHead(nn.Sequential):
"""Linear presence head for predicting the presence of classes in an image."""
def __init__(self, d_model):
"""Initializes the LinearPresenceHead."""
# a hack to make `LinearPresenceHead` compatible with old checkpoints
super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1))
def forward(self, hs, prompt, prompt_mask):
"""Forward pass of the presence head."""
return super().forward(hs)
class MaskPredictor(nn.Module):
"""Predicts masks from object queries and pixel embeddings."""
def __init__(self, hidden_dim, mask_dim):
"""Initializes the MaskPredictor."""
super().__init__()
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
def forward(self, obj_queries, pixel_embed):
"""Predicts masks from object queries and pixel embeddings."""
if len(obj_queries.shape) == 3:
if pixel_embed.ndim == 3:
# batch size was omitted
mask_preds = torch.einsum("bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed)
else:
mask_preds = torch.einsum("bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed)
else:
# Assumed to have aux masks
if pixel_embed.ndim == 3:
# batch size was omitted
mask_preds = torch.einsum("lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed)
else:
mask_preds = torch.einsum("lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed)
return mask_preds
class SegmentationHead(nn.Module):
"""Segmentation head that predicts masks from backbone features and object queries."""
def __init__(
self,
hidden_dim,
upsampling_stages,
use_encoder_inputs=False,
aux_masks=False,
no_dec=False,
pixel_decoder=None,
act_ckpt=False,
shared_conv=False,
compile_mode_pixel_decoder=None,
):
"""Initializes the SegmentationHead."""
super().__init__()
self.use_encoder_inputs = use_encoder_inputs
self.aux_masks = aux_masks
if pixel_decoder is not None:
self.pixel_decoder = pixel_decoder
else:
self.pixel_decoder = PixelDecoder(
hidden_dim,
upsampling_stages,
shared_conv=shared_conv,
compile_mode=compile_mode_pixel_decoder,
)
self.no_dec = no_dec
if no_dec:
self.mask_predictor = nn.Conv2d(hidden_dim, 1, kernel_size=3, stride=1, padding=1)
else:
self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim)
self.act_ckpt = act_ckpt
# used to update the output dictionary
self.instance_keys = ["pred_masks"]
def _embed_pixels(self, backbone_feats: list[torch.Tensor], encoder_hidden_states) -> torch.Tensor:
"""Embeds pixels using the pixel decoder."""
if self.use_encoder_inputs:
backbone_visual_feats = [bb_feat.clone() for bb_feat in backbone_feats]
# Extract visual embeddings
encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0)
spatial_dim = math.prod(backbone_feats[-1].shape[-2:])
encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape(-1, *backbone_feats[-1].shape[1:])
backbone_visual_feats[-1] = encoder_visual_embed
if self.act_ckpt:
pixel_embed = checkpoint.checkpoint(self.pixel_decoder, backbone_visual_feats, use_reentrant=False)
else:
pixel_embed = self.pixel_decoder(backbone_visual_feats)
else:
backbone_feats = [x for x in backbone_feats]
pixel_embed = self.pixel_decoder(backbone_feats)
if pixel_embed.shape[0] == 1:
# For batch_size=1 training, we can avoid the indexing to save memory
pixel_embed = pixel_embed.squeeze(0)
else:
pixel_embed = pixel_embed[[0], ...]
return pixel_embed
def forward(
self,
backbone_feats: list[torch.Tensor],
obj_queries: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
**kwargs,
) -> dict[str, torch.Tensor]:
"""Forward pass of the SegmentationHead."""
if self.use_encoder_inputs:
assert encoder_hidden_states is not None
pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states)
if self.no_dec:
mask_pred = self.mask_predictor(pixel_embed)
elif self.aux_masks:
mask_pred = self.mask_predictor(obj_queries, pixel_embed)
else:
mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed)
return {"pred_masks": mask_pred}
class PixelDecoder(nn.Module):
"""Pixel decoder module that upsamples backbone features."""
def __init__(
self,
hidden_dim,
num_upsampling_stages,
interpolation_mode="nearest",
shared_conv=False,
compile_mode=None,
):
"""Initializes the PixelDecoder."""
super().__init__()
self.hidden_dim = hidden_dim
self.num_upsampling_stages = num_upsampling_stages
self.interpolation_mode = interpolation_mode
conv_layers = []
norms = []
num_convs = 1 if shared_conv else num_upsampling_stages
for _ in range(num_convs):
conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1))
norms.append(nn.GroupNorm(8, self.hidden_dim))
self.conv_layers = nn.ModuleList(conv_layers)
self.norms = nn.ModuleList(norms)
self.shared_conv = shared_conv
self.out_dim = self.conv_layers[-1].out_channels
if compile_mode is not None:
self.forward = torch.compile(self.forward, mode=compile_mode, dynamic=True, fullgraph=True)
# Needed to make checkpointing happy. But we don't know if the module is checkpointed, so we disable it by default.
torch._dynamo.config.optimize_ddp = False
def forward(self, backbone_feats: list[torch.Tensor]):
"""Forward pass of the PixelDecoder."""
prev_fpn = backbone_feats[-1]
fpn_feats = backbone_feats[:-1]
for layer_idx, bb_feat in enumerate(fpn_feats[::-1]):
curr_fpn = bb_feat
prev_fpn = curr_fpn + F.interpolate(prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode)
if self.shared_conv:
# only one conv layer
layer_idx = 0
prev_fpn = self.conv_layers[layer_idx](prev_fpn)
prev_fpn = F.relu(self.norms[layer_idx](prev_fpn))
return prev_fpn
class UniversalSegmentationHead(SegmentationHead):
"""This module handles semantic+instance segmentation."""
def __init__(
self,
hidden_dim,
upsampling_stages,
pixel_decoder,
aux_masks=False,
no_dec=False,
act_ckpt=False,
presence_head: bool = False,
dot_product_scorer=None,
cross_attend_prompt=None,
):
"""Initializes the UniversalSegmentationHead."""
super().__init__(
hidden_dim=hidden_dim,
upsampling_stages=upsampling_stages,
use_encoder_inputs=True,
aux_masks=aux_masks,
no_dec=no_dec,
pixel_decoder=pixel_decoder,
act_ckpt=act_ckpt,
)
self.d_model = hidden_dim
if dot_product_scorer is not None:
assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
self.presence_head = None
if presence_head:
self.presence_head = (
dot_product_scorer if dot_product_scorer is not None else LinearPresenceHead(self.d_model)
)
self.cross_attend_prompt = cross_attend_prompt
if self.cross_attend_prompt is not None:
self.cross_attn_norm = nn.LayerNorm(self.d_model)
self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1)
self.instance_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, self.d_model, kernel_size=1)
def forward(
self,
backbone_feats: list[torch.Tensor],
obj_queries: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
prompt: torch.Tensor = None,
prompt_mask: torch.Tensor = None,
**kwargs,
) -> dict[str, torch.Tensor]:
"""Forward pass of the UniversalSegmentationHead."""
assert encoder_hidden_states is not None
bs = encoder_hidden_states.shape[1]
if self.cross_attend_prompt is not None:
tgt2 = self.cross_attn_norm(encoder_hidden_states)
tgt2 = self.cross_attend_prompt(
query=tgt2,
key=prompt.to(tgt2.dtype),
value=prompt.to(tgt2.dtype),
key_padding_mask=prompt_mask,
need_weights=False,
)[0]
encoder_hidden_states = tgt2 + encoder_hidden_states
presence_logit = None
if self.presence_head is not None:
pooled_enc = encoder_hidden_states.mean(0)
presence_logit = (
self.presence_head(
pooled_enc.view(1, bs, 1, self.d_model),
prompt=prompt,
prompt_mask=prompt_mask,
)
.squeeze(0)
.squeeze(1)
)
pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states)
instance_embeds = self.instance_seg_head(pixel_embed)
if self.no_dec:
mask_pred = self.mask_predictor(instance_embeds)
elif self.aux_masks:
mask_pred = self.mask_predictor(obj_queries, instance_embeds)
else:
mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds)
return {
"pred_masks": mask_pred,
"semantic_seg": self.semantic_seg_head(pixel_embed),
"presence_logit": presence_logit,
}

View file

@ -0,0 +1,198 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Various utility models."""
from __future__ import annotations
import math
import numpy as np
import torch
from torch import Tensor, nn
class DotProductScoring(torch.nn.Module):
"""A module that computes dot-product scores between a set of query features and a."""
def __init__(
self,
d_model,
d_proj,
prompt_mlp=None,
clamp_logits=True,
clamp_max_val=12.0,
):
"""Initialize the DotProductScoring module."""
super().__init__()
self.d_proj = d_proj
assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None
self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt
self.prompt_proj = torch.nn.Linear(d_model, d_proj)
self.hs_proj = torch.nn.Linear(d_model, d_proj)
self.scale = float(1.0 / np.sqrt(d_proj))
self.clamp_logits = clamp_logits
if self.clamp_logits:
self.clamp_max_val = clamp_max_val
def mean_pool_text(self, prompt, prompt_mask):
"""Mean-pool the prompt embeddings over the valid tokens only."""
# is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
is_valid = (~prompt_mask).to(prompt.dtype).permute(1, 0)[..., None]
# num_valid has shape (bs, 1)
num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
# mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim)
pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid
return pooled_prompt
def forward(self, hs, prompt, prompt_mask):
"""Compute dot-product scores between hs and prompt."""
# hs has shape (num_layer, bs, num_query, d_model)
# prompt has shape (seq, bs, d_model)
# prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding
assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2
# apply MLP on prompt if specified
if self.prompt_mlp is not None:
prompt = self.prompt_mlp(prompt.to(hs.dtype))
# first, get the mean-pooled version of the prompt
pooled_prompt = self.mean_pool_text(prompt, prompt_mask)
# then, project pooled_prompt and hs to d_proj dimensions
proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj)
proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj)
# finally, get dot-product scores of shape (num_layer, bs, num_query, 1)
scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1))
scores *= self.scale
# clamp scores to a max value to avoid numerical issues in loss or matcher
if self.clamp_logits:
scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val)
return scores
class LayerScale(nn.Module):
"""LayerScale module as introduced in "Meta Pseudo Labels" and used in."""
def __init__(
self,
dim: int,
init_values: float | Tensor = 1e-5,
inplace: bool = False,
) -> None:
"""Initialize the LayerScale module."""
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
"""Apply LayerScale to the input tensor."""
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class TransformerWrapper(nn.Module):
"""A wrapper for the transformer consisting of an encoder and a decoder."""
def __init__(
self,
encoder,
decoder,
d_model: int,
two_stage_type="none", # ["none"] only for now
pos_enc_at_input_dec=True,
):
"""Initialize the TransformerWrapper."""
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.num_queries = decoder.num_queries if decoder is not None else None
self.pos_enc_at_input_dec = pos_enc_at_input_dec
# for two stage
assert two_stage_type in ["none"], f"unknown param {two_stage_type} of two_stage_type"
self.two_stage_type = two_stage_type
self._reset_parameters()
self.d_model = d_model
def _reset_parameters(self):
"""Initialize the parameters of the model."""
for n, p in self.named_parameters():
if p.dim() > 1:
if "box_embed" not in n and "query_embed" not in n and "reference_points" not in n:
nn.init.xavier_uniform_(p)
def get_valid_ratio(mask):
"""Compute the valid ratio of height and width from the mask."""
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def gen_sineembed_for_position(pos_tensor: torch.Tensor, num_feats: int = 256):
"""Generate sinusoidal position embeddings for 2D or 4D coordinate tensors.
This function creates sinusoidal embeddings using sine and cosine functions at different frequencies, similar to the
positional encoding used in Transformer models. It supports both 2D position tensors (x, y) and 4D tensors (x, y, w,
h) for bounding box coordinates.
Args:
pos_tensor (torch.Tensor): Input position tensor of shape (n_query, bs, 2) for 2D coordinates or (n_query, bs,
4) for 4D coordinates (bounding boxes).
num_feats (int): Number of feature dimensions for the output embedding. Must be even. Defaults to 256.
Returns:
(torch.Tensor): Sinusoidal position embeddings of shape (n_query, bs, num_feats) for 2D input or (n_query, bs,
num_feats * 2) for 4D input.
Raises:
AssertionError: If num_feats is not even.
ValueError: If pos_tensor.size(-1) is not 2 or 4.
Examples:
>>> pos_2d = torch.rand(100, 8, 2) # 100 queries, batch size 8, 2D coordinates
>>> embeddings_2d = gen_sineembed_for_position(pos_2d, num_feats=256)
>>> embeddings_2d.shape
torch.Size([100, 8, 256])
>>> pos_4d = torch.rand(50, 4, 4) # 50 queries, batch size 4, 4D coordinates
>>> embeddings_4d = gen_sineembed_for_position(pos_4d, num_feats=128)
>>> embeddings_4d.shape
torch.Size([50, 4, 256])
"""
assert num_feats % 2 == 0
num_feats = num_feats // 2
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(num_feats, dtype=pos_tensor.dtype, device=pos_tensor.device)
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
return pos

View file

@ -0,0 +1,129 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Necks are the interface between a vision backbone and the rest of the detection model."""
from __future__ import annotations
from copy import deepcopy
import torch
import torch.nn as nn
class Sam3DualViTDetNeck(nn.Module):
"""A neck that implements a simple FPN as in ViTDet, with support for dual necks (for SAM3 and SAM2)."""
def __init__(
self,
trunk: nn.Module,
position_encoding: nn.Module,
d_model: int,
scale_factors=(4.0, 2.0, 1.0, 0.5),
add_sam2_neck: bool = False,
):
"""
SimpleFPN neck a la ViTDet
(From detectron2, very lightly adapted)
It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights.
:param trunk: the backbone
:param position_encoding: the positional encoding to use
:param d_model: the dimension of the model
"""
super().__init__()
self.trunk = trunk
self.position_encoding = position_encoding
self.convs = nn.ModuleList()
self.scale_factors = scale_factors
use_bias = True
dim: int = self.trunk.channel_list[-1]
for _, scale in enumerate(scale_factors):
current = nn.Sequential()
if scale == 4.0:
current.add_module(
"dconv_2x2_0",
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
)
current.add_module(
"gelu",
nn.GELU(),
)
current.add_module(
"dconv_2x2_1",
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
)
out_dim = dim // 4
elif scale == 2.0:
current.add_module(
"dconv_2x2",
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
)
out_dim = dim // 2
elif scale == 1.0:
out_dim = dim
elif scale == 0.5:
current.add_module(
"maxpool_2x2",
nn.MaxPool2d(kernel_size=2, stride=2),
)
out_dim = dim
else:
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
current.add_module(
"conv_1x1",
nn.Conv2d(
in_channels=out_dim,
out_channels=d_model,
kernel_size=1,
bias=use_bias,
),
)
current.add_module(
"conv_3x3",
nn.Conv2d(
in_channels=d_model,
out_channels=d_model,
kernel_size=3,
padding=1,
bias=use_bias,
),
)
self.convs.append(current)
self.sam2_convs = None
if add_sam2_neck:
# Assumes sam2 neck is just a clone of the original neck
self.sam2_convs = deepcopy(self.convs)
def forward(
self, tensor_list: list[torch.Tensor]
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
"""Get the feature maps and positional encodings from the neck."""
xs = self.trunk(tensor_list)
sam3_out, sam3_pos = [], []
sam2_out, sam2_pos = None, None
if self.sam2_convs is not None:
sam2_out, sam2_pos = [], []
x = xs[-1] # simpleFPN
for i in range(len(self.convs)):
sam3_x_out = self.convs[i](x)
sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype)
sam3_out.append(sam3_x_out)
sam3_pos.append(sam3_pos_out)
if self.sam2_convs is not None:
sam2_x_out = self.sam2_convs[i](x)
sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
sam2_out.append(sam2_x_out)
sam2_pos.append(sam2_pos_out)
return sam3_out, sam3_pos, sam2_out, sam2_pos
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
"""Set the image size for the trunk backbone."""
self.trunk.set_imgsz(imgsz)

View file

@ -0,0 +1,357 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from __future__ import annotations
from copy import deepcopy
import torch
from ultralytics.nn.modules.utils import inverse_sigmoid
from ultralytics.utils.ops import xywh2xyxy
from .geometry_encoders import Prompt
from .vl_combiner import SAM3VLBackbone
def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True):
"""Helper function to update output dictionary with main and auxiliary outputs."""
out[out_name] = out_value[-1] if auxiliary else out_value
if auxiliary and update_aux:
if "aux_outputs" not in out:
out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)]
assert len(out["aux_outputs"]) == len(out_value) - 1
for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]):
aux_output[out_name] = aux_value
class SAM3SemanticModel(torch.nn.Module):
"""SAM3 model for semantic segmentation with vision-language backbone."""
def __init__(
self,
backbone: SAM3VLBackbone,
transformer,
input_geometry_encoder,
segmentation_head=None,
num_feature_levels=1,
o2m_mask_predict=True,
dot_prod_scoring=None,
use_instance_query: bool = True,
multimask_output: bool = True,
use_act_checkpoint_seg_head: bool = True,
matcher=None,
use_dot_prod_scoring=True,
supervise_joint_box_scores: bool = False, # only relevant if using presence token/score
detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score
separate_scorer_for_instance: bool = False,
num_interactive_steps_val: int = 0,
):
"""Initialize the SAM3SemanticModel."""
super().__init__()
self.backbone = backbone
self.geometry_encoder = input_geometry_encoder
self.transformer = transformer
self.hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.segmentation_head = segmentation_head
self.o2m_mask_predict = o2m_mask_predict
self.dot_prod_scoring = dot_prod_scoring
self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head
self.matcher = matcher
self.num_interactive_steps_val = num_interactive_steps_val
self.use_dot_prod_scoring = use_dot_prod_scoring
if self.use_dot_prod_scoring:
assert dot_prod_scoring is not None
self.dot_prod_scoring = dot_prod_scoring
self.instance_dot_prod_scoring = None
if separate_scorer_for_instance:
self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring)
else:
self.class_embed = torch.nn.Linear(self.hidden_dim, 1)
self.instance_class_embed = None
if separate_scorer_for_instance:
self.instance_class_embed = deepcopy(self.class_embed)
self.supervise_joint_box_scores = supervise_joint_box_scores
self.detach_presence_in_joint_score = detach_presence_in_joint_score
# verify the number of queries for O2O and O2M
num_o2o_static = self.transformer.decoder.num_queries
num_o2m_static = self.transformer.decoder.num_o2m_queries
assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0)
self.dac = self.transformer.decoder.dac
self.use_instance_query = use_instance_query
self.multimask_output = multimask_output
self.text_embeddings = {}
self.names = []
def _prepare_backbone_features(self, backbone_out, num_prompts=1):
"""Prepare and flatten visual features from the image backbone output for further processing."""
if num_prompts > 1: # expand features if there's more than one prompt
for i, feat in enumerate(backbone_out["backbone_fpn"]):
backbone_out["backbone_fpn"][i] = feat.expand(num_prompts, -1, -1, -1)
for i, pos in enumerate(backbone_out["vision_pos_enc"]):
pos = pos.expand(num_prompts, -1, -1, -1)
backbone_out["vision_pos_enc"][i] = pos
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
# flatten NxCxHxW to HWxNxC
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
def _encode_prompt(
self,
img_feats,
img_pos_embeds,
vis_feat_sizes,
geometric_prompt,
visual_prompt_embed=None,
visual_prompt_mask=None,
prev_mask_pred=None,
):
"""Encode the geometric and visual prompts."""
if prev_mask_pred is not None:
img_feats = [img_feats[-1] + prev_mask_pred]
# Encode geometry
geo_feats, geo_masks = self.geometry_encoder(
geo_prompt=geometric_prompt,
img_feats=img_feats,
img_sizes=vis_feat_sizes,
img_pos_embeds=img_pos_embeds,
)
if visual_prompt_embed is None:
visual_prompt_embed = torch.zeros((0, *geo_feats.shape[1:]), device=geo_feats.device)
visual_prompt_mask = torch.zeros(
(*geo_masks.shape[:-1], 0),
device=geo_masks.device,
dtype=geo_masks.dtype,
)
prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0)
prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1)
return prompt, prompt_mask
def _run_encoder(
self,
img_feats,
img_pos_embeds,
vis_feat_sizes,
prompt,
prompt_mask,
encoder_extra_kwargs: dict | None = None,
):
"""Run the transformer encoder."""
# Run the encoder
# make a copy of the image feature lists since the encoder may modify these lists in-place
memory = self.transformer.encoder(
src=img_feats.copy(),
src_key_padding_mask=None,
src_pos=img_pos_embeds.copy(),
prompt=prompt,
prompt_key_padding_mask=prompt_mask,
feat_sizes=vis_feat_sizes,
encoder_extra_kwargs=encoder_extra_kwargs,
)
encoder_out = {
# encoded image features
"encoder_hidden_states": memory["memory"],
"pos_embed": memory["pos_embed"],
"padding_mask": memory["padding_mask"],
"spatial_shapes": memory["spatial_shapes"],
"valid_ratios": memory["valid_ratios"],
"vis_feat_sizes": vis_feat_sizes,
# encoded text features (or other prompts)
"prompt_before_enc": prompt,
"prompt_after_enc": memory.get("memory_text", prompt),
"prompt_mask": prompt_mask,
}
return encoder_out
def _run_decoder(
self,
pos_embed,
memory,
src_mask,
out,
prompt,
prompt_mask,
encoder_out,
):
"""Run the transformer decoder."""
bs = memory.shape[1]
query_embed = self.transformer.decoder.query_embed.weight
tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
hs, reference_boxes, dec_presence_out, _ = self.transformer.decoder(
tgt=tgt,
memory=memory,
memory_key_padding_mask=src_mask,
pos=pos_embed,
reference_boxes=None,
spatial_shapes=encoder_out["spatial_shapes"],
valid_ratios=encoder_out["valid_ratios"],
tgt_mask=None,
memory_text=prompt,
text_attention_mask=prompt_mask,
apply_dac=False,
)
hs = hs.transpose(1, 2) # seq-first to batch-first
reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first
if dec_presence_out is not None:
# seq-first to batch-first
dec_presence_out = dec_presence_out.transpose(1, 2)
self._update_scores_and_boxes(
out,
hs,
reference_boxes,
prompt,
prompt_mask,
dec_presence_out=dec_presence_out,
)
return out, hs
def _update_scores_and_boxes(
self,
out,
hs,
reference_boxes,
prompt,
prompt_mask,
dec_presence_out=None,
is_instance_prompt=False,
):
"""Update output dict with class scores and box predictions."""
num_o2o = hs.size(2)
# score prediction
if self.use_dot_prod_scoring:
dot_prod_scoring_head = self.dot_prod_scoring
if is_instance_prompt and self.instance_dot_prod_scoring is not None:
dot_prod_scoring_head = self.instance_dot_prod_scoring
outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask)
else:
class_embed_head = self.class_embed
if is_instance_prompt and self.instance_class_embed is not None:
class_embed_head = self.instance_class_embed
outputs_class = class_embed_head(hs)
# box prediction
box_head = self.transformer.decoder.bbox_embed
if is_instance_prompt and self.transformer.decoder.instance_bbox_embed is not None:
box_head = self.transformer.decoder.instance_bbox_embed
anchor_box_offsets = box_head(hs)
reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid()
outputs_boxes_xyxy = xywh2xyxy(outputs_coord)
if dec_presence_out is not None:
_update_out(out, "presence_logit_dec", dec_presence_out, update_aux=False)
if self.supervise_joint_box_scores:
assert dec_presence_out is not None
prob_dec_presence_out = dec_presence_out.clone().sigmoid()
if self.detach_presence_in_joint_score:
prob_dec_presence_out = prob_dec_presence_out.detach()
outputs_class = inverse_sigmoid(outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)).clamp(
min=-10.0, max=10.0
)
_update_out(out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=False)
_update_out(out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=False)
_update_out(out, "pred_boxes_xyxy", outputs_boxes_xyxy[:, :, :num_o2o], update_aux=False)
def _run_segmentation_heads(
self,
out,
backbone_out,
encoder_hidden_states,
prompt,
prompt_mask,
hs,
):
"""Run segmentation heads and get masks."""
if self.segmentation_head is not None:
num_o2o = hs.size(2)
obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o]
seg_head_outputs = self.segmentation_head(
backbone_feats=backbone_out["backbone_fpn"],
obj_queries=obj_queries,
encoder_hidden_states=encoder_hidden_states,
prompt=prompt,
prompt_mask=prompt_mask,
)
for k, v in seg_head_outputs.items():
if k in self.segmentation_head.instance_keys:
_update_out(out, k, v[:, :num_o2o], auxiliary=False)
else:
out[k] = v
else:
backbone_out.pop("backbone_fpn", None)
def forward_grounding(
self, backbone_out: dict[str, torch.Tensor], text_ids: torch.Tensor, geometric_prompt: Prompt = None
):
"""Forward pass for grounding (detection + segmentation) given input images and text."""
backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = self._prepare_backbone_features(
backbone_out, num_prompts=len(text_ids)
)
backbone_out.update({k: v for k, v in self.text_embeddings.items()})
with torch.profiler.record_function("SAM3Image._encode_prompt"):
prompt, prompt_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt)
# index text features (note that regardless of early or late fusion, the batch size of
# `txt_feats` is always the number of *prompts* in the encoder)
txt_feats = backbone_out["language_features"][:, text_ids]
txt_masks = backbone_out["language_mask"][text_ids]
# encode text
prompt = torch.cat([txt_feats, prompt], dim=0)
prompt_mask = torch.cat([txt_masks, prompt_mask], dim=1)
# Run the encoder
with torch.profiler.record_function("SAM3Image._run_encoder"):
encoder_out = self._run_encoder(img_feats, img_pos_embeds, vis_feat_sizes, prompt, prompt_mask)
out = {"backbone_out": backbone_out}
# Run the decoder
with torch.profiler.record_function("SAM3Image._run_decoder"):
out, hs = self._run_decoder(
memory=encoder_out["encoder_hidden_states"],
pos_embed=encoder_out["pos_embed"],
src_mask=encoder_out["padding_mask"],
out=out,
prompt=prompt,
prompt_mask=prompt_mask,
encoder_out=encoder_out,
)
# Run segmentation heads
with torch.profiler.record_function("SAM3Image._run_segmentation_heads"):
self._run_segmentation_heads(
out=out,
backbone_out=backbone_out,
encoder_hidden_states=encoder_out["encoder_hidden_states"],
prompt=prompt,
prompt_mask=prompt_mask,
hs=hs,
)
return out
def set_classes(self, text: list[str]):
"""Set the text embeddings for the given class names."""
self.text_embeddings = self.backbone.forward_text(text)
self.names = text
def set_imgsz(self, imgsz: tuple[int, int]):
"""Set the image size for the model."""
self.backbone.set_imgsz(imgsz)

View file

@ -0,0 +1,307 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from __future__ import annotations
from collections import OrderedDict
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .model_misc import LayerScale
class ResidualAttentionBlock(nn.Module):
"""Transformer block with multi-head attention, layer normalization, and MLP feed-forward network."""
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float | None = None,
act_layer: Callable[[], nn.Module] = nn.GELU,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
):
"""Initialize residual attention block with configurable dimensions and normalization."""
super().__init__()
# Attention
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
# LayerNorm, LayerScale
self.ln_1 = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
# MLP
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model)),
]
)
)
def attention(
self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None
) -> torch.Tensor:
"""Compute multi-head attention with optional cross-attention support and masking."""
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
if attn_mask is not None:
# Leave boolean masks as is
if not attn_mask.dtype == torch.bool:
attn_mask = attn_mask.to(q_x.dtype)
return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0]
def forward(
self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None
) -> torch.Tensor:
"""Apply residual attention with layer normalization and MLP, supporting optional cross-attention."""
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class Transformer(nn.Module):
"""Stack of residual attention blocks forming a transformer encoder with optional gradient checkpointing."""
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float | None = None,
act_layer: Callable[[], nn.Module] = nn.GELU,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
compile_mode: str | None = None,
use_act_checkpoint: bool = False,
):
"""Initialize transformer with configurable depth, width, and optional compilation/checkpointing."""
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = use_act_checkpoint
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
for _ in range(layers)
]
)
if compile_mode is not None:
self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
if self.grad_checkpointing:
torch._dynamo.config.optimize_ddp = False
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor:
"""Process input through all transformer blocks with optional gradient checkpointing during training."""
for _, r in enumerate(self.resblocks):
if self.grad_checkpointing and not torch.jit.is_scripting() and self.training:
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
else:
x = r(x, attn_mask=attn_mask)
return x
def text_global_pool(
x: torch.Tensor, text: torch.Tensor = None, pool_type: str = "argmax"
) -> tuple[torch.Tensor, torch.Tensor]:
"""Extract pooled representation and tokens from text embeddings using specified pooling strategy
(first/last/argmax/none).
"""
if pool_type == "first":
pooled, tokens = x[:, 0], x[:, 1:]
elif pool_type == "last":
pooled, tokens = x[:, -1], x[:, :-1]
elif pool_type == "argmax":
# take features from the eot embedding (eot_token is the highest number in each sequence)
assert text is not None
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
else:
pooled = tokens = x
return pooled, tokens
class TextTransformer(nn.Module):
"""Text transformer encoder with causal masking and flexible pooling strategies."""
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
mlp_ratio: float = 4.0,
ls_init_value: float | None = None,
output_dim: int = 512,
no_causal_mask: bool = False,
pool_type: str = "none", # no pooling
proj_bias: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
output_tokens: bool = False,
use_ln_post: bool = True,
compile_mode: str | None = None,
use_act_checkpoint: bool = False,
):
"""Initialize text transformer with embedding layers, transformer blocks, and pooling options."""
super().__init__()
assert pool_type in ("first", "last", "argmax", "none")
self.output_tokens = output_tokens
self.num_pos = self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.heads = heads
self.pool_type = pool_type
self.token_embedding = nn.Embedding(self.vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
compile_mode=compile_mode,
use_act_checkpoint=use_act_checkpoint,
)
self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
if no_causal_mask:
self.attn_mask = None
else:
self.register_buffer("attn_mask", self.build_causal_mask(), persistent=False)
if proj_bias:
self.text_projection = nn.Linear(width, output_dim)
else:
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
def build_causal_mask(self) -> torch.Tensor:
"""Create a causal attention mask to prevent attention to future tokens."""
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(self, text: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the text transformer, returning pooled output and optionally token embeddings."""
seq_len = text.shape[1]
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
attn_mask = self.attn_mask
if attn_mask is not None:
attn_mask = attn_mask[:seq_len, :seq_len]
x = x + self.positional_embedding[:seq_len]
x = self.transformer(x, attn_mask=attn_mask)
x = self.ln_final(x)
pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
pooled = self.text_projection(pooled)
else:
pooled = pooled @ self.text_projection
if self.output_tokens:
return pooled, tokens
return pooled
class VETextEncoder(nn.Module):
"""Text encoder for Vision Encoder (VE) models, combining a text transformer and a linear resizer."""
def __init__(
self,
d_model: int,
tokenizer: Callable,
width: int = 1024,
heads: int = 16,
layers: int = 24,
context_length: int = 32,
vocab_size: int = 49408,
use_ln_post: bool = True,
compile_mode: str | None = None,
use_act_checkpoint: bool = True,
):
"""Initialize VE text encoder with a text transformer and a linear resizer to match decoder dimensions."""
super().__init__()
self.context_length = context_length
self.use_ln_post = use_ln_post
self.tokenizer = tokenizer
self.encoder = TextTransformer(
context_length=self.context_length,
vocab_size=vocab_size,
width=width,
heads=heads,
layers=layers,
# we want the tokens, not just the pooled output
output_tokens=True,
use_ln_post=use_ln_post,
compile_mode=compile_mode,
use_act_checkpoint=use_act_checkpoint,
)
self.resizer = nn.Linear(self.encoder.width, d_model)
def forward(
self, text: list[str] | tuple[torch.Tensor, torch.Tensor, dict], input_boxes: list | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode text input, either raw strings or pre-encoded tensors, and resize to match decoder dimensions."""
if isinstance(text[0], str):
# no use case for this
assert input_boxes is None or len(input_boxes) == 0, "not supported"
# Encode the text
tokenized = self.tokenizer(text, context_length=self.context_length).to(
self.resizer.weight.device
) # [b, seq_len]
text_attention_mask = (tokenized != 0).bool()
# manually embed the tokens
inputs_embeds = self.encoder.token_embedding(tokenized) # [b, seq_len, d=1024]
_, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024]
assert text_memory.shape[1] == inputs_embeds.shape[1]
# Invert attention mask because its the opposite in pytorch transformer
text_attention_mask = text_attention_mask.ne(1)
# Transpose memory because pytorch's attention expects sequence first
text_memory = text_memory.transpose(0, 1)
# Resize the encoder hidden states to be of the same d_model as the decoder
text_memory_resized = self.resizer(text_memory)
else:
# The text is already encoded, use as is.
text_attention_mask, text_memory_resized, tokenized = text
inputs_embeds = tokenized["inputs_embeds"]
assert input_boxes is None or len(input_boxes) == 0, "Can't replace boxes in text if it's already encoded"
# Note that the input_embeds are returned in pytorch's convention (sequence first)
return (
text_attention_mask,
text_memory_resized,
inputs_embeds.transpose(0, 1),
)

View file

@ -0,0 +1,242 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Text Tokenizer.
Copied and lightly adapted from VE repo, which in turn copied
from open_clip and openAI CLIP.
"""
from __future__ import annotations
import gzip
import html
import io
import os
import string
from functools import lru_cache
import ftfy
import regex as re
import torch
from iopath.common.file_io import g_pathmgr
@lru_cache
def bytes_to_unicode():
"""Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode
strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When
you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a
significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8
bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length
strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
"""Basic text cleaning: fix unicode and unescape HTML entities."""
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
"""Remove redundant whitespace."""
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def _clean_canonicalize(x):
"""Clean text and canonicalize it."""
# basic, remove whitespace, remove punctuation, lower case
return canonicalize_text(basic_clean(x))
def _clean_lower(x):
"""Clean text and return lowercase."""
# basic, remove whitespace, lower case
return whitespace_clean(basic_clean(x)).lower()
def _clean_whitespace(x):
"""Clean text and remove redundant whitespace."""
# basic, remove whitespace
return whitespace_clean(basic_clean(x))
def get_clean_fn(type: str):
"""Get text cleaning function by name."""
if type == "canonicalize":
return _clean_canonicalize
elif type == "lower":
return _clean_lower
elif type == "whitespace":
return _clean_whitespace
else:
assert False, f"Invalid clean function ({type})."
def canonicalize_text(text, *, keep_punctuation_exact_string=None):
"""Returns canonicalized `text` (lowercase and punctuation removed). From:
https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94.
Args:
text: string to be canonicalized.
keep_punctuation_exact_string: If provided, then this exact string kept. For example providing '{}' will keep
any occurrences of '{}' (but will still remove '{' and '}' that appear separately).
"""
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans("", "", string.punctuation))
for part in text.split(keep_punctuation_exact_string)
)
else:
text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower()
text = re.sub(r"\s+", " ", text)
return text.strip()
class SimpleTokenizer:
"""A simple tokenizer for text inputs."""
def __init__(
self,
bpe_path: str | os.PathLike,
additional_special_tokens: list[str] | None = None,
context_length: int = 77,
clean: str = "lower",
):
"""The tokenizer for text inputs."""
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with g_pathmgr.open(bpe_path, "rb") as fh:
bpe_bytes = io.BytesIO(fh.read())
merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
# merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
merges = merges[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
special_tokens = ["<start_of_text>", "<end_of_text>"]
if additional_special_tokens:
special_tokens += additional_special_tokens
vocab.extend(special_tokens)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {t: t for t in special_tokens}
special = "|".join(special_tokens)
self.pat = re.compile(
special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE,
)
self.vocab_size = len(self.encoder)
self.all_special_ids = [self.encoder[t] for t in special_tokens]
self.sot_token_id = self.all_special_ids[0]
self.eot_token_id = self.all_special_ids[1]
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
def bpe(self, token):
"""Byte Pair Encoding."""
if token in self.cache:
return self.cache[token]
word = (*tuple(token[:-1]), token[-1] + "</w>")
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
"""Encode text to a sequence of BPE tokens."""
bpe_tokens = []
text = self.clean_fn(text)
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def decode(self, tokens):
"""Decodes a sequence of tokens back into a text string."""
text = "".join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("</w>", " ")
return text
def __call__(self, texts: str | list[str], context_length: int | None = None) -> torch.LongTensor:
"""Returns the tokenized representation of given input string(s) Parameters. ---------- texts : Union[str,
list[str]] An input string or a list of input strings to tokenize context_length : int The context
length to use; all CLIP models use 77 as the context length.
Returns:
-------: A two-dimensional tensor containing the resulting tokens, shape = [number of input strings,
context_length]
"""
if isinstance(texts, str):
texts = [texts]
context_length = context_length or self.context_length
assert context_length, "Please set a valid context length"
all_tokens = [[self.sot_token_id, *self.encode(text), self.eot_token_id] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = self.eot_token_id
result[i, : len(tokens)] = torch.tensor(tokens)
return result

View file

@ -0,0 +1,546 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
ViTDet backbone adapted from Detectron2.
This module implements Vision Transformer (ViT) backbone for object detection.
Rope embedding code adopted from:
1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
2. https://github.com/naver-ai/rope-vit
3. https://github.com/lucidrains/rotary-embedding-torch
"""
from __future__ import annotations
import math
from functools import partial
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch import Tensor
from ultralytics.models.sam.modules.blocks import PatchEmbed
from ultralytics.models.sam.modules.utils import (
apply_rotary_enc,
compute_axial_cis,
concat_rel_pos,
get_abs_pos,
window_partition,
window_unpartition,
)
from ultralytics.utils.checks import check_requirements
from .model_misc import LayerScale
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings and 2d-rope."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: tuple[int, int] | None = None,
cls_token: bool = False,
use_rope: bool = False,
rope_theta: float = 10000.0,
rope_pt_size: tuple[int, int] | None = None,
rope_interp: bool = False,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution for calculating the relative positional parameter size or rope
size.
attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer".
cls_token: whether a cls_token is present.
use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
use_rel_pos: whether to use relative positional embeddings
rope_theta: control frequencies of rope
rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
rope_interp: whether to interpolate (or extrapolate) rope to match input size.
"""
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.cls_token = cls_token
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
# rel_pos embeddings and rope
self.use_rel_pos = use_rel_pos
self.input_size = input_size
self.use_rope = use_rope
self.rope_theta = rope_theta
self.rope_pt_size = rope_pt_size
self.rope_interp = rope_interp
# init rel_pos embeddings and rope
self._setup_rel_pos(rel_pos_zero_init, input_size)
self._setup_rope_freqs(input_size)
def _setup_rel_pos(self, rel_pos_zero_init: bool = True, input_size: tuple[int, int] | None = None) -> None:
"""Setup relative positional embeddings."""
if not self.use_rel_pos:
self.rel_pos_h = None
self.rel_pos_w = None
return
assert input_size is not None
assert self.cls_token is False, "not supported"
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
if not rel_pos_zero_init:
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
# Precompute the relative coords
H, W = input_size
q_coords = torch.arange(H)[:, None]
k_coords = torch.arange(W)[None, :]
relative_coords = (q_coords - k_coords) + (H - 1)
self.relative_coords = relative_coords.long()
def _setup_rope_freqs(self, input_size: tuple[int, int] | None = None) -> None:
"""Setup 2d-rope frequencies."""
if not self.use_rope:
self.freqs_cis = None
return
assert input_size is not None
# determine rope input size
if self.rope_pt_size is None:
self.rope_pt_size = input_size
# initialize 2d rope freqs
self.compute_cis = partial(
compute_axial_cis,
dim=self.head_dim,
theta=self.rope_theta,
)
# interpolate rope
scale_pos = 1.0
if self.rope_interp:
scale_pos = self.rope_pt_size[0] / input_size[0]
# get scaled freqs_cis
freqs_cis = self.compute_cis(
end_x=input_size[0],
end_y=input_size[1],
scale_pos=scale_pos,
)
if self.cls_token:
t = torch.zeros(
self.head_dim // 2,
dtype=torch.float32,
device=freqs_cis.device,
)
cls_freqs_cis = torch.polar(torch.ones_like(t), t)[None, :]
freqs_cis = torch.cat([cls_freqs_cis, freqs_cis], dim=0)
self.freqs_cis = freqs_cis
def _apply_rope(self, q, k) -> tuple[Tensor, Tensor]:
"""Apply 2d-rope to q and k."""
if not self.use_rope:
return q, k
assert self.freqs_cis is not None
return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis.to(q.device))
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of attention block."""
s = 1 if self.cls_token else 0 # used to exclude cls_token
if x.ndim == 4:
B, H, W, _ = x.shape
assert s == 0 # no cls_token
L = H * W
ndim = 4
else:
assert x.ndim == 3
B, L, _ = x.shape
ndim = 3
H = W = math.sqrt(L - s)
# qkv with shape (3, B, nHead, L, C)
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, -1)
# q, k, v with shape (B, nHead, L, C)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
# handle rope and rel pos embeddings
q, k = self._apply_rope(q, k)
if self.use_rel_pos:
q, k = concat_rel_pos(
q.flatten(0, 1),
k.flatten(0, 1),
(H, W),
x.shape[1:3],
self.rel_pos_h,
self.rel_pos_w,
rescale=True,
relative_coords=self.relative_coords,
)
# sdpa expects [B, nheads, H*W, C] so we transpose back
q = q.reshape(B, self.num_heads, H * W, -1)
k = k.reshape(B, self.num_heads, H * W, -1)
x = F.scaled_dot_product_attention(q, k, v)
if ndim == 4:
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
else:
x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
x = self.proj(x)
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention."""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_path: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
act_layer: Callable[..., nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: tuple[int, int] | None = None,
use_rope: bool = False,
rope_pt_size: tuple[int, int] | None = None,
rope_interp: bool = False,
cls_token: bool = False,
dropout: float = 0.0,
init_values: float | None = None,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then not use window attention.
input_size (int or None): Input resolution for calculating the relative positional parameter size.
dropout (float): Dropout rate.
cls_token: whether a cls_token is present.
use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
rope_interp: whether to interpolate (or extrapolate) rope to match target input size, expected to specify
source size as rope_pt_size.
init_values: layer scale init, None for no layer scale.
"""
super().__init__()
check_requirements("timm")
from timm.layers import DropPath, Mlp
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rope=use_rope,
rope_pt_size=rope_pt_size,
rope_interp=rope_interp,
cls_token=cls_token,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=(dropout, 0.0),
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.dropout = nn.Dropout(dropout)
self.window_size = window_size
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the transformer block."""
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.ls1(self.attn(x))
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + self.dropout(self.drop_path(x))
x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x)))))
return x
class ViT(nn.Module):
"""This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. "Exploring Plain Vision Transformer
Backbones for Object Detection", https://arxiv.org/abs/2203.16527.
"""
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_path_rate: float = 0.0,
norm_layer: Callable[..., nn.Module] | str = "LayerNorm",
act_layer: Callable[..., nn.Module] = nn.GELU,
use_abs_pos: bool = True,
tile_abs_pos: bool = True,
rel_pos_blocks: tuple[int, ...] | bool = (2, 5, 8, 11),
rel_pos_zero_init: bool = True,
window_size: int = 14,
global_att_blocks: tuple[int, ...] = (2, 5, 8, 11),
use_rope: bool = False,
rope_pt_size: int | None = None,
use_interp_rope: bool = False,
pretrain_img_size: int = 224,
pretrain_use_cls_token: bool = True,
retain_cls_token: bool = True,
dropout: float = 0.0,
return_interm_layers: bool = False,
init_values: float | None = None, # for layerscale
ln_pre: bool = False,
ln_post: bool = False,
bias_patch_embed: bool = True,
compile_mode: str | None = None,
use_act_checkpoint: bool = True,
):
"""
Args:
img_size (int): Input image size. Only relevant for rel pos or rope.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation.
rel_pos_blocks (list): Blocks which have rel pos embeddings.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention).
use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together).
rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling.
use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size, expected to
specify source size as rope_pt_size.
use_act_checkpoint (bool): If True, use activation checkpointing.
pretrain_img_size (int): input image size for pretraining models.
pretrain_use_cls_token (bool): If True, pretraining models use class token.
retain_cls_token: whether cls_token should be retained.
dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp.
return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks).
init_values: layer scale init, None for no layer scale.
ln_pre (bool): If True, apply layer norm before transformer blocks.
ln_post (bool): If True, apply layer norm after transformer blocks.
bias_patch_embed (bool): bias in conv for patch embed?
compile_mode (str): mode to compile the forward.
"""
super().__init__()
self.pretrain_use_cls_token = pretrain_use_cls_token
window_block_indexes = [i for i in range(depth) if i not in global_att_blocks]
self.full_attn_ids = list(global_att_blocks)
self.rel_pos_blocks = [False] * depth
if isinstance(rel_pos_blocks, bool) and rel_pos_blocks:
self.rel_pos_blocks = [True] * depth
else:
for i in rel_pos_blocks:
self.rel_pos_blocks[i] = True
self.retain_cls_token = retain_cls_token
if self.retain_cls_token:
assert pretrain_use_cls_token
assert len(window_block_indexes) == 0, "windowing not supported with cls token"
assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
scale = embed_dim**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(1, 1, embed_dim))
if isinstance(norm_layer, str):
norm_layer = partial(getattr(nn, norm_layer), eps=1e-5)
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
bias=bias_patch_embed,
)
# Handle absolute positional embedding
self.tile_abs_pos = tile_abs_pos
self.use_abs_pos = use_abs_pos
if self.tile_abs_pos:
assert self.use_abs_pos
if self.use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
else:
self.pos_embed = None
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.patch_size = patch_size
self.window_size = window_size
self.blocks = nn.ModuleList()
cur_stage = 1
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=self.rel_pos_blocks[i],
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i in window_block_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
use_rope=use_rope,
rope_pt_size=((window_size, window_size) if rope_pt_size is None else (rope_pt_size, rope_pt_size)),
rope_interp=use_interp_rope,
cls_token=self.retain_cls_token,
dropout=dropout,
init_values=init_values,
)
if i not in window_block_indexes:
cur_stage += 1
self.use_act_checkpoint = use_act_checkpoint
self.blocks.append(block)
self.return_interm_layers = return_interm_layers
self.channel_list = [embed_dim] * len(self.full_attn_ids) if return_interm_layers else [embed_dim]
if self.pos_embed is not None:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.ln_pre = norm_layer(embed_dim) if ln_pre else nn.Identity()
self.ln_post = norm_layer(embed_dim) if ln_post else nn.Identity()
self.apply(self._init_weights)
if compile_mode is not None:
self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
if self.use_act_checkpoint and self.training:
torch._dynamo.config.optimize_ddp = False
def _init_weights(self, m: nn.Module) -> None:
"""Initialize the weights."""
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""Vit forward path and get feature maps."""
x = self.patch_embed(x)
h, w = x.shape[1], x.shape[2]
s = 0
if self.retain_cls_token:
# If cls_token is retained, we don't
# maintain spatial shape
x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
s = 1
if self.pos_embed is not None:
x = x + get_abs_pos(
self.pos_embed,
self.pretrain_use_cls_token,
(h, w),
self.retain_cls_token,
tiling=self.tile_abs_pos,
)
x = self.ln_pre(x)
outputs = []
for i, blk in enumerate(self.blocks):
if self.use_act_checkpoint and self.training:
x = checkpoint.checkpoint(blk, x, use_reentrant=False)
else:
x = blk(x)
if (i == self.full_attn_ids[-1]) or (self.return_interm_layers and i in self.full_attn_ids):
if i == self.full_attn_ids[-1]:
x = self.ln_post(x)
feats = x[:, s:]
if feats.ndim == 4:
feats = feats.permute(0, 3, 1, 2)
else:
assert feats.ndim == 3
h = w = math.sqrt(feats.shape[1])
feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)
outputs.append(feats)
return outputs
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
"""Setup rel pos embeddings and rope freqs for a new input image size."""
for block in self.blocks:
if block.window_size != 0:
continue
block.attn._setup_rel_pos(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))
block.attn._setup_rope_freqs(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))

View file

@ -0,0 +1,165 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Provides utility to combine a vision backbone with a language backbone."""
from __future__ import annotations
from copy import copy
import torch
import torch.nn as nn
from torch.nn.attention import SDPBackend, sdpa_kernel
from .necks import Sam3DualViTDetNeck
class SAM3VLBackbone(nn.Module):
"""This backbone combines a vision backbone and a language backbone without fusion. As such it is more of a
convenience wrapper to handle the two backbones together.
It adds support for activation checkpointing and compilation.
"""
def __init__(
self,
visual: Sam3DualViTDetNeck,
text,
compile_visual: bool = False,
act_ckpt_whole_vision_backbone: bool = False,
act_ckpt_whole_language_backbone: bool = False,
scalp=0,
):
"""Initialize the backbone combiner.
:param visual: The vision backbone to use
:param text: The text encoder to use
"""
super().__init__()
self.vision_backbone: Sam3DualViTDetNeck = torch.compile(visual) if compile_visual else visual
self.language_backbone = text
self.scalp = scalp
# allow running activation checkpointing on the entire vision and language backbones
self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
def forward(
self,
samples: torch.Tensor,
captions: list[str],
input_boxes: torch.Tensor = None,
additional_text: list[str] | None = None,
):
"""Forward pass of the backbone combiner.
:param samples: The input images
:param captions: The input captions
:param input_boxes: If the text contains place-holders for boxes, this
parameter contains the tensor containing their spatial features
:param additional_text: This can be used to encode some additional text
(different from the captions) in the same forward of the backbone
:return: Output dictionary with the following keys:
- vision_features: The output of the vision backbone
- language_features: The output of the language backbone
- language_mask: The attention mask of the language backbone
- vision_pos_enc: The positional encoding of the vision backbone
- (optional) additional_text_features: The output of the language
backbone for the additional text
- (optional) additional_text_mask: The attention mask of the
language backbone for the additional text
"""
output = self.forward_image(samples)
output.update(self.forward_text(captions, input_boxes, additional_text))
return output
def forward_image(self, samples: torch.Tensor):
"""Forward pass of the vision backbone and get both SAM3 and SAM2 features."""
# Forward through backbone
sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(samples)
if self.scalp > 0:
# Discard the lowest resolution features
sam3_features, sam3_pos = (
sam3_features[: -self.scalp],
sam3_pos[: -self.scalp],
)
if sam2_features is not None and sam2_pos is not None:
sam2_features, sam2_pos = (
sam2_features[: -self.scalp],
sam2_pos[: -self.scalp],
)
sam2_output = None
if sam2_features is not None and sam2_pos is not None:
sam2_src = sam2_features[-1]
sam2_output = {
"vision_features": sam2_src,
"vision_pos_enc": sam2_pos,
"backbone_fpn": sam2_features,
}
sam3_src = sam3_features[-1]
return {
"vision_features": sam3_src,
"vision_pos_enc": sam3_pos,
"backbone_fpn": sam3_features,
"sam2_backbone_out": sam2_output,
}
def forward_image_sam2(self, samples: torch.Tensor):
"""Forward pass of the vision backbone to get SAM2 features only."""
xs = self.vision_backbone.trunk(samples)
sam2_features, sam2_pos = [], []
x = xs[-1] # simpleFPN
assert self.vision_backbone.sam2_convs is not None, "SAM2 neck is not available."
for i in range(len(self.vision_backbone.sam2_convs)):
sam2_x_out = self.vision_backbone.sam2_convs[i](x)
sam2_pos_out = self.vision_backbone.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
sam2_features.append(sam2_x_out)
sam2_pos.append(sam2_pos_out)
if self.scalp > 0:
# Discard the lowest resolution features
sam2_features, sam2_pos = (
sam2_features[: -self.scalp],
sam2_pos[: -self.scalp],
)
return {
"vision_features": sam2_features[-1],
"vision_pos_enc": sam2_pos,
"backbone_fpn": sam2_features,
}
def forward_text(self, captions, input_boxes=None, additional_text=None):
"""Forward pass of the text encoder."""
output = {}
# Forward through text_encoder
text_to_encode = copy(captions)
if additional_text is not None:
# if there are additional_text, we piggy-back them into this forward.
# They'll be used later for output alignment
text_to_encode += additional_text
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]):
text_attention_mask, text_memory, text_embeds = self.language_backbone(text_to_encode, input_boxes)
if additional_text is not None:
output["additional_text_features"] = text_memory[:, -len(additional_text) :]
output["additional_text_mask"] = text_attention_mask[-len(additional_text) :]
text_memory = text_memory[:, : len(captions)]
text_attention_mask = text_attention_mask[: len(captions)]
text_embeds = text_embeds[:, : len(captions)]
output["language_features"] = text_memory
output["language_mask"] = text_attention_mask
output["language_embeds"] = text_embeds # Text embeddings before forward to the encoder
return output
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
"""Set the image size for the vision backbone."""
self.vision_backbone.set_imgsz(imgsz)

View file

@ -359,7 +359,15 @@ class MLP(nn.Module):
"""
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act=nn.ReLU, sigmoid: bool = False
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
act=nn.ReLU,
sigmoid: bool = False,
residual: bool = False,
out_norm: nn.Module = None,
):
"""Initialize the MLP with specified input, hidden, output dimensions and number of layers.
@ -370,6 +378,8 @@ class MLP(nn.Module):
num_layers (int): Number of layers.
act (nn.Module): Activation function.
sigmoid (bool): Whether to apply sigmoid to the output.
residual (bool): Whether to use residual connections.
out_norm (nn.Module, optional): Normalization layer for the output.
"""
super().__init__()
self.num_layers = num_layers
@ -377,6 +387,12 @@ class MLP(nn.Module):
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim, *h], [*h, output_dim]))
self.sigmoid = sigmoid
self.act = act()
if residual and input_dim != output_dim:
raise ValueError("residual is only supported if input_dim == output_dim")
self.residual = residual
# whether to apply a normalization layer to the output
assert isinstance(out_norm, nn.Module) or out_norm is None
self.out_norm = out_norm or nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass for the entire MLP.
@ -387,8 +403,12 @@ class MLP(nn.Module):
Returns:
(torch.Tensor): Output tensor after MLP.
"""
orig_x = x
for i, layer in enumerate(self.layers):
x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x)
if getattr(self, "residual", False):
x = x + orig_x
x = getattr(self, "out_norm", nn.Identity())(x)
return x.sigmoid() if getattr(self, "sigmoid", False) else x

View file

@ -660,6 +660,4 @@ def clean_str(s):
def empty_like(x):
"""Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
return (
torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
)
return torch.empty_like(x, dtype=x.dtype) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=x.dtype)