mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
ultralytics 8.3.237 SAM3 integration (#22897)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: fatih akyon <34196005+fcakyon@users.noreply.github.com>
This commit is contained in:
parent
24b1ec6252
commit
e0764aa55d
45 changed files with 7070 additions and 252 deletions
|
|
@ -114,13 +114,17 @@ SAM 3 will be available directly in the Ultralytics package once integration lan
|
|||
pip install ultralytics
|
||||
```
|
||||
|
||||
Models will download automatically when first used. You can then use standard [predict mode](../modes/predict.md) and later [export](../modes/export.md) models to formats like [ONNX](../integrations/onnx.md) and [TensorRT](../integrations/tensorrt.md) for deployment. Watch for a package update with SAM-3 weights and configs soon.
|
||||
!!! warning "SAM 3 Model Weights Required"
|
||||
|
||||
Unlike other Ultralytics models, SAM 3 weights (`sam3.pt`) are **not automatically downloaded**. You must manually download the model weights from the [official SAM 3 repository](https://github.com/facebookresearch/sam3) before using SAM 3. Place the downloaded `sam3.pt` file in your working directory or specify the full path when loading the model.
|
||||
|
||||
!!! note "BPE Vocabulary for Text Prompts"
|
||||
|
||||
If you plan to use text-based concept segmentation with `SAM3SemanticPredictor`, you also need to download the BPE vocabulary file `bpe_simple_vocab_16e6.txt.gz` from the [SAM 3 assets](https://github.com/facebookresearch/sam3/blob/main/assets/bpe_simple_vocab_16e6.txt.gz).
|
||||
|
||||
## How to Use SAM 3: Versatility in Concept Segmentation
|
||||
|
||||
!!! warning "Ultralytics API preview"
|
||||
|
||||
The following examples show the intended Ultralytics API once SAM 3 ships in the package. Until integration lands, details may change.
|
||||
SAM 3 supports both Promptable Concept Segmentation (PCS) and Promptable Visual Segmentation (PVS) tasks through different predictor interfaces.
|
||||
|
||||
### Supported Tasks and Models
|
||||
|
||||
|
|
@ -138,143 +142,176 @@ SAM 3 supports both Promptable Concept Segmentation (PCS) and Promptable Visual
|
|||
|
||||
!!! example "Text-based Concept Segmentation"
|
||||
|
||||
Find and segment all instances of a concept using a text description.
|
||||
Find and segment all instances of a concept using a text description. Text prompts require the `SAM3SemanticPredictor` interface.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
from ultralytics import SAM
|
||||
from ultralytics.models.sam.predict import SAM3SemanticPredictor
|
||||
|
||||
# Load SAM 3 model
|
||||
model = SAM("sam3.pt")
|
||||
# Initialize predictor with configuration
|
||||
overrides = dict(
|
||||
conf=0.25,
|
||||
task="segment",
|
||||
mode="predict",
|
||||
model="sam3.pt",
|
||||
half=True, # Use FP16 for faster inference
|
||||
)
|
||||
predictor = SAM3SemanticPredictor(
|
||||
overrides=overrides,
|
||||
bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz", # Required for text encoding
|
||||
)
|
||||
|
||||
# Segment all instances of a concept
|
||||
results = model("path/to/image.jpg", prompt="yellow school bus")
|
||||
# Set image once for multiple queries
|
||||
predictor.set_image("path/to/image.jpg")
|
||||
|
||||
# Query with multiple text prompts
|
||||
results = predictor(text=["person", "bus", "glasses"], save=True)
|
||||
|
||||
# Works with descriptive phrases
|
||||
results = model("path/to/image.jpg", prompt="person wearing a red hat")
|
||||
results = predictor(text=["person with red cloth", "person with blue cloth"], save=True)
|
||||
|
||||
# Or simple object names
|
||||
results = model("path/to/image.jpg", prompt="striped cat")
|
||||
# Query with a single concept
|
||||
results = predictor(text=["a person"], save=True)
|
||||
```
|
||||
|
||||
=== "CLI"
|
||||
!!! note "Text Encoding Requirement"
|
||||
|
||||
```bash
|
||||
# Segment all matching concepts in an image
|
||||
yolo segment model=sam3.pt source=path/to/image.jpg prompt="yellow school bus"
|
||||
```
|
||||
|
||||
!!! warning "API Preview"
|
||||
|
||||
This example shows intended usage. Actual implementation pending Ultralytics integration.
|
||||
The `bpe_path` parameter is required for text prompt encoding. Download the BPE vocabulary file from the [bpe_simple_vocab_16e6.txt.gz](https://github.com/facebookresearch/sam3/blob/main/assets/bpe_simple_vocab_16e6.txt.gz).
|
||||
|
||||
#### Segment with Image Exemplars
|
||||
|
||||
!!! example "Image Exemplar-based Segmentation"
|
||||
|
||||
Use one or more example objects to find all similar instances.
|
||||
Use bounding boxes as visual prompts to find all similar instances. This also requires `SAM3SemanticPredictor` for concept-based matching.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
from ultralytics import SAM
|
||||
from ultralytics.models.sam.predict import SAM3SemanticPredictor
|
||||
|
||||
model = SAM("sam3.pt")
|
||||
# Initialize predictor
|
||||
overrides = dict(conf=0.25, task="segment", mode="predict", model="sam3.pt", half=True)
|
||||
predictor = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
# Provide a positive example box - finds all similar objects
|
||||
results = model("path/to/image.jpg", bboxes=[100, 150, 300, 400], labels=[1])
|
||||
# Set image
|
||||
predictor.set_image("path/to/image.jpg")
|
||||
|
||||
# Add negative examples to exclude certain instances
|
||||
results = model(
|
||||
"path/to/image.jpg",
|
||||
bboxes=[[100, 150, 300, 400], [500, 200, 600, 350]], # Two boxes
|
||||
labels=[1, 0], # First is positive, second is negative
|
||||
)
|
||||
# Provide bounding box examples to segment similar objects
|
||||
results = predictor(bboxes=[[480.0, 290.0, 590.0, 650.0]], save=True)
|
||||
|
||||
# Combine text and image exemplars for precision
|
||||
results = model("path/to/image.jpg", prompt="dog", bboxes=[100, 150, 300, 400], labels=[1])
|
||||
# Multiple bounding boxes for different concepts
|
||||
results = predictor(bboxes=[[539, 599, 589, 639], [343, 267, 499, 662]], save=True)
|
||||
```
|
||||
|
||||
!!! warning "API Preview"
|
||||
#### Feature-based Inference for Efficiency
|
||||
|
||||
This example shows intended usage. Actual implementation pending Ultralytics integration.
|
||||
!!! example "Reusing Image Features for Multiple Queries"
|
||||
|
||||
#### Interactive Refinement
|
||||
|
||||
!!! example "Iterative Refinement with Exemplars"
|
||||
|
||||
Progressively improve results by adding exemplar prompts based on initial output.
|
||||
Extract image features once and reuse them for multiple segmentation queries to improve efficiency.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
from ultralytics import SAM
|
||||
import cv2
|
||||
|
||||
model = SAM("sam3.pt")
|
||||
from ultralytics.models.sam.predict import SAM3SemanticPredictor
|
||||
from ultralytics.utils.plotting import Annotator, colors
|
||||
|
||||
# Initial segmentation with text
|
||||
results = model("path/to/image.jpg", prompt="car")
|
||||
# Initialize predictors
|
||||
overrides = dict(conf=0.50, task="segment", mode="predict", model="sam3.pt", verbose=False)
|
||||
predictor = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
|
||||
predictor2 = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
# If some cars are missed, add a positive exemplar
|
||||
results = model(
|
||||
"path/to/image.jpg",
|
||||
prompt="car",
|
||||
bboxes=[missed_car_box],
|
||||
labels=[1], # Positive example
|
||||
)
|
||||
# Extract features from the first predictor
|
||||
source = "path/to/image.jpg"
|
||||
predictor.set_image(source)
|
||||
src_shape = cv2.imread(source).shape[:2]
|
||||
|
||||
# If false positives appear, add negative exemplars
|
||||
results = model(
|
||||
"path/to/image.jpg",
|
||||
prompt="car",
|
||||
bboxes=[false_positive_box],
|
||||
labels=[0], # Negative example
|
||||
)
|
||||
# Setup second predictor and reuse features
|
||||
predictor2.setup_model()
|
||||
|
||||
# Perform inference using shared features with text prompt
|
||||
masks, boxes = predictor2.inference_features(predictor.features, src_shape=src_shape, text=["person"])
|
||||
|
||||
# Perform inference using shared features with bounding box prompt
|
||||
masks, boxes = predictor2.inference_features(predictor.features, src_shape=src_shape, bboxes=[439, 437, 524, 709])
|
||||
|
||||
# Visualize results
|
||||
masks, boxes = masks.cpu().numpy(), boxes.cpu().numpy()
|
||||
im = cv2.imread(source)
|
||||
annotator = Annotator(im, pil=False)
|
||||
annotator.masks(masks, [colors(x, True) for x in range(len(masks))])
|
||||
|
||||
cv2.imshow("result", annotator.result())
|
||||
cv2.waitKey(0)
|
||||
```
|
||||
|
||||
!!! warning "API Preview"
|
||||
|
||||
This example shows intended usage. Actual implementation pending Ultralytics integration.
|
||||
|
||||
### Video Concept Segmentation
|
||||
|
||||
!!! example "Track Concepts Across Video"
|
||||
#### Track Concepts Across Video with Bounding Boxes
|
||||
|
||||
Detect and track all instances of a concept throughout a video.
|
||||
!!! example "Video Tracking with Visual Prompts"
|
||||
|
||||
Detect and track object instances across video frames using bounding box prompts.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
from ultralytics.models.sam import SAM3VideoPredictor
|
||||
from ultralytics.models.sam.predict import SAM3VideoPredictor
|
||||
|
||||
# Create video predictor
|
||||
predictor = SAM3VideoPredictor(model="sam3.pt", imgsz=1024, conf=0.25)
|
||||
overrides = dict(conf=0.25, task="segment", mode="predict", model="sam3.pt", half=True)
|
||||
predictor = SAM3VideoPredictor(overrides=overrides)
|
||||
|
||||
# Track all instances of a concept
|
||||
results = predictor(source="video.mp4", prompt="person wearing blue shirt")
|
||||
# Track objects using bounding box prompts
|
||||
results = predictor(source="path/to/video.mp4", bboxes=[[706.5, 442.5, 905.25, 555], [598, 635, 725, 750]], stream=True)
|
||||
|
||||
# Combine text with exemplar for precision
|
||||
# Process and display results
|
||||
for r in results:
|
||||
r.show() # Display frame with segmentation masks
|
||||
```
|
||||
|
||||
#### Track Concepts with Text Prompts
|
||||
|
||||
!!! example "Video Tracking with Semantic Queries"
|
||||
|
||||
Track all instances of concepts specified by text across video frames.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
from ultralytics.models.sam.sam3_video_model import SAM3VideoSemanticPredictor
|
||||
|
||||
# Initialize semantic video predictor
|
||||
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=640, model="sam3.pt", half=True)
|
||||
predictor = SAM3VideoSemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
# Track concepts using text prompts
|
||||
results = predictor(source="path/to/video.mp4", text=["person", "bicycle"], stream=True, save=True)
|
||||
|
||||
# Process results
|
||||
for r in results:
|
||||
r.show() # Display frame with tracked objects
|
||||
|
||||
# Alternative: Track with bounding box prompts
|
||||
results = predictor(
|
||||
source="video.mp4",
|
||||
prompt="kangaroo",
|
||||
bboxes=[initial_box], # Exemplar from first frame
|
||||
labels=[1],
|
||||
source="path/to/video.mp4",
|
||||
bboxes=[[864, 383, 975, 620], [705, 229, 782, 402]],
|
||||
labels=[1, 1], # Positive labels
|
||||
stream=True,
|
||||
save=True,
|
||||
)
|
||||
```
|
||||
|
||||
!!! warning "API Preview"
|
||||
|
||||
This example shows intended usage. Actual implementation pending Ultralytics integration.
|
||||
|
||||
For broader streaming and production setups, see [object tracking](../guides/object-counting.md) and [view results in terminal](../guides/view-results-in-terminal.md).
|
||||
|
||||
### Visual Prompts (SAM 2 Compatibility)
|
||||
|
||||
SAM 3 maintains full backward compatibility with SAM 2's visual prompting:
|
||||
SAM 3 maintains full backward compatibility with SAM 2's visual prompting for single-object segmentation:
|
||||
|
||||
!!! example "SAM 2 Style Visual Prompts"
|
||||
|
||||
The basic `SAM` interface behaves exactly like SAM 2, segmenting only the specific area indicated by visual prompts (points, boxes, or masks).
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
|
|
@ -282,19 +319,21 @@ SAM 3 maintains full backward compatibility with SAM 2's visual prompting:
|
|||
|
||||
model = SAM("sam3.pt")
|
||||
|
||||
# Single point prompt (SAM 2 style)
|
||||
results = model(points=[900, 370], labels=[1])
|
||||
# Single point prompt - segments object at specific location
|
||||
results = model.predict(source="path/to/image.jpg", points=[900, 370], labels=[1])
|
||||
results[0].show()
|
||||
|
||||
# Multiple points
|
||||
results = model(points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||
# Multiple points - segments single object with multiple point hints
|
||||
results = model.predict(source="path/to/image.jpg", points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||
|
||||
# Box prompt
|
||||
results = model(bboxes=[100, 150, 300, 400])
|
||||
# Box prompt - segments object within bounding box
|
||||
results = model.predict(source="path/to/image.jpg", bboxes=[100, 150, 300, 400])
|
||||
results[0].show()
|
||||
```
|
||||
|
||||
!!! warning "API Preview"
|
||||
!!! warning "Visual Prompts vs Concept Segmentation"
|
||||
|
||||
This example shows intended usage. Actual implementation pending Ultralytics integration.
|
||||
Using `SAM("sam3.pt")` with visual prompts (points/boxes/masks) will segment **only the specific object** at that location, just like SAM 2. To segment **all instances of a concept**, use `SAM3SemanticPredictor` with text or exemplar prompts as shown above.
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@ keywords: Ultralytics, SAM model, Segment Anything Model, SAM 2 model, Segment A
|
|||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.build._load_checkpoint
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.build.build_sam_vit_h
|
||||
|
||||
<br><br><hr><br>
|
||||
|
|
|
|||
32
docs/en/reference/models/sam/build_sam3.md
Normal file
32
docs/en/reference/models/sam/build_sam3.md
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/build_sam3.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/build_sam3.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/build_sam3.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.build_sam3._create_vision_backbone
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.build_sam3._create_sam3_transformer
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.build_sam3.build_sam3_image_model
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.build_sam3.build_interactive_sam3
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.build_sam3._load_checkpoint
|
||||
|
||||
<br><br>
|
||||
|
|
@ -17,4 +17,8 @@ keywords: Ultralytics, SAM Module, SAM 2 Module, object segmentation, image enco
|
|||
|
||||
## ::: ultralytics.models.sam.modules.sam.SAM2Model
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.modules.sam.SAM3Model
|
||||
|
||||
<br><br>
|
||||
|
|
|
|||
|
|
@ -49,4 +49,12 @@ keywords: Ultralytics, SAM, SAM 2, API Reference, models, window partition, data
|
|||
|
||||
## ::: ultralytics.models.sam.modules.utils.add_decomposed_rel_pos
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.modules.utils.get_abs_pos
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.modules.utils.concat_rel_pos
|
||||
|
||||
<br><br>
|
||||
|
|
|
|||
|
|
@ -25,4 +25,20 @@ keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Mode
|
|||
|
||||
## ::: ultralytics.models.sam.predict.SAM2DynamicInteractivePredictor
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.predict.SAM3Predictor
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.predict.SAM3SemanticPredictor
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.predict.SAM3VideoPredictor
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.predict.SAM3VideoSemanticPredictor
|
||||
|
||||
<br><br>
|
||||
|
|
|
|||
20
docs/en/reference/models/sam/sam3/decoder.md
Normal file
20
docs/en/reference/models/sam/sam3/decoder.md
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/sam3/decoder.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/decoder.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/decoder.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.decoder.TransformerDecoder
|
||||
|
||||
<br><br>
|
||||
28
docs/en/reference/models/sam/sam3/encoder.md
Normal file
28
docs/en/reference/models/sam/sam3/encoder.md
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/sam3/encoder.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/encoder.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/encoder.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoderLayer
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoder
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoderFusion
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.encoder.pool_text_feat
|
||||
|
||||
<br><br>
|
||||
28
docs/en/reference/models/sam/sam3/geometry_encoders.md
Normal file
28
docs/en/reference/models/sam/sam3/geometry_encoders.md
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/sam3/geometry_encoders.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/geometry_encoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/geometry_encoders.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.geometry_encoders.Prompt
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.geometry_encoders.SequenceGeometryEncoder
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.geometry_encoders.is_right_padded
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.geometry_encoders.concat_padded_sequences
|
||||
|
||||
<br><br>
|
||||
32
docs/en/reference/models/sam/sam3/maskformer_segmentation.md
Normal file
32
docs/en/reference/models/sam/sam3/maskformer_segmentation.md
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/sam3/maskformer_segmentation.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/maskformer_segmentation.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/maskformer_segmentation.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.maskformer_segmentation.LinearPresenceHead
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.maskformer_segmentation.MaskPredictor
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.maskformer_segmentation.SegmentationHead
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.maskformer_segmentation.PixelDecoder
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.maskformer_segmentation.UniversalSegmentationHead
|
||||
|
||||
<br><br>
|
||||
32
docs/en/reference/models/sam/sam3/model_misc.md
Normal file
32
docs/en/reference/models/sam/sam3/model_misc.md
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/sam3/model_misc.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/model_misc.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/model_misc.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.model_misc.DotProductScoring
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.model_misc.LayerScale
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.model_misc.TransformerWrapper
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.model_misc.get_valid_ratio
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.model_misc.gen_sineembed_for_position
|
||||
|
||||
<br><br>
|
||||
16
docs/en/reference/models/sam/sam3/necks.md
Normal file
16
docs/en/reference/models/sam/sam3/necks.md
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/sam3/necks.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/necks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/necks.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.necks.Sam3DualViTDetNeck
|
||||
|
||||
<br><br>
|
||||
20
docs/en/reference/models/sam/sam3/sam3_image.md
Normal file
20
docs/en/reference/models/sam/sam3/sam3_image.md
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/sam3/sam3_image.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/sam3_image.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/sam3_image.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.sam3_image.SAM3SemanticModel
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.sam3_image._update_out
|
||||
|
||||
<br><br>
|
||||
32
docs/en/reference/models/sam/sam3/text_encoder_ve.md
Normal file
32
docs/en/reference/models/sam/sam3/text_encoder_ve.md
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/sam3/text_encoder_ve.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/text_encoder_ve.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/text_encoder_ve.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.text_encoder_ve.ResidualAttentionBlock
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.text_encoder_ve.Transformer
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.text_encoder_ve.TextTransformer
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.text_encoder_ve.VETextEncoder
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.text_encoder_ve.text_global_pool
|
||||
|
||||
<br><br>
|
||||
52
docs/en/reference/models/sam/sam3/tokenizer_ve.md
Normal file
52
docs/en/reference/models/sam/sam3/tokenizer_ve.md
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/sam3/tokenizer_ve.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/tokenizer_ve.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/tokenizer_ve.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.tokenizer_ve.SimpleTokenizer
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.tokenizer_ve.bytes_to_unicode
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.tokenizer_ve.get_pairs
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.tokenizer_ve.basic_clean
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.tokenizer_ve.whitespace_clean
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_canonicalize
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_lower
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_whitespace
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.tokenizer_ve.get_clean_fn
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.tokenizer_ve.canonicalize_text
|
||||
|
||||
<br><br>
|
||||
24
docs/en/reference/models/sam/sam3/vitdet.md
Normal file
24
docs/en/reference/models/sam/sam3/vitdet.md
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/sam3/vitdet.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vitdet.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vitdet.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.vitdet.Attention
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.vitdet.Block
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.vitdet.ViT
|
||||
|
||||
<br><br>
|
||||
16
docs/en/reference/models/sam/sam3/vl_combiner.md
Normal file
16
docs/en/reference/models/sam/sam3/vl_combiner.md
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
---
|
||||
description: TODO ADD DESCRIPTION
|
||||
keywords: TODO ADD KEYWORDS
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/sam/sam3/vl_combiner.py`
|
||||
|
||||
!!! success "Improvements"
|
||||
|
||||
This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vl_combiner.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vl_combiner.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) — thank you! 🙏
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.sam.sam3.vl_combiner.SAM3VLBackbone
|
||||
|
||||
<br><br>
|
||||
13
mkdocs.yml
13
mkdocs.yml
|
|
@ -592,6 +592,7 @@ nav:
|
|||
- sam:
|
||||
- amg: reference/models/sam/amg.md
|
||||
- build: reference/models/sam/build.md
|
||||
- build_sam3: reference/models/sam/build_sam3.md
|
||||
- model: reference/models/sam/model.md
|
||||
- modules:
|
||||
- blocks: reference/models/sam/modules/blocks.md
|
||||
|
|
@ -603,6 +604,18 @@ nav:
|
|||
- transformer: reference/models/sam/modules/transformer.md
|
||||
- utils: reference/models/sam/modules/utils.md
|
||||
- predict: reference/models/sam/predict.md
|
||||
- sam3:
|
||||
- decoder: reference/models/sam/sam3/decoder.md
|
||||
- encoder: reference/models/sam/sam3/encoder.md
|
||||
- geometry_encoders: reference/models/sam/sam3/geometry_encoders.md
|
||||
- maskformer_segmentation: reference/models/sam/sam3/maskformer_segmentation.md
|
||||
- model_misc: reference/models/sam/sam3/model_misc.md
|
||||
- necks: reference/models/sam/sam3/necks.md
|
||||
- sam3_image: reference/models/sam/sam3/sam3_image.md
|
||||
- text_encoder_ve: reference/models/sam/sam3/text_encoder_ve.md
|
||||
- tokenizer_ve: reference/models/sam/sam3/tokenizer_ve.md
|
||||
- vitdet: reference/models/sam/sam3/vitdet.md
|
||||
- vl_combiner: reference/models/sam/sam3/vl_combiner.md
|
||||
- utils:
|
||||
- loss: reference/models/utils/loss.md
|
||||
- ops: reference/models/utils/ops.md
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
__version__ = "8.3.236"
|
||||
__version__ = "8.3.237"
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -244,14 +244,15 @@ class BasePredictor:
|
|||
for _ in gen: # sourcery skip: remove-empty-nested-block, noqa
|
||||
pass
|
||||
|
||||
def setup_source(self, source):
|
||||
def setup_source(self, source, stride: int | None = None):
|
||||
"""Set up source and inference mode.
|
||||
|
||||
Args:
|
||||
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor): Source for
|
||||
inference.
|
||||
stride (int, optional): Model stride for image size checking.
|
||||
"""
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=stride or self.model.stride, min_dim=2) # check image size
|
||||
self.dataset = load_inference_source(
|
||||
source=source,
|
||||
batch=self.args.batch,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,16 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .model import SAM
|
||||
from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor
|
||||
from .predict import (
|
||||
Predictor,
|
||||
SAM2DynamicInteractivePredictor,
|
||||
SAM2Predictor,
|
||||
SAM2VideoPredictor,
|
||||
SAM3Predictor,
|
||||
SAM3SemanticPredictor,
|
||||
SAM3VideoPredictor,
|
||||
SAM3VideoSemanticPredictor,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"SAM",
|
||||
|
|
@ -9,4 +18,8 @@ __all__ = (
|
|||
"SAM2DynamicInteractivePredictor",
|
||||
"SAM2Predictor",
|
||||
"SAM2VideoPredictor",
|
||||
"SAM3Predictor",
|
||||
"SAM3SemanticPredictor",
|
||||
"SAM3VideoPredictor",
|
||||
"SAM3VideoSemanticPredictor",
|
||||
) # tuple or list of exportable items
|
||||
|
|
|
|||
|
|
@ -21,6 +21,21 @@ from .modules.tiny_encoder import TinyViT
|
|||
from .modules.transformer import TwoWayTransformer
|
||||
|
||||
|
||||
def _load_checkpoint(model, checkpoint):
|
||||
"""Load checkpoint into model from file path."""
|
||||
if checkpoint is None:
|
||||
return model
|
||||
|
||||
checkpoint = attempt_download_asset(checkpoint)
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch_load(f)
|
||||
# Handle nested "model" key
|
||||
if "model" in state_dict and isinstance(state_dict["model"], dict):
|
||||
state_dict = state_dict["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def build_sam_vit_h(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
|
||||
return _build_sam(
|
||||
|
|
@ -205,10 +220,7 @@ def _build_sam(
|
|||
pixel_std=[58.395, 57.12, 57.375],
|
||||
)
|
||||
if checkpoint is not None:
|
||||
checkpoint = attempt_download_asset(checkpoint)
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch_load(f)
|
||||
sam.load_state_dict(state_dict)
|
||||
sam = _load_checkpoint(sam, checkpoint)
|
||||
sam.eval()
|
||||
return sam
|
||||
|
||||
|
|
@ -299,10 +311,7 @@ def _build_sam2(
|
|||
)
|
||||
|
||||
if checkpoint is not None:
|
||||
checkpoint = attempt_download_asset(checkpoint)
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch_load(f)["model"]
|
||||
sam2.load_state_dict(state_dict)
|
||||
sam2 = _load_checkpoint(sam2, checkpoint)
|
||||
sam2.eval()
|
||||
return sam2
|
||||
|
||||
|
|
|
|||
374
ultralytics/models/sam/build_sam3.py
Normal file
374
ultralytics/models/sam/build_sam3.py
Normal file
|
|
@ -0,0 +1,374 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.nn.modules.transformer import MLP
|
||||
from ultralytics.utils.patches import torch_load
|
||||
|
||||
from .modules.blocks import PositionEmbeddingSine, RoPEAttention
|
||||
from .modules.encoders import MemoryEncoder
|
||||
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
|
||||
from .modules.sam import SAM3Model
|
||||
from .sam3.decoder import TransformerDecoder, TransformerDecoderLayer
|
||||
from .sam3.encoder import TransformerEncoderFusion, TransformerEncoderLayer
|
||||
from .sam3.geometry_encoders import SequenceGeometryEncoder
|
||||
from .sam3.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
|
||||
from .sam3.model_misc import DotProductScoring, TransformerWrapper
|
||||
from .sam3.necks import Sam3DualViTDetNeck
|
||||
from .sam3.sam3_image import SAM3SemanticModel
|
||||
from .sam3.text_encoder_ve import VETextEncoder
|
||||
from .sam3.tokenizer_ve import SimpleTokenizer
|
||||
from .sam3.vitdet import ViT
|
||||
from .sam3.vl_combiner import SAM3VLBackbone
|
||||
|
||||
|
||||
def _create_vision_backbone(compile_mode=None, enable_inst_interactivity=True) -> Sam3DualViTDetNeck:
|
||||
"""Create SAM3 visual backbone with ViT and neck."""
|
||||
# Position encoding
|
||||
position_encoding = PositionEmbeddingSine(
|
||||
num_pos_feats=256,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
)
|
||||
|
||||
# ViT backbone
|
||||
vit_backbone = ViT(
|
||||
img_size=1008,
|
||||
pretrain_img_size=336,
|
||||
patch_size=14,
|
||||
embed_dim=1024,
|
||||
depth=32,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.625,
|
||||
norm_layer="LayerNorm",
|
||||
drop_path_rate=0.1,
|
||||
qkv_bias=True,
|
||||
use_abs_pos=True,
|
||||
tile_abs_pos=True,
|
||||
global_att_blocks=(7, 15, 23, 31),
|
||||
rel_pos_blocks=(),
|
||||
use_rope=True,
|
||||
use_interp_rope=True,
|
||||
window_size=24,
|
||||
pretrain_use_cls_token=True,
|
||||
retain_cls_token=False,
|
||||
ln_pre=True,
|
||||
ln_post=False,
|
||||
return_interm_layers=False,
|
||||
bias_patch_embed=False,
|
||||
compile_mode=compile_mode,
|
||||
)
|
||||
return Sam3DualViTDetNeck(
|
||||
position_encoding=position_encoding,
|
||||
d_model=256,
|
||||
scale_factors=[4.0, 2.0, 1.0, 0.5],
|
||||
trunk=vit_backbone,
|
||||
add_sam2_neck=enable_inst_interactivity,
|
||||
)
|
||||
|
||||
|
||||
def _create_sam3_transformer() -> TransformerWrapper:
|
||||
"""Create SAM3 detector encoder and decoder."""
|
||||
encoder: TransformerEncoderFusion = TransformerEncoderFusion(
|
||||
layer=TransformerEncoderLayer(
|
||||
d_model=256,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
pos_enc_at_attn=True,
|
||||
pos_enc_at_cross_attn_keys=False,
|
||||
pos_enc_at_cross_attn_queries=False,
|
||||
pre_norm=True,
|
||||
self_attention=nn.MultiheadAttention(
|
||||
num_heads=8,
|
||||
dropout=0.1,
|
||||
embed_dim=256,
|
||||
batch_first=True,
|
||||
),
|
||||
cross_attention=nn.MultiheadAttention(
|
||||
num_heads=8,
|
||||
dropout=0.1,
|
||||
embed_dim=256,
|
||||
batch_first=True,
|
||||
),
|
||||
),
|
||||
num_layers=6,
|
||||
d_model=256,
|
||||
num_feature_levels=1,
|
||||
frozen=False,
|
||||
use_act_checkpoint=True,
|
||||
add_pooled_text_to_img_feat=False,
|
||||
pool_text_with_mask=True,
|
||||
)
|
||||
decoder: TransformerDecoder = TransformerDecoder(
|
||||
layer=TransformerDecoderLayer(
|
||||
d_model=256,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
cross_attention=nn.MultiheadAttention(
|
||||
num_heads=8,
|
||||
dropout=0.1,
|
||||
embed_dim=256,
|
||||
),
|
||||
n_heads=8,
|
||||
use_text_cross_attention=True,
|
||||
),
|
||||
num_layers=6,
|
||||
num_queries=200,
|
||||
return_intermediate=True,
|
||||
box_refine=True,
|
||||
num_o2m_queries=0,
|
||||
dac=True,
|
||||
boxRPB="log",
|
||||
d_model=256,
|
||||
frozen=False,
|
||||
interaction_layer=None,
|
||||
dac_use_selfatt_ln=True,
|
||||
use_act_checkpoint=True,
|
||||
presence_token=True,
|
||||
)
|
||||
|
||||
return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
|
||||
|
||||
|
||||
def build_sam3_image_model(
|
||||
checkpoint_path: str, bpe_path: str, enable_segmentation: bool = True, compile: bool = False
|
||||
):
|
||||
"""Build SAM3 image model.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Optional path to model checkpoint
|
||||
bpe_path: Path to the BPE tokenizer vocabulary
|
||||
enable_segmentation: Whether to enable segmentation head
|
||||
compile: To enable compilation, set to "default"
|
||||
|
||||
Returns:
|
||||
A SAM3 image model
|
||||
"""
|
||||
# Create visual components
|
||||
compile_mode = "default" if compile else None
|
||||
vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True)
|
||||
|
||||
# Create text components
|
||||
text_encoder = VETextEncoder(
|
||||
tokenizer=SimpleTokenizer(bpe_path=bpe_path),
|
||||
d_model=256,
|
||||
width=1024,
|
||||
heads=16,
|
||||
layers=24,
|
||||
)
|
||||
|
||||
# Create visual-language backbone
|
||||
backbone = SAM3VLBackbone(visual=vision_encoder, text=text_encoder, scalp=1)
|
||||
|
||||
# Create transformer components
|
||||
transformer = _create_sam3_transformer()
|
||||
|
||||
# Create dot product scoring
|
||||
dot_prod_scoring = DotProductScoring(
|
||||
d_model=256,
|
||||
d_proj=256,
|
||||
prompt_mlp=MLP(
|
||||
input_dim=256,
|
||||
hidden_dim=2048,
|
||||
output_dim=256,
|
||||
num_layers=2,
|
||||
residual=True,
|
||||
out_norm=nn.LayerNorm(256),
|
||||
),
|
||||
)
|
||||
|
||||
# Create segmentation head if enabled
|
||||
segmentation_head = (
|
||||
UniversalSegmentationHead(
|
||||
hidden_dim=256,
|
||||
upsampling_stages=3,
|
||||
aux_masks=False,
|
||||
presence_head=False,
|
||||
dot_product_scorer=None,
|
||||
act_ckpt=True,
|
||||
cross_attend_prompt=nn.MultiheadAttention(
|
||||
num_heads=8,
|
||||
dropout=0,
|
||||
embed_dim=256,
|
||||
),
|
||||
pixel_decoder=PixelDecoder(
|
||||
num_upsampling_stages=3,
|
||||
interpolation_mode="nearest",
|
||||
hidden_dim=256,
|
||||
compile_mode=compile_mode,
|
||||
),
|
||||
)
|
||||
if enable_segmentation
|
||||
else None
|
||||
)
|
||||
|
||||
# Create geometry encoder
|
||||
input_geometry_encoder = SequenceGeometryEncoder(
|
||||
pos_enc=PositionEmbeddingSine(
|
||||
num_pos_feats=256,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
),
|
||||
encode_boxes_as_points=False,
|
||||
boxes_direct_project=True,
|
||||
boxes_pool=True,
|
||||
boxes_pos_enc=True,
|
||||
d_model=256,
|
||||
num_layers=3,
|
||||
layer=TransformerEncoderLayer(
|
||||
d_model=256,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
pos_enc_at_attn=False,
|
||||
pre_norm=True,
|
||||
pos_enc_at_cross_attn_queries=False,
|
||||
pos_enc_at_cross_attn_keys=True,
|
||||
),
|
||||
use_act_ckpt=True,
|
||||
add_cls=True,
|
||||
add_post_encode_proj=True,
|
||||
)
|
||||
|
||||
# Create the SAM3SemanticModel model
|
||||
model = SAM3SemanticModel(
|
||||
backbone=backbone,
|
||||
transformer=transformer,
|
||||
input_geometry_encoder=input_geometry_encoder,
|
||||
segmentation_head=segmentation_head,
|
||||
num_feature_levels=1,
|
||||
o2m_mask_predict=True,
|
||||
dot_prod_scoring=dot_prod_scoring,
|
||||
use_instance_query=False,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
# Load checkpoint
|
||||
model = _load_checkpoint(model, checkpoint_path)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def build_interactive_sam3(checkpoint_path: str, compile=None, with_backbone=True) -> SAM3Model:
|
||||
"""Build the SAM3 Tracker module for video tracking.
|
||||
|
||||
Returns:
|
||||
Sam3TrackerPredictor: Wrapped SAM3 Tracker module
|
||||
"""
|
||||
# Create model components
|
||||
memory_encoder = MemoryEncoder(out_dim=64, interpol_size=[1152, 1152])
|
||||
memory_attention = MemoryAttention(
|
||||
batch_first=True,
|
||||
d_model=256,
|
||||
pos_enc_at_input=True,
|
||||
layer=MemoryAttentionLayer(
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
pos_enc_at_attn=False,
|
||||
pos_enc_at_cross_attn_keys=True,
|
||||
pos_enc_at_cross_attn_queries=False,
|
||||
self_attn=RoPEAttention(
|
||||
embedding_dim=256,
|
||||
num_heads=1,
|
||||
downsample_rate=1,
|
||||
rope_theta=10000.0,
|
||||
feat_sizes=[72, 72],
|
||||
),
|
||||
d_model=256,
|
||||
cross_attn=RoPEAttention(
|
||||
embedding_dim=256,
|
||||
num_heads=1,
|
||||
downsample_rate=1,
|
||||
kv_in_dim=64,
|
||||
rope_theta=10000.0,
|
||||
feat_sizes=[72, 72],
|
||||
rope_k_repeat=True,
|
||||
),
|
||||
),
|
||||
num_layers=4,
|
||||
)
|
||||
|
||||
backbone = (
|
||||
SAM3VLBackbone(scalp=1, visual=_create_vision_backbone(compile_mode=compile), text=None)
|
||||
if with_backbone
|
||||
else None
|
||||
)
|
||||
model = SAM3Model(
|
||||
image_size=1008,
|
||||
image_encoder=backbone,
|
||||
memory_attention=memory_attention,
|
||||
memory_encoder=memory_encoder,
|
||||
backbone_stride=14,
|
||||
num_maskmem=7,
|
||||
sigmoid_scale_for_mem_enc=20.0,
|
||||
sigmoid_bias_for_mem_enc=-10.0,
|
||||
use_mask_input_as_output_without_sam=True,
|
||||
directly_add_no_mem_embed=True,
|
||||
use_high_res_features_in_sam=True,
|
||||
multimask_output_in_sam=True,
|
||||
iou_prediction_use_sigmoid=True,
|
||||
use_obj_ptrs_in_encoder=True,
|
||||
add_tpos_enc_to_obj_ptrs=True,
|
||||
only_obj_ptrs_in_the_past_for_eval=True,
|
||||
pred_obj_scores=True,
|
||||
pred_obj_scores_mlp=True,
|
||||
fixed_no_obj_ptr=True,
|
||||
multimask_output_for_tracking=True,
|
||||
use_multimask_token_for_obj_ptr=True,
|
||||
multimask_min_pt_num=0,
|
||||
multimask_max_pt_num=1,
|
||||
use_mlp_for_obj_ptr_proj=True,
|
||||
compile_image_encoder=False,
|
||||
no_obj_embed_spatial=True,
|
||||
proj_tpos_enc_in_obj_ptrs=True,
|
||||
use_signed_tpos_enc_to_obj_ptrs=True,
|
||||
sam_mask_decoder_extra_args=dict(
|
||||
dynamic_multimask_via_stability=True,
|
||||
dynamic_multimask_stability_delta=0.05,
|
||||
dynamic_multimask_stability_thresh=0.98,
|
||||
),
|
||||
)
|
||||
|
||||
# Load checkpoint if provided
|
||||
model = _load_checkpoint(model, checkpoint_path, interactive=True)
|
||||
|
||||
# Setup device and mode
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def _load_checkpoint(model, checkpoint, interactive=False):
|
||||
"""Load SAM3 model checkpoint from file."""
|
||||
with open(checkpoint, "rb") as f:
|
||||
ckpt = torch_load(f)
|
||||
if "model" in ckpt and isinstance(ckpt["model"], dict):
|
||||
ckpt = ckpt["model"]
|
||||
sam3_image_ckpt = {k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k}
|
||||
if interactive:
|
||||
sam3_image_ckpt.update(
|
||||
{
|
||||
k.replace("backbone.vision_backbone", "image_encoder.vision_backbone"): v
|
||||
for k, v in sam3_image_ckpt.items()
|
||||
if "backbone.vision_backbone" in k
|
||||
}
|
||||
)
|
||||
sam3_image_ckpt.update(
|
||||
{
|
||||
k.replace("tracker.transformer.encoder", "memory_attention"): v
|
||||
for k, v in ckpt.items()
|
||||
if "tracker.transformer" in k
|
||||
}
|
||||
)
|
||||
sam3_image_ckpt.update(
|
||||
{
|
||||
k.replace("tracker.maskmem_backbone", "memory_encoder"): v
|
||||
for k, v in ckpt.items()
|
||||
if "tracker.maskmem_backbone" in k
|
||||
}
|
||||
)
|
||||
sam3_image_ckpt.update({k.replace("tracker.", ""): v for k, v in ckpt.items() if "tracker." in k})
|
||||
model.load_state_dict(sam3_image_ckpt, strict=False)
|
||||
return model
|
||||
|
|
@ -21,7 +21,7 @@ from pathlib import Path
|
|||
from ultralytics.engine.model import Model
|
||||
from ultralytics.utils.torch_utils import model_info
|
||||
|
||||
from .predict import Predictor, SAM2Predictor
|
||||
from .predict import Predictor, SAM2Predictor, SAM3Predictor
|
||||
|
||||
|
||||
class SAM(Model):
|
||||
|
|
@ -59,6 +59,7 @@ class SAM(Model):
|
|||
if model and Path(model).suffix not in {".pt", ".pth"}:
|
||||
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
||||
self.is_sam2 = "sam2" in Path(model).stem
|
||||
self.is_sam3 = "sam3" in Path(model).stem
|
||||
super().__init__(model=model, task="segment")
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
|
|
@ -72,9 +73,14 @@ class SAM(Model):
|
|||
>>> sam = SAM("sam_b.pt")
|
||||
>>> sam._load("path/to/custom_weights.pt")
|
||||
"""
|
||||
from .build import build_sam # slow import
|
||||
if self.is_sam3:
|
||||
from .build_sam3 import build_interactive_sam3
|
||||
|
||||
self.model = build_sam(weights)
|
||||
self.model = build_interactive_sam3(weights)
|
||||
else:
|
||||
from .build import build_sam # slow import
|
||||
|
||||
self.model = build_sam(weights)
|
||||
|
||||
def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""Perform segmentation prediction on the given image or video source.
|
||||
|
|
@ -158,4 +164,6 @@ class SAM(Model):
|
|||
>>> print(task_map)
|
||||
{'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
|
||||
"""
|
||||
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
|
||||
return {
|
||||
"segment": {"predictor": SAM2Predictor if self.is_sam2 else SAM3Predictor if self.is_sam3 else Predictor}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ class MaskDownSampler(nn.Module):
|
|||
padding: int = 0,
|
||||
total_stride: int = 16,
|
||||
activation: type[nn.Module] = nn.GELU,
|
||||
interpol_size: tuple[int, int] | None = None,
|
||||
):
|
||||
"""Initialize a mask downsampler module for progressive downsampling and channel expansion."""
|
||||
super().__init__()
|
||||
|
|
@ -102,9 +103,24 @@ class MaskDownSampler(nn.Module):
|
|||
mask_in_chans = mask_out_chans
|
||||
|
||||
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
||||
self.interpol_size = interpol_size
|
||||
if self.interpol_size is not None:
|
||||
assert isinstance(self.interpol_size, (list, tuple)), (
|
||||
f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
|
||||
)
|
||||
self.interpol_size = list(interpol_size)
|
||||
assert len(self.interpol_size) == 2
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
|
||||
if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]):
|
||||
x = F.interpolate(
|
||||
x.float(),
|
||||
size=self.interpol_size,
|
||||
align_corners=False,
|
||||
mode="bilinear",
|
||||
antialias=True,
|
||||
).to(x.dtype)
|
||||
return self.encoder(x)
|
||||
|
||||
|
||||
|
|
@ -429,13 +445,7 @@ class RoPEAttention(Attention):
|
|||
)
|
||||
|
||||
# Attention
|
||||
_, _, _, c_per_head = q.shape
|
||||
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
||||
attn = attn / math.sqrt(c_per_head)
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
# Get output
|
||||
out = attn @ v
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
|
@ -1033,6 +1043,7 @@ class PatchEmbed(nn.Module):
|
|||
padding: tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the PatchEmbed module for converting image patches to embeddings.
|
||||
|
||||
|
|
@ -1045,10 +1056,11 @@ class PatchEmbed(nn.Module):
|
|||
padding (tuple[int, int]): Padding applied to the input before convolution.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Dimensionality of the output patch embeddings.
|
||||
bias (bool): Whether to include a bias term in the convolutional layer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute patch embedding by applying convolution and transposing resulting tensor."""
|
||||
|
|
|
|||
|
|
@ -436,9 +436,8 @@ class SAM2MaskDecoder(nn.Module):
|
|||
def _get_stability_scores(self, mask_logits):
|
||||
"""Compute mask stability scores based on IoU between upper and lower thresholds."""
|
||||
mask_logits = mask_logits.flatten(-2)
|
||||
stability_delta = self.dynamic_multimask_stability_delta
|
||||
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
||||
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
||||
area_i = torch.sum(mask_logits > self.dynamic_multimask_stability_delta, dim=-1).float()
|
||||
area_u = torch.sum(mask_logits > -self.dynamic_multimask_stability_delta, dim=-1).float()
|
||||
return torch.where(area_u > 0, area_i / area_u, 1.0)
|
||||
|
||||
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
||||
|
|
|
|||
|
|
@ -361,6 +361,7 @@ class MemoryEncoder(nn.Module):
|
|||
self,
|
||||
out_dim,
|
||||
in_dim=256, # in_dim of pix_feats
|
||||
interpol_size: tuple[int, int] | None = None,
|
||||
):
|
||||
"""Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
|
||||
|
||||
|
|
@ -370,10 +371,12 @@ class MemoryEncoder(nn.Module):
|
|||
Args:
|
||||
out_dim (int): Output dimension of the encoded features.
|
||||
in_dim (int): Input dimension of the pixel features.
|
||||
interpol_size (tuple[int, int] | None): Size to interpolate masks to. If None, uses the size of pixel
|
||||
features.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
|
||||
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1, interpol_size=interpol_size)
|
||||
|
||||
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
||||
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
|
||||
|
|
|
|||
|
|
@ -59,6 +59,8 @@ class MemoryAttentionLayer(nn.Module):
|
|||
pos_enc_at_attn: bool = False,
|
||||
pos_enc_at_cross_attn_keys: bool = True,
|
||||
pos_enc_at_cross_attn_queries: bool = False,
|
||||
self_attn: nn.Module | None = None,
|
||||
cross_attn: nn.Module | None = None,
|
||||
):
|
||||
"""Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
|
||||
|
||||
|
|
@ -69,13 +71,15 @@ class MemoryAttentionLayer(nn.Module):
|
|||
pos_enc_at_attn (bool): Whether to add positional encoding at attention.
|
||||
pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
|
||||
pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
|
||||
self_attn (nn.Module | None): Custom self-attention module. If None, a default RoPEAttention is used.
|
||||
cross_attn (nn.Module | None): Custom cross-attention module. If None, a default RoPEAttention is used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.dim_feedforward = dim_feedforward
|
||||
self.dropout_value = dropout
|
||||
self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
|
||||
self.cross_attn_image = RoPEAttention(
|
||||
self.self_attn = self_attn or RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
|
||||
self.cross_attn_image = cross_attn or RoPEAttention(
|
||||
rope_k_repeat=True,
|
||||
embedding_dim=256,
|
||||
num_heads=1,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from torch.nn.init import trunc_normal_
|
|||
from ultralytics.nn.modules import MLP
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
from .blocks import SAM2TwoWayTransformer
|
||||
from .blocks import SAM2TwoWayTransformer, TwoWayTransformer
|
||||
from .decoders import MaskDecoder, SAM2MaskDecoder
|
||||
from .encoders import ImageEncoderViT, PromptEncoder
|
||||
from .utils import get_1d_sine_pe, select_closest_cond_frames
|
||||
|
|
@ -329,6 +329,7 @@ class SAM2Model(torch.nn.Module):
|
|||
|
||||
self._build_sam_heads()
|
||||
self.max_cond_frames_in_attn = max_cond_frames_in_attn
|
||||
self.add_all_frames_to_correct_as_cond = True
|
||||
|
||||
# Model compilation
|
||||
if compile_image_encoder:
|
||||
|
|
@ -473,7 +474,7 @@ class SAM2Model(torch.nn.Module):
|
|||
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
||||
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
||||
sam_mask_prompt = F.interpolate(
|
||||
mask_inputs.float(),
|
||||
mask_inputs.to(backbone_features.dtype),
|
||||
size=self.sam_prompt_encoder.mask_input_size,
|
||||
align_corners=False,
|
||||
mode="bilinear",
|
||||
|
|
@ -571,7 +572,7 @@ class SAM2Model(torch.nn.Module):
|
|||
# produce an object pointer using the SAM decoder from the mask input
|
||||
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
|
||||
backbone_features=backbone_features,
|
||||
mask_inputs=self.mask_downsample(mask_inputs_float),
|
||||
mask_inputs=self.mask_downsample(mask_inputs_float.to(backbone_features.dtype)),
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
|
||||
|
|
@ -818,7 +819,6 @@ class SAM2Model(torch.nn.Module):
|
|||
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
||||
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
|
||||
maskmem_features = maskmem_out["vision_features"]
|
||||
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
||||
# add a no-object embedding to the spatial memory to indicate that the frame
|
||||
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
||||
if self.no_obj_embed_spatial is not None:
|
||||
|
|
@ -827,7 +827,7 @@ class SAM2Model(torch.nn.Module):
|
|||
..., None, None
|
||||
].expand(*maskmem_features.shape)
|
||||
|
||||
return maskmem_features, maskmem_pos_enc
|
||||
return maskmem_features, maskmem_out["vision_pos_enc"]
|
||||
|
||||
def _track_step(
|
||||
self,
|
||||
|
|
@ -1005,7 +1005,151 @@ class SAM2Model(torch.nn.Module):
|
|||
|
||||
def set_imgsz(self, imgsz):
|
||||
"""Set image size to make model compatible with different image sizes."""
|
||||
if hasattr(self.image_encoder, "set_imgsz"):
|
||||
self.image_encoder.set_imgsz(imgsz)
|
||||
self.image_size = imgsz[0]
|
||||
self.sam_prompt_encoder.input_image_size = imgsz
|
||||
self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
|
||||
self.sam_prompt_encoder.image_embedding_size = [
|
||||
x // self.backbone_stride for x in imgsz
|
||||
] # fixed ViT patch size of 16
|
||||
self.sam_prompt_encoder.mask_input_size = [
|
||||
x // self.backbone_stride * 4 for x in imgsz
|
||||
] # fixed ViT patch size of 16
|
||||
self.sam_image_embedding_size = self.image_size // self.backbone_stride # update image embedding size
|
||||
|
||||
|
||||
class SAM3Model(SAM2Model):
|
||||
"""SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder,
|
||||
memory_attention,
|
||||
memory_encoder,
|
||||
num_maskmem=7,
|
||||
image_size=1008,
|
||||
backbone_stride=14,
|
||||
sigmoid_scale_for_mem_enc=1,
|
||||
sigmoid_bias_for_mem_enc=0,
|
||||
binarize_mask_from_pts_for_mem_enc=False,
|
||||
use_mask_input_as_output_without_sam=False,
|
||||
max_cond_frames_in_attn=-1,
|
||||
directly_add_no_mem_embed=False,
|
||||
use_high_res_features_in_sam=False,
|
||||
multimask_output_in_sam=False,
|
||||
multimask_min_pt_num=1,
|
||||
multimask_max_pt_num=1,
|
||||
multimask_output_for_tracking=False,
|
||||
use_multimask_token_for_obj_ptr: bool = False,
|
||||
iou_prediction_use_sigmoid=False,
|
||||
memory_temporal_stride_for_eval=1,
|
||||
non_overlap_masks_for_mem_enc=False,
|
||||
use_obj_ptrs_in_encoder=False,
|
||||
max_obj_ptrs_in_encoder=16,
|
||||
add_tpos_enc_to_obj_ptrs=True,
|
||||
proj_tpos_enc_in_obj_ptrs=False,
|
||||
use_signed_tpos_enc_to_obj_ptrs=False,
|
||||
only_obj_ptrs_in_the_past_for_eval=False,
|
||||
pred_obj_scores: bool = False,
|
||||
pred_obj_scores_mlp: bool = False,
|
||||
fixed_no_obj_ptr: bool = False,
|
||||
soft_no_obj_ptr: bool = False,
|
||||
use_mlp_for_obj_ptr_proj: bool = False,
|
||||
no_obj_embed_spatial: bool = False,
|
||||
sam_mask_decoder_extra_args=None,
|
||||
compile_image_encoder: bool = False,
|
||||
):
|
||||
"""SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
|
||||
super().__init__(
|
||||
image_encoder,
|
||||
memory_attention,
|
||||
memory_encoder,
|
||||
num_maskmem,
|
||||
image_size,
|
||||
backbone_stride,
|
||||
sigmoid_scale_for_mem_enc,
|
||||
sigmoid_bias_for_mem_enc,
|
||||
binarize_mask_from_pts_for_mem_enc,
|
||||
use_mask_input_as_output_without_sam,
|
||||
max_cond_frames_in_attn,
|
||||
directly_add_no_mem_embed,
|
||||
use_high_res_features_in_sam,
|
||||
multimask_output_in_sam,
|
||||
multimask_min_pt_num,
|
||||
multimask_max_pt_num,
|
||||
multimask_output_for_tracking,
|
||||
use_multimask_token_for_obj_ptr,
|
||||
iou_prediction_use_sigmoid,
|
||||
memory_temporal_stride_for_eval,
|
||||
non_overlap_masks_for_mem_enc,
|
||||
use_obj_ptrs_in_encoder,
|
||||
max_obj_ptrs_in_encoder,
|
||||
add_tpos_enc_to_obj_ptrs,
|
||||
proj_tpos_enc_in_obj_ptrs,
|
||||
use_signed_tpos_enc_to_obj_ptrs,
|
||||
only_obj_ptrs_in_the_past_for_eval,
|
||||
pred_obj_scores,
|
||||
pred_obj_scores_mlp,
|
||||
fixed_no_obj_ptr,
|
||||
soft_no_obj_ptr,
|
||||
use_mlp_for_obj_ptr_proj,
|
||||
no_obj_embed_spatial,
|
||||
sam_mask_decoder_extra_args,
|
||||
compile_image_encoder,
|
||||
)
|
||||
self.sam_mask_decoder = SAM2MaskDecoder(
|
||||
num_multimask_outputs=3,
|
||||
transformer=TwoWayTransformer(
|
||||
depth=2,
|
||||
embedding_dim=self.sam_prompt_embed_dim,
|
||||
mlp_dim=2048,
|
||||
num_heads=8,
|
||||
),
|
||||
transformer_dim=self.sam_prompt_embed_dim,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
use_high_res_features=self.use_high_res_features_in_sam,
|
||||
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
|
||||
pred_obj_scores=self.pred_obj_scores,
|
||||
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
|
||||
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
|
||||
**(self.sam_mask_decoder_extra_args or {}),
|
||||
)
|
||||
|
||||
def forward_image(self, img_batch: torch.Tensor):
|
||||
"""Process image batch through encoder to extract multi-level features for SAM model."""
|
||||
backbone_out = self.image_encoder.forward_image_sam2(img_batch)
|
||||
if self.use_high_res_features_in_sam:
|
||||
# precompute projected level 0 and level 1 features in SAM decoder
|
||||
# to avoid running it again on every SAM click
|
||||
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
||||
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
||||
return backbone_out
|
||||
|
||||
def set_imgsz(self, imgsz: tuple[int, int]):
|
||||
"""Set the image size for the model and mask downsampler."""
|
||||
super().set_imgsz(imgsz)
|
||||
self.memory_encoder.mask_downsampler.interpol_size = [size // 14 * 16 for size in imgsz]
|
||||
|
||||
@staticmethod
|
||||
def _suppress_shrinked_masks(pred_masks, new_pred_masks, shrink_threshold=0.3):
|
||||
"""Suppress masks that shrink in area after applying pixelwise non-overlapping constraints."""
|
||||
area_before = (pred_masks > 0).sum(dim=(-1, -2))
|
||||
area_after = (new_pred_masks > 0).sum(dim=(-1, -2))
|
||||
area_before = torch.clamp(area_before, min=1.0)
|
||||
area_ratio = area_after / area_before
|
||||
keep = area_ratio >= shrink_threshold
|
||||
keep_mask = keep[..., None, None].expand_as(pred_masks)
|
||||
pred_masks_after = torch.where(keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
||||
return pred_masks_after
|
||||
|
||||
def _suppress_object_pw_area_shrinkage(self, pred_masks):
|
||||
"""This function suppresses masks that shrink in area after applying pixelwise non-overlapping constraints. Note
|
||||
that the final output can still be overlapping.
|
||||
"""
|
||||
# Apply pixel-wise non-overlapping constraint based on mask scores
|
||||
pixel_level_non_overlapping_masks = self._apply_non_overlapping_constraints(pred_masks)
|
||||
# Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints
|
||||
# NOTE: The output of this function can be a no op if none of the masks shrinked by a large factor.
|
||||
pred_masks = self._suppress_shrinked_masks(pred_masks, pixel_level_non_overlapping_masks)
|
||||
return pred_masks
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
|
@ -86,7 +87,7 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
|
|||
return pos_embed
|
||||
|
||||
|
||||
def init_t_xy(end_x: int, end_y: int):
|
||||
def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0):
|
||||
"""Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
|
||||
|
||||
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
|
||||
|
|
@ -95,6 +96,8 @@ def init_t_xy(end_x: int, end_y: int):
|
|||
Args:
|
||||
end_x (int): Width of the grid (number of columns).
|
||||
end_y (int): Height of the grid (number of rows).
|
||||
scale (float): Scaling factor to apply to the coordinates.
|
||||
offset (int): Offset to add to the coordinates.
|
||||
|
||||
Returns:
|
||||
t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
|
||||
|
|
@ -110,10 +113,10 @@ def init_t_xy(end_x: int, end_y: int):
|
|||
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
||||
t_x = (t % end_x).float()
|
||||
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
||||
return t_x, t_y
|
||||
return t_x * scale + offset, t_y * scale + offset
|
||||
|
||||
|
||||
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
||||
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0, scale_pos: float = 1.0):
|
||||
"""Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
|
||||
|
||||
This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
|
||||
|
|
@ -124,6 +127,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|||
end_x (int): Width of the 2D grid.
|
||||
end_y (int): Height of the 2D grid.
|
||||
theta (float, optional): Scaling factor for frequency computation.
|
||||
scale_pos (float, optional): Scaling factor for position coordinates.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).
|
||||
|
|
@ -137,7 +141,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|||
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
||||
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
||||
|
||||
t_x, t_y = init_t_xy(end_x, end_y)
|
||||
t_x, t_y = init_t_xy(end_x, end_y, scale=scale_pos)
|
||||
freqs_x = torch.outer(t_x, freqs_x)
|
||||
freqs_y = torch.outer(t_y, freqs_y)
|
||||
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
||||
|
|
@ -375,3 +379,129 @@ def add_decomposed_rel_pos(
|
|||
)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
def get_abs_pos(
|
||||
abs_pos: torch.Tensor,
|
||||
has_cls_token: bool,
|
||||
hw: tuple[int, int],
|
||||
retain_cls_token: bool = False,
|
||||
tiling: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
|
||||
original embeddings.
|
||||
|
||||
Args:
|
||||
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
||||
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
||||
hw (Tuple): size of input image tokens.
|
||||
retain_cls_token: whether to retain the cls_token
|
||||
tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win)
|
||||
|
||||
Returns:
|
||||
Absolute positional embeddings after processing with shape (1, H, W, C),: if retain_cls_token is False,
|
||||
otherwise (1, 1+H*W, C).
|
||||
"""
|
||||
if retain_cls_token:
|
||||
assert has_cls_token
|
||||
|
||||
h, w = hw
|
||||
if has_cls_token:
|
||||
cls_pos = abs_pos[:, :1]
|
||||
abs_pos = abs_pos[:, 1:]
|
||||
|
||||
xy_num = abs_pos.shape[1]
|
||||
size = int(math.sqrt(xy_num))
|
||||
assert size * size == xy_num
|
||||
|
||||
if size != h or size != w:
|
||||
new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2)
|
||||
if tiling:
|
||||
new_abs_pos = new_abs_pos.tile([1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])])[
|
||||
:, :, :h, :w
|
||||
]
|
||||
else:
|
||||
new_abs_pos = F.interpolate(
|
||||
new_abs_pos,
|
||||
size=(h, w),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
if not retain_cls_token:
|
||||
return new_abs_pos.permute(0, 2, 3, 1)
|
||||
else:
|
||||
# add cls_token back, flatten spatial dims
|
||||
assert has_cls_token
|
||||
return torch.cat(
|
||||
[cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
else:
|
||||
if not retain_cls_token:
|
||||
return abs_pos.reshape(1, h, w, -1)
|
||||
else:
|
||||
assert has_cls_token
|
||||
return torch.cat([cls_pos, abs_pos], dim=1)
|
||||
|
||||
|
||||
def concat_rel_pos(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
q_hw: tuple[int, int],
|
||||
k_hw: tuple[int, int],
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
rescale: bool = False,
|
||||
relative_coords: torch.Tensor = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now effectively including rel pos biases.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): q tensor with shape (B, L_q, C).
|
||||
k (torch.Tensor): k tensor with shape (B, L_k, C).
|
||||
q_hw: These are spatial size of q tensors.
|
||||
k_hw: These are spatial size of k tensors.
|
||||
rel_pos_h: These are relative pos embeddings/params of height.
|
||||
rel_pos_w: These are relative pos embeddings/params of width.
|
||||
rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will scale by the wrong factor due to
|
||||
the concat.
|
||||
relative_coords (torch.Tensor, optional): Precomputed relative coords index tensor.
|
||||
|
||||
Returns:
|
||||
q, k: But, padded so that qk^T accounts for rel pos biases.
|
||||
"""
|
||||
q_h, q_w = q_hw
|
||||
k_h, k_w = k_hw
|
||||
|
||||
assert (q_h == q_w) and (k_h == k_w), "only square inputs supported"
|
||||
|
||||
if relative_coords is not None:
|
||||
Rh = rel_pos_h[relative_coords]
|
||||
Rw = rel_pos_w[relative_coords]
|
||||
else:
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
|
||||
old_scale = dim**0.5
|
||||
new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa
|
||||
# attn will be divided by new_scale, but we want to divide q by old_scale
|
||||
scale_ratio = new_scale / old_scale
|
||||
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w)
|
||||
|
||||
eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device)
|
||||
eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device)
|
||||
|
||||
eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h])
|
||||
eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w])
|
||||
|
||||
q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1)
|
||||
k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(B, k_h * k_w, -1)
|
||||
|
||||
return q, k
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
3
ultralytics/models/sam/sam3/__init__.py
Normal file
3
ultralytics/models/sam/sam3/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
546
ultralytics/models/sam/sam3/decoder.py
Normal file
546
ultralytics/models/sam/sam3/decoder.py
Normal file
|
|
@ -0,0 +1,546 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
"""
|
||||
Transformer decoder.
|
||||
Inspired from Pytorch's version, adds the pre-norm variant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchvision.ops.roi_align import RoIAlign
|
||||
|
||||
from ultralytics.nn.modules.transformer import MLP
|
||||
from ultralytics.nn.modules.utils import _get_clones, inverse_sigmoid
|
||||
from ultralytics.utils.ops import xywh2xyxy
|
||||
|
||||
from .model_misc import gen_sineembed_for_position
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
"""TransformerDecoderLayer is made up of self-attn, cross-attn, and feedforward network (FFN)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
dim_feedforward: int,
|
||||
dropout: float,
|
||||
cross_attention: nn.Module,
|
||||
n_heads: int,
|
||||
use_text_cross_attention: bool = False,
|
||||
):
|
||||
"""Initialize the TransformerDecoderLayer."""
|
||||
super().__init__()
|
||||
# cross attention
|
||||
self.cross_attn = cross_attention
|
||||
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
|
||||
# cross attention text
|
||||
self.use_text_cross_attention = use_text_cross_attention
|
||||
if use_text_cross_attention:
|
||||
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
||||
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.catext_norm = nn.LayerNorm(d_model)
|
||||
|
||||
# self attention
|
||||
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
||||
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
# ffn
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.activation = nn.ReLU()
|
||||
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
|
||||
@staticmethod
|
||||
def with_pos_embed(tensor, pos):
|
||||
"""Add positional embedding to the tensor."""
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_ffn(self, tgt):
|
||||
"""Feedforward network forward pass."""
|
||||
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout4(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
# for tgt
|
||||
tgt: torch.Tensor, # nq, bs, d_model
|
||||
tgt_query_pos: torch.Tensor = None, # pos for query. MLP(Sine(pos))
|
||||
memory_text: torch.Tensor = None, # num_token, bs, d_model
|
||||
text_attention_mask: torch.Tensor = None, # bs, num_token
|
||||
# for memory
|
||||
memory: torch.Tensor = None, # hw, bs, d_model
|
||||
memory_key_padding_mask: torch.Tensor = None,
|
||||
memory_pos: torch.Tensor = None, # pos for memory
|
||||
# sa
|
||||
self_attn_mask: torch.Tensor = None, # mask used for self-attention
|
||||
cross_attn_mask: torch.Tensor = None, # mask used for cross-attention
|
||||
# dac
|
||||
dac=False,
|
||||
dac_use_selfatt_ln=True,
|
||||
presence_token=None,
|
||||
# skip inside deformable attn
|
||||
**kwargs, # additional kwargs for compatibility
|
||||
):
|
||||
"""Input: - tgt/tgt_query_pos: nq, bs, d_model. -."""
|
||||
# self attention
|
||||
tgt, tgt_query_pos = self._apply_self_attention(
|
||||
tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask
|
||||
)
|
||||
|
||||
if self.use_text_cross_attention:
|
||||
tgt2 = self.ca_text(
|
||||
self.with_pos_embed(tgt, tgt_query_pos),
|
||||
memory_text.to(tgt.dtype),
|
||||
memory_text.to(tgt.dtype),
|
||||
key_padding_mask=text_attention_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.catext_dropout(tgt2)
|
||||
tgt = self.catext_norm(tgt)
|
||||
|
||||
if presence_token is not None:
|
||||
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
|
||||
cross_attn_mask = torch.cat([presence_token_mask, cross_attn_mask], dim=1) # (bs*nheads, 1+nq, hw)
|
||||
|
||||
# Cross attention to image
|
||||
tgt2 = self.cross_attn(
|
||||
query=self.with_pos_embed(tgt, tgt_query_pos),
|
||||
key=self.with_pos_embed(memory, memory_pos),
|
||||
value=memory,
|
||||
attn_mask=cross_attn_mask,
|
||||
key_padding_mask=(memory_key_padding_mask.transpose(0, 1) if memory_key_padding_mask is not None else None),
|
||||
need_weights=False,
|
||||
)[0]
|
||||
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
# ffn
|
||||
tgt = self.forward_ffn(tgt.to(memory.dtype))
|
||||
|
||||
presence_token_out = None
|
||||
if presence_token is not None:
|
||||
presence_token_out = tgt[:1]
|
||||
tgt = tgt[1:]
|
||||
|
||||
return tgt, presence_token_out
|
||||
|
||||
def _apply_self_attention(self, tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask):
|
||||
"""Apply self-attention with optional DAC splitting."""
|
||||
if self.self_attn is None:
|
||||
return tgt
|
||||
|
||||
if dac:
|
||||
# Split queries for DAC (detect-and-classify)
|
||||
assert tgt.shape[0] % 2 == 0, "DAC requires even number of queries"
|
||||
num_o2o_queries = tgt.shape[0] // 2
|
||||
tgt_o2o = tgt[:num_o2o_queries]
|
||||
tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
|
||||
tgt_o2m = tgt[num_o2o_queries:]
|
||||
else:
|
||||
tgt_o2o = tgt
|
||||
tgt_query_pos_o2o = tgt_query_pos
|
||||
|
||||
# Handle presence token
|
||||
if presence_token is not None:
|
||||
tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
|
||||
tgt_query_pos_o2o = torch.cat([torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0).to(
|
||||
tgt_o2o.dtype
|
||||
)
|
||||
tgt_query_pos = torch.cat([torch.zeros_like(presence_token), tgt_query_pos], dim=0)
|
||||
|
||||
# Self-attention
|
||||
q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
|
||||
tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0].to(tgt.dtype)
|
||||
tgt_o2o = tgt_o2o + self.dropout2(tgt2)
|
||||
|
||||
# Recombine and normalize
|
||||
if dac:
|
||||
if not dac_use_selfatt_ln:
|
||||
tgt_o2o = self.norm2(tgt_o2o)
|
||||
tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0)
|
||||
if dac_use_selfatt_ln:
|
||||
tgt = self.norm2(tgt)
|
||||
else:
|
||||
tgt = tgt_o2o
|
||||
tgt = self.norm2(tgt)
|
||||
|
||||
return tgt, tgt_query_pos
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
"""Transformer Decoder consisting of multiple layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
frozen: bool,
|
||||
interaction_layer,
|
||||
layer,
|
||||
num_layers: int,
|
||||
num_queries: int,
|
||||
return_intermediate: bool,
|
||||
box_refine: bool = False,
|
||||
num_o2m_queries: int = 0,
|
||||
dac: bool = False,
|
||||
boxRPB: str = "none",
|
||||
# Experimental: An object query for SAM 2 tasks
|
||||
instance_query: bool = False,
|
||||
# Defines the number of additional instance queries,
|
||||
# 1 or 4 are the most likely for single vs multi mask support
|
||||
num_instances: int = 1, # Irrelevant if instance_query is False
|
||||
dac_use_selfatt_ln: bool = True,
|
||||
use_act_checkpoint: bool = False,
|
||||
compile_mode=None,
|
||||
presence_token: bool = False,
|
||||
clamp_presence_logits: bool = True,
|
||||
clamp_presence_logit_max_val: float = 10.0,
|
||||
use_normed_output_consistently: bool = True,
|
||||
separate_box_head_instance: bool = False,
|
||||
separate_norm_instance: bool = False,
|
||||
):
|
||||
"""Initialize the TransformerDecoder."""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.layers = _get_clones(layer, num_layers)
|
||||
self.fine_layers = (
|
||||
_get_clones(interaction_layer, num_layers) if interaction_layer is not None else [None] * num_layers
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
self.num_queries = num_queries
|
||||
self.dac = dac
|
||||
if dac:
|
||||
self.num_o2m_queries = num_queries
|
||||
tot_num_queries = num_queries
|
||||
else:
|
||||
self.num_o2m_queries = num_o2m_queries
|
||||
tot_num_queries = num_queries + num_o2m_queries
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.return_intermediate = return_intermediate
|
||||
self.bbox_embed = MLP(d_model, d_model, 4, 3)
|
||||
self.query_embed = nn.Embedding(tot_num_queries, d_model)
|
||||
self.instance_query_embed = None
|
||||
self.instance_query_reference_points = None
|
||||
self.use_instance_query = instance_query
|
||||
self.num_instances = num_instances
|
||||
self.use_normed_output_consistently = use_normed_output_consistently
|
||||
|
||||
self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
|
||||
self.instance_bbox_embed = None
|
||||
if separate_box_head_instance:
|
||||
self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
|
||||
if instance_query:
|
||||
self.instance_query_embed = nn.Embedding(num_instances, d_model)
|
||||
self.box_refine = box_refine
|
||||
if box_refine:
|
||||
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
|
||||
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
|
||||
|
||||
self.reference_points = nn.Embedding(num_queries, 4)
|
||||
if instance_query:
|
||||
self.instance_reference_points = nn.Embedding(num_instances, 4)
|
||||
|
||||
assert boxRPB in ["none", "log", "linear", "both"]
|
||||
self.boxRPB = boxRPB
|
||||
if boxRPB != "none":
|
||||
try:
|
||||
nheads = self.layers[0].cross_attn_image.num_heads
|
||||
except AttributeError:
|
||||
nheads = self.layers[0].cross_attn.num_heads
|
||||
|
||||
n_input = 4 if boxRPB == "both" else 2
|
||||
self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
|
||||
self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
|
||||
self.compilable_cord_cache = None
|
||||
self.compilable_stored_size = None
|
||||
self.coord_cache = {}
|
||||
|
||||
self.roi_pooler = (
|
||||
RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
|
||||
if interaction_layer is not None
|
||||
else None
|
||||
)
|
||||
if frozen:
|
||||
for p in self.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.presence_token = None
|
||||
self.clamp_presence_logits = clamp_presence_logits
|
||||
self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
|
||||
if presence_token:
|
||||
self.presence_token = nn.Embedding(1, d_model)
|
||||
self.presence_token_head = MLP(d_model, d_model, 1, 3)
|
||||
self.presence_token_out_norm = nn.LayerNorm(d_model)
|
||||
|
||||
self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
|
||||
self.dac_use_selfatt_ln = dac_use_selfatt_ln
|
||||
self.use_act_checkpoint = use_act_checkpoint
|
||||
|
||||
nn.init.normal_(self.query_embed.weight.data)
|
||||
if self.instance_query_embed is not None:
|
||||
nn.init.normal_(self.instance_query_embed.weight.data)
|
||||
|
||||
assert self.roi_pooler is None
|
||||
assert self.return_intermediate, "support return_intermediate only"
|
||||
assert self.box_refine, "support box refine only"
|
||||
|
||||
self.compile_mode = compile_mode
|
||||
self.compiled = False
|
||||
# We defer compilation till after the first forward, to first warm-up the boxRPB cache
|
||||
|
||||
# assign layer index to each layer so that some layers can decide what to do
|
||||
# based on which layer index they are (e.g. cross attention to memory bank only
|
||||
# in selected layers)
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
layer.layer_idx = layer_idx
|
||||
|
||||
@staticmethod
|
||||
def _get_coords(H, W, device, dtype):
|
||||
"""Get normalized coordinates for height and width."""
|
||||
coords_h = torch.arange(0, H, dtype=dtype, device=device) / H
|
||||
coords_w = torch.arange(0, W, dtype=dtype, device=device) / W
|
||||
return coords_h, coords_w
|
||||
|
||||
def _get_rpb_matrix(self, reference_boxes, feat_size):
|
||||
"""Get the relative position bias (RPB) matrix for box-relative position bias."""
|
||||
H, W = feat_size
|
||||
boxes_xyxy = xywh2xyxy(reference_boxes).transpose(0, 1)
|
||||
bs, num_queries, _ = boxes_xyxy.shape
|
||||
if self.compilable_cord_cache is None:
|
||||
self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device, reference_boxes.dtype)
|
||||
self.compilable_stored_size = (H, W)
|
||||
|
||||
if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
|
||||
H,
|
||||
W,
|
||||
):
|
||||
# good, hitting the cache, will be compilable
|
||||
coords_h, coords_w = self.compilable_cord_cache
|
||||
else:
|
||||
# cache miss, will create compilation issue
|
||||
# In case we're not compiling, we'll still rely on the dict-based cache
|
||||
if feat_size not in self.coord_cache:
|
||||
self.coord_cache[feat_size] = self._get_coords(H, W, reference_boxes.device)
|
||||
coords_h, coords_w = self.coord_cache[feat_size]
|
||||
|
||||
assert coords_h.shape == (H,)
|
||||
assert coords_w.shape == (W,)
|
||||
|
||||
deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
|
||||
deltas_y = deltas_y.view(bs, num_queries, -1, 2)
|
||||
deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
|
||||
deltas_x = deltas_x.view(bs, num_queries, -1, 2)
|
||||
|
||||
if self.boxRPB in ["log", "both"]:
|
||||
deltas_x_log = deltas_x * 8 # normalize to -8, 8
|
||||
deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / np.log2(8)
|
||||
|
||||
deltas_y_log = deltas_y * 8 # normalize to -8, 8
|
||||
deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / np.log2(8)
|
||||
if self.boxRPB == "log":
|
||||
deltas_x = deltas_x_log
|
||||
deltas_y = deltas_y_log
|
||||
else:
|
||||
deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
|
||||
deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)
|
||||
|
||||
if self.training:
|
||||
assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
|
||||
deltas_x = self.boxRPB_embed_x(x=deltas_x) # bs, num_queries, W, n_heads
|
||||
deltas_y = self.boxRPB_embed_y(x=deltas_y) # bs, num_queries, H, n_heads
|
||||
|
||||
if not torch.compiler.is_dynamo_compiling():
|
||||
assert deltas_x.shape[:3] == (bs, num_queries, W)
|
||||
assert deltas_y.shape[:3] == (bs, num_queries, H)
|
||||
|
||||
B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(2) # bs, num_queries, H, W, n_heads
|
||||
if not torch.compiler.is_dynamo_compiling():
|
||||
assert B.shape[:4] == (bs, num_queries, H, W)
|
||||
B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
|
||||
B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
|
||||
B = B.contiguous() # memeff attn likes ordered strides
|
||||
if not torch.compiler.is_dynamo_compiling():
|
||||
assert B.shape[2:] == (num_queries, H * W)
|
||||
return B
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: torch.Tensor = None,
|
||||
memory_mask: torch.Tensor = None,
|
||||
memory_key_padding_mask: torch.Tensor = None,
|
||||
pos: torch.Tensor = None,
|
||||
reference_boxes: torch.Tensor = None, # num_queries, bs, 4
|
||||
# for memory
|
||||
spatial_shapes: torch.Tensor = None, # bs, num_levels, 2
|
||||
valid_ratios: torch.Tensor = None,
|
||||
# for text
|
||||
memory_text: torch.Tensor = None,
|
||||
text_attention_mask: torch.Tensor = None,
|
||||
# if `apply_dac` is None, it will default to `self.dac`
|
||||
apply_dac: bool | None = None,
|
||||
is_instance_prompt=False,
|
||||
decoder_extra_kwargs: dict | None = None,
|
||||
# ROI memory bank
|
||||
obj_roi_memory_feat=None,
|
||||
obj_roi_memory_mask=None,
|
||||
box_head_trk=None,
|
||||
):
|
||||
"""Forward pass of the TransformerDecoder."""
|
||||
if memory_mask is not None:
|
||||
assert self.boxRPB == "none", (
|
||||
"inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
|
||||
)
|
||||
|
||||
apply_dac = apply_dac if apply_dac is not None else self.dac
|
||||
if apply_dac:
|
||||
assert (tgt.shape[0] == self.num_queries) or (
|
||||
self.use_instance_query and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
|
||||
)
|
||||
|
||||
tgt = tgt.repeat(2, 1, 1)
|
||||
# note that we don't tile tgt_mask, since DAC doesn't
|
||||
# use self-attention in o2m queries
|
||||
if reference_boxes is not None:
|
||||
assert (reference_boxes.shape[0] == self.num_queries) or (
|
||||
self.use_instance_query and (reference_boxes.shape[0] == self.instance_query_embed.num_embeddings)
|
||||
)
|
||||
reference_boxes = reference_boxes.repeat(2, 1, 1)
|
||||
|
||||
bs = tgt.shape[1]
|
||||
intermediate = []
|
||||
intermediate_presence_logits = []
|
||||
presence_feats = None
|
||||
|
||||
if self.box_refine:
|
||||
if reference_boxes is None:
|
||||
# In this case, we're in a one-stage model, so we generate the reference boxes
|
||||
reference_boxes = self.reference_points.weight.unsqueeze(1)
|
||||
reference_boxes = reference_boxes.repeat(2, bs, 1) if apply_dac else reference_boxes.repeat(1, bs, 1)
|
||||
reference_boxes = reference_boxes.sigmoid()
|
||||
intermediate_ref_boxes = [reference_boxes]
|
||||
else:
|
||||
reference_boxes = None
|
||||
intermediate_ref_boxes = None
|
||||
|
||||
output = tgt
|
||||
presence_out = None
|
||||
if self.presence_token is not None and is_instance_prompt is False:
|
||||
# expand to batch dim
|
||||
presence_out = self.presence_token.weight[None].expand(1, bs, -1)
|
||||
|
||||
box_head = self.bbox_embed
|
||||
if is_instance_prompt and self.instance_bbox_embed is not None:
|
||||
box_head = self.instance_bbox_embed
|
||||
|
||||
out_norm = self.norm
|
||||
if is_instance_prompt and self.instance_norm is not None:
|
||||
out_norm = self.instance_norm
|
||||
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
reference_points_input = (
|
||||
reference_boxes[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
|
||||
) # nq, bs, nlevel, 4
|
||||
|
||||
query_sine_embed = gen_sineembed_for_position(
|
||||
reference_points_input[:, :, 0, :], self.d_model
|
||||
) # nq, bs, d_model*2
|
||||
|
||||
# conditional query
|
||||
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
|
||||
|
||||
if self.boxRPB != "none" and reference_boxes is not None:
|
||||
assert spatial_shapes.shape[0] == 1, "only single scale support implemented"
|
||||
memory_mask = self._get_rpb_matrix(
|
||||
reference_boxes,
|
||||
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
|
||||
)
|
||||
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
|
||||
if self.training:
|
||||
assert self.use_act_checkpoint, "Activation checkpointing not enabled in the decoder"
|
||||
output, presence_out = layer(
|
||||
tgt=output,
|
||||
tgt_query_pos=query_pos,
|
||||
memory_text=memory_text,
|
||||
text_attention_mask=text_attention_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
memory_pos=pos,
|
||||
self_attn_mask=tgt_mask,
|
||||
cross_attn_mask=memory_mask,
|
||||
dac=apply_dac,
|
||||
dac_use_selfatt_ln=self.dac_use_selfatt_ln,
|
||||
presence_token=presence_out,
|
||||
**(decoder_extra_kwargs or {}),
|
||||
# ROI memory bank
|
||||
obj_roi_memory_feat=obj_roi_memory_feat,
|
||||
obj_roi_memory_mask=obj_roi_memory_mask,
|
||||
)
|
||||
|
||||
# iter update
|
||||
if self.box_refine:
|
||||
reference_before_sigmoid = inverse_sigmoid(reference_boxes)
|
||||
if box_head_trk is None:
|
||||
# delta_unsig = self.bbox_embed(output)
|
||||
if not self.use_normed_output_consistently:
|
||||
delta_unsig = box_head(output)
|
||||
else:
|
||||
delta_unsig = box_head(out_norm(output))
|
||||
else:
|
||||
# box_head_trk use a separate box head for tracking queries
|
||||
Q_det = decoder_extra_kwargs["Q_det"]
|
||||
assert output.size(0) >= Q_det
|
||||
delta_unsig_det = self.bbox_embed(output[:Q_det])
|
||||
delta_unsig_trk = box_head_trk(output[Q_det:])
|
||||
delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
|
||||
outputs_unsig = delta_unsig + reference_before_sigmoid
|
||||
new_reference_points = outputs_unsig.sigmoid()
|
||||
|
||||
reference_boxes = new_reference_points.detach()
|
||||
if layer_idx != self.num_layers - 1:
|
||||
intermediate_ref_boxes.append(new_reference_points)
|
||||
else:
|
||||
raise NotImplementedError("not implemented yet")
|
||||
|
||||
intermediate.append(out_norm(output))
|
||||
if self.presence_token is not None and is_instance_prompt is False:
|
||||
# norm, mlp head
|
||||
intermediate_layer_presence_logits = self.presence_token_head(
|
||||
self.presence_token_out_norm(presence_out)
|
||||
).squeeze(-1)
|
||||
|
||||
# clamp to mitigate numerical issues
|
||||
if self.clamp_presence_logits:
|
||||
intermediate_layer_presence_logits.clamp(
|
||||
min=-self.clamp_presence_logit_max_val,
|
||||
max=self.clamp_presence_logit_max_val,
|
||||
)
|
||||
|
||||
intermediate_presence_logits.append(intermediate_layer_presence_logits)
|
||||
presence_feats = presence_out.clone()
|
||||
|
||||
if not self.compiled and self.compile_mode is not None:
|
||||
self.forward = torch.compile(self.forward, mode=self.compile_mode, fullgraph=True)
|
||||
self.compiled = True
|
||||
|
||||
return (
|
||||
torch.stack(intermediate),
|
||||
torch.stack(intermediate_ref_boxes),
|
||||
(
|
||||
torch.stack(intermediate_presence_logits)
|
||||
if self.presence_token is not None and is_instance_prompt is False
|
||||
else None
|
||||
),
|
||||
presence_feats,
|
||||
)
|
||||
535
ultralytics/models/sam/sam3/encoder.py
Normal file
535
ultralytics/models/sam/sam3/encoder.py
Normal file
|
|
@ -0,0 +1,535 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
# Based on https://github.com/IDEA-Research/GroundingDINO
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ultralytics.nn.modules.utils import _get_clones
|
||||
|
||||
from .model_misc import get_valid_ratio
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""Transformer encoder layer that performs self-attention followed by cross-attention.
|
||||
|
||||
This layer was previously called TransformerDecoderLayer but was renamed to better reflect its role in the
|
||||
architecture. It processes input sequences through self-attention and then cross-attention with another input
|
||||
(typically image features).
|
||||
|
||||
The layer supports both pre-norm and post-norm configurations, as well as positional encoding at different stages of
|
||||
the attention mechanism.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
dim_feedforward: int,
|
||||
dropout: float,
|
||||
pos_enc_at_attn: bool,
|
||||
pos_enc_at_cross_attn_keys: bool,
|
||||
pos_enc_at_cross_attn_queries: bool,
|
||||
pre_norm: bool,
|
||||
self_attention: nn.Module = None,
|
||||
cross_attention: nn.Module = None,
|
||||
):
|
||||
"""Initialize a transformer encoder layer.
|
||||
|
||||
Args:
|
||||
cross_attention: Cross-attention module for attending to image features
|
||||
d_model: Model dimension/hidden size
|
||||
dim_feedforward: Dimension of the feedforward network
|
||||
dropout: Dropout probability
|
||||
pos_enc_at_attn: Whether to add positional encodings at self-attention
|
||||
pos_enc_at_cross_attn_keys: Whether to add positional encodings to keys in cross-attention
|
||||
pos_enc_at_cross_attn_queries: Whether to add positional encodings to queries in cross-attention
|
||||
pre_norm: Whether to use pre-norm (True) or post-norm (False) architecture
|
||||
self_attention: Self-attention module
|
||||
"""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.dim_feedforward = dim_feedforward
|
||||
self.dropout_value = dropout
|
||||
self.self_attn = self_attention or nn.MultiheadAttention(num_heads=8, dropout=0.1, embed_dim=256)
|
||||
self.cross_attn_image = cross_attention or nn.MultiheadAttention(num_heads=8, dropout=0.1, embed_dim=256)
|
||||
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = nn.ReLU()
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
self.pos_enc_at_attn = pos_enc_at_attn
|
||||
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
||||
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
||||
|
||||
self.layer_idx = None
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
tgt_mask: torch.Tensor = None,
|
||||
memory_mask: torch.Tensor = None,
|
||||
tgt_key_padding_mask: torch.Tensor = None,
|
||||
memory_key_padding_mask: torch.Tensor = None,
|
||||
pos: torch.Tensor = None,
|
||||
query_pos: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for post-norm architecture.
|
||||
|
||||
In post-norm architecture, normalization is applied after attention and feedforward operations.
|
||||
|
||||
Args:
|
||||
tgt: Input tensor to be processed
|
||||
memory: Memory tensor for cross-attention
|
||||
tgt_mask: Mask for self-attention
|
||||
memory_mask: Mask for cross-attention
|
||||
tgt_key_padding_mask: Key padding mask for self-attention
|
||||
memory_key_padding_mask: Key padding mask for cross-attention
|
||||
pos: Positional encoding for memory
|
||||
query_pos: Positional encoding for query
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Processed tensor
|
||||
"""
|
||||
q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
|
||||
|
||||
# Self attention
|
||||
tgt2 = self.self_attn(
|
||||
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask, need_weights=False
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
# Cross attention to image
|
||||
tgt2 = self.cross_attn_image(
|
||||
query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
|
||||
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
need_weights=False,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
|
||||
# FFN
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
dac: bool = False,
|
||||
tgt_mask: torch.Tensor = None,
|
||||
memory_mask: torch.Tensor = None,
|
||||
tgt_key_padding_mask: torch.Tensor = None,
|
||||
memory_key_padding_mask: torch.Tensor = None,
|
||||
pos: torch.Tensor = None,
|
||||
query_pos: torch.Tensor = None,
|
||||
# **kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for pre-norm architecture.
|
||||
|
||||
In pre-norm architecture, normalization is applied before attention and feedforward operations.
|
||||
|
||||
Args:
|
||||
tgt: Input tensor to be processed
|
||||
memory: Memory tensor for cross-attention
|
||||
dac: Whether to use Divide-and-Conquer attention
|
||||
tgt_mask: Mask for self-attention
|
||||
memory_mask: Mask for cross-attention
|
||||
tgt_key_padding_mask: Key padding mask for self-attention
|
||||
memory_key_padding_mask: Key padding mask for cross-attention
|
||||
pos: Positional encoding for memory
|
||||
query_pos: Positional encoding for query
|
||||
attn_bias: Optional attention bias tensor
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Processed tensor
|
||||
"""
|
||||
if dac:
|
||||
# we only apply self attention to the first half of the queries
|
||||
assert tgt.shape[0] % 2 == 0
|
||||
other_tgt = tgt[tgt.shape[0] // 2 :]
|
||||
tgt = tgt[: tgt.shape[0] // 2]
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
if dac:
|
||||
# Recombine
|
||||
tgt = torch.cat((tgt, other_tgt), dim=0)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.cross_attn_image(
|
||||
query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
||||
key=memory.to(tgt2.dtype) + pos if self.pos_enc_at_cross_attn_keys else memory.to(tgt2.dtype),
|
||||
value=memory.to(tgt2.dtype),
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
# attn_bias=attn_bias,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
dac: bool = False,
|
||||
tgt_mask: torch.Tensor = None,
|
||||
memory_mask: torch.Tensor = None,
|
||||
tgt_key_padding_mask: torch.Tensor = None,
|
||||
memory_key_padding_mask: torch.Tensor = None,
|
||||
pos: torch.Tensor = None,
|
||||
query_pos: torch.Tensor = None,
|
||||
# **kwds: Any,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for the transformer encoder layer.
|
||||
|
||||
Args:
|
||||
tgt: Input tensor to be processed
|
||||
memory: Memory tensor (e.g., image features) for cross-attention
|
||||
dac: Whether to use Divide-and-Conquer attention (only apply self-attention to first half)
|
||||
tgt_mask: Mask for self-attention
|
||||
memory_mask: Mask for cross-attention
|
||||
tgt_key_padding_mask: Key padding mask for self-attention
|
||||
memory_key_padding_mask: Key padding mask for cross-attention
|
||||
pos: Positional encoding for memory
|
||||
query_pos: Positional encoding for query
|
||||
attn_bias: Optional attention bias tensor
|
||||
**kwds: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Processed tensor after self-attention, cross-attention, and feedforward network
|
||||
"""
|
||||
fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
|
||||
return fwd_fn(
|
||||
tgt,
|
||||
memory,
|
||||
dac=dac,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
# attn_bias=attn_bias,
|
||||
# **kwds,
|
||||
)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
"""Transformer encoder that processes multi-level features.
|
||||
|
||||
This encoder takes multi-level features (e.g., from a backbone network) and processes them through a stack of
|
||||
transformer encoder layers. It supports features from multiple levels (e.g., different resolutions) and can apply
|
||||
activation checkpointing for memory efficiency during training.
|
||||
|
||||
Args:
|
||||
layer: The encoder layer to be stacked multiple times
|
||||
num_layers: Number of encoder layers to stack
|
||||
d_model: Model dimension/hidden size
|
||||
num_feature_levels: Number of feature levels to process
|
||||
frozen: Whether to freeze the parameters of this module
|
||||
use_act_checkpoint: Whether to use activation checkpointing during training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
num_layers: int,
|
||||
d_model: int,
|
||||
num_feature_levels: int,
|
||||
frozen: bool = False,
|
||||
use_act_checkpoint: bool = False,
|
||||
):
|
||||
"""Initialize the transformer encoder."""
|
||||
super().__init__()
|
||||
self.layers = _get_clones(layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.num_feature_levels = num_feature_levels
|
||||
self.level_embed = None
|
||||
if num_feature_levels > 1:
|
||||
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
||||
|
||||
if frozen:
|
||||
for p in self.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.use_act_checkpoint = use_act_checkpoint
|
||||
|
||||
# assign layer index to each layer so that some layers can decide what to do
|
||||
# based on which layer index they are (e.g. cross attention to memory bank only
|
||||
# in selected layers)
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
layer.layer_idx = layer_idx
|
||||
|
||||
def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
|
||||
"""Prepare multi-level features for transformer encoder."""
|
||||
assert len(srcs) == self.num_feature_levels, "mismatch between expected and received # of feature levels"
|
||||
|
||||
src_flatten = []
|
||||
mask_flatten = []
|
||||
lvl_pos_embed_flatten = []
|
||||
spatial_shapes = []
|
||||
has_mask = masks is not None and masks[0] is not None
|
||||
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
|
||||
_, _, h, w = src.shape
|
||||
spatial_shape = (h, w)
|
||||
spatial_shapes.append(spatial_shape)
|
||||
|
||||
src = src.flatten(2).transpose(1, 2) # bs, hw, c
|
||||
if has_mask:
|
||||
mask = mask.flatten(1)
|
||||
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
|
||||
if self.level_embed is not None:
|
||||
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
|
||||
else:
|
||||
lvl_pos_embed = pos_embed
|
||||
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
||||
src_flatten.append(src)
|
||||
if has_mask:
|
||||
mask_flatten.append(mask)
|
||||
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
|
||||
mask_flatten = torch.cat(mask_flatten, 1) if has_mask else None # bs, \sum{hxw}
|
||||
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
|
||||
spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
|
||||
level_start_index = torch.cat(
|
||||
(
|
||||
spatial_shapes.new_zeros((1,)),
|
||||
spatial_shapes.prod(1).cumsum(0)[:-1],
|
||||
)
|
||||
)
|
||||
if has_mask:
|
||||
valid_ratios = torch.stack([get_valid_ratio(m) for m in masks], 1)
|
||||
else:
|
||||
valid_ratios = torch.ones(
|
||||
(src_flatten.shape[0], self.num_feature_levels, 2),
|
||||
device=src_flatten.device,
|
||||
dtype=src_flatten.dtype,
|
||||
)
|
||||
|
||||
return (
|
||||
src_flatten,
|
||||
mask_flatten,
|
||||
lvl_pos_embed_flatten,
|
||||
level_start_index,
|
||||
valid_ratios,
|
||||
spatial_shapes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: list[torch.Tensor],
|
||||
src_key_padding_masks: list[torch.Tensor] | None = None,
|
||||
pos: list[torch.Tensor] | None = None,
|
||||
prompt: torch.Tensor = None,
|
||||
prompt_key_padding_mask: torch.Tensor = None,
|
||||
encoder_extra_kwargs: dict | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Process multi-level features through the transformer encoder.
|
||||
|
||||
Args:
|
||||
src: List of multi-level features, each with shape (batch_size, channels, height, width)
|
||||
src_key_padding_masks: List of padding masks for each feature level, each with shape (batch_size, height,
|
||||
width)
|
||||
pos: List of positional embeddings for each feature level, each with shape (batch_size, channels, height,
|
||||
width)
|
||||
prompt: Optional text/prompt features to attend to, with shape (seq_len, batch_size, d_model)
|
||||
prompt_key_padding_mask: Optional padding mask for prompt, with shape (batch_size, seq_len)
|
||||
encoder_extra_kwargs: Optional additional arguments to pass to each encoder layer
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- output: Processed features with shape (seq_len, batch_size, d_model)
|
||||
- key_padding_masks_flatten: Flattened padding masks
|
||||
- lvl_pos_embed_flatten: Flattened positional embeddings
|
||||
- level_start_index: Starting indices for each feature level
|
||||
- spatial_shapes: Spatial dimensions of each feature level
|
||||
- valid_ratios: Valid ratios for each feature level
|
||||
"""
|
||||
assert len(src) == self.num_feature_levels, "must be equal to num_feature_levels"
|
||||
if src_key_padding_masks is not None:
|
||||
assert len(src_key_padding_masks) == self.num_feature_levels
|
||||
if pos is not None:
|
||||
assert len(pos) == self.num_feature_levels
|
||||
# Flatten multilevel feats and add level pos embeds
|
||||
(
|
||||
src_flatten,
|
||||
key_padding_masks_flatten,
|
||||
lvl_pos_embed_flatten,
|
||||
level_start_index,
|
||||
valid_ratios,
|
||||
spatial_shapes,
|
||||
) = self._prepare_multilevel_features(src, src_key_padding_masks, pos)
|
||||
|
||||
output = src_flatten
|
||||
for layer in self.layers:
|
||||
layer_kwargs = {}
|
||||
|
||||
assert isinstance(layer, TransformerEncoderLayer)
|
||||
layer_kwargs["memory"] = prompt
|
||||
layer_kwargs["memory_key_padding_mask"] = prompt_key_padding_mask
|
||||
layer_kwargs["query_pos"] = lvl_pos_embed_flatten
|
||||
layer_kwargs["tgt"] = output
|
||||
layer_kwargs["tgt_key_padding_mask"] = key_padding_masks_flatten
|
||||
|
||||
if self.training:
|
||||
assert self.use_act_checkpoint, "activation ckpt not enabled in encoder"
|
||||
if encoder_extra_kwargs is not None:
|
||||
layer_kwargs.update(encoder_extra_kwargs)
|
||||
output = layer(**layer_kwargs)
|
||||
# return as seq first
|
||||
return (
|
||||
output.transpose(0, 1),
|
||||
(key_padding_masks_flatten.transpose(0, 1) if key_padding_masks_flatten is not None else None),
|
||||
lvl_pos_embed_flatten.transpose(0, 1),
|
||||
level_start_index,
|
||||
spatial_shapes,
|
||||
valid_ratios,
|
||||
)
|
||||
|
||||
|
||||
class TransformerEncoderFusion(TransformerEncoder):
|
||||
"""Transformer encoder that fuses text and image features.
|
||||
|
||||
This encoder extends TransformerEncoder to handle both text and image features, with the ability to add pooled text
|
||||
features to image features for better cross-modal fusion. It supports torch.compile for performance optimization.
|
||||
|
||||
Args:
|
||||
layer: The encoder layer to be stacked multiple times
|
||||
num_layers: Number of encoder layers to stack
|
||||
d_model: Model dimension/hidden size
|
||||
num_feature_levels: Number of feature levels to process
|
||||
add_pooled_text_to_img_feat: Whether to add pooled text features to image features
|
||||
pool_text_with_mask: Whether to use the mask when pooling text features
|
||||
compile_mode: Mode for torch.compile, or None to disable compilation
|
||||
**kwargs: Additional arguments to pass to the parent class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
num_layers: int,
|
||||
d_model: int,
|
||||
num_feature_levels: int,
|
||||
add_pooled_text_to_img_feat: bool = True,
|
||||
pool_text_with_mask: bool = False,
|
||||
compile_mode: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the transformer encoder with text-image fusion."""
|
||||
super().__init__(
|
||||
layer,
|
||||
num_layers,
|
||||
d_model,
|
||||
num_feature_levels,
|
||||
**kwargs,
|
||||
)
|
||||
self.add_pooled_text_to_img_feat = add_pooled_text_to_img_feat
|
||||
if self.add_pooled_text_to_img_feat:
|
||||
self.text_pooling_proj = nn.Linear(d_model, d_model)
|
||||
self.pool_text_with_mask = pool_text_with_mask
|
||||
if compile_mode is not None:
|
||||
self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: list[torch.Tensor],
|
||||
prompt: torch.Tensor,
|
||||
src_key_padding_mask: list[torch.Tensor] | None = None,
|
||||
src_pos: list[torch.Tensor] | None = None,
|
||||
prompt_key_padding_mask: torch.Tensor = None,
|
||||
feat_sizes: list[int] | None = None,
|
||||
encoder_extra_kwargs: dict | None = None,
|
||||
):
|
||||
"""Forward pass for the transformer encoder with text-image fusion."""
|
||||
# Restore spatial shapes of vision
|
||||
bs = src[0].shape[1] # seq first
|
||||
if feat_sizes is not None:
|
||||
assert len(feat_sizes) == len(src)
|
||||
if src_key_padding_mask is None:
|
||||
src_key_padding_mask = [None] * len(src)
|
||||
for i, (h, w) in enumerate(feat_sizes):
|
||||
src[i] = src[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
|
||||
src_pos[i] = src_pos[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
|
||||
src_key_padding_mask[i] = (
|
||||
src_key_padding_mask[i].reshape(h, w, bs).permute(2, 0, 1)
|
||||
if src_key_padding_mask[i] is not None
|
||||
else None
|
||||
)
|
||||
else:
|
||||
assert all(x.dim == 4 for x in src), "expected list of (bs, c, h, w) tensors"
|
||||
|
||||
if self.add_pooled_text_to_img_feat:
|
||||
# Fusion: Add mean pooled text to image features
|
||||
pooled_text = pool_text_feat(prompt, prompt_key_padding_mask, self.pool_text_with_mask)
|
||||
pooled_text = self.text_pooling_proj(pooled_text)[..., None, None] # prompt is seq first
|
||||
src = [x.add_(pooled_text) for x in src]
|
||||
|
||||
(
|
||||
out,
|
||||
key_padding_masks_flatten,
|
||||
lvl_pos_embed_flatten,
|
||||
level_start_index,
|
||||
spatial_shapes,
|
||||
valid_ratios,
|
||||
) = super().forward(
|
||||
src,
|
||||
src_key_padding_masks=src_key_padding_mask,
|
||||
pos=src_pos,
|
||||
prompt=prompt.transpose(0, 1),
|
||||
prompt_key_padding_mask=prompt_key_padding_mask,
|
||||
encoder_extra_kwargs=encoder_extra_kwargs,
|
||||
)
|
||||
|
||||
return {
|
||||
"memory": out,
|
||||
"padding_mask": key_padding_masks_flatten,
|
||||
"pos_embed": lvl_pos_embed_flatten,
|
||||
"memory_text": prompt,
|
||||
"level_start_index": level_start_index,
|
||||
"spatial_shapes": spatial_shapes,
|
||||
"valid_ratios": valid_ratios,
|
||||
}
|
||||
|
||||
|
||||
def pool_text_feat(prompt, prompt_mask, pool_with_mask):
|
||||
"""Mean-pool the prompt embeddings over the valid tokens only."""
|
||||
# prompt has shape (seq, bs, dim)
|
||||
if not pool_with_mask:
|
||||
return prompt.mean(dim=0)
|
||||
|
||||
# prompt_mask has shape (bs, seq), where False is valid and True is padding
|
||||
assert prompt_mask.dim() == 2
|
||||
# is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
|
||||
is_valid = (~prompt_mask).float().permute(1, 0)[..., None]
|
||||
# num_valid has shape (bs, 1)
|
||||
num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
|
||||
|
||||
# mean pool over all the valid tokens
|
||||
pooled_text = (prompt * is_valid).sum(dim=0) / num_valid
|
||||
return pooled_text
|
||||
415
ultralytics/models/sam/sam3/geometry_encoders.py
Normal file
415
ultralytics/models/sam/sam3/geometry_encoders.py
Normal file
|
|
@ -0,0 +1,415 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
from ultralytics.nn.modules.utils import _get_clones
|
||||
from ultralytics.utils.ops import xywh2xyxy
|
||||
|
||||
|
||||
def is_right_padded(mask: torch.Tensor):
|
||||
"""Given a padding mask (following pytorch convention, 1s for padded values), returns whether the padding is on the
|
||||
right or not.
|
||||
"""
|
||||
return (mask.long() == torch.sort(mask.long(), dim=-1)[0]).all()
|
||||
|
||||
|
||||
def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
|
||||
"""
|
||||
Concatenates two right-padded sequences, such that the resulting sequence
|
||||
is contiguous and also right-padded.
|
||||
|
||||
Following pytorch's convention, tensors are sequence first, and the mask are
|
||||
batch first, with 1s for padded values.
|
||||
|
||||
:param seq1: A tensor of shape (seq1_length, batch_size, hidden_size).
|
||||
:param mask1: A tensor of shape (batch_size, seq1_length).
|
||||
:param seq2: A tensor of shape (seq2_length, batch_size, hidden_size).
|
||||
:param mask2: A tensor of shape (batch_size, seq2_length).
|
||||
:param return_index: If True, also returns the index of the ids of the element of seq2
|
||||
in the concatenated sequence. This can be used to retrieve the elements of seq2
|
||||
:return: A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
|
||||
otherwise (concatenated_sequence, concatenated_mask, index).
|
||||
"""
|
||||
seq1_length, batch_size, hidden_size = seq1.shape
|
||||
seq2_length, batch_size, hidden_size = seq2.shape
|
||||
|
||||
assert batch_size == seq1.size(1) == seq2.size(1) == mask1.size(0) == mask2.size(0)
|
||||
assert hidden_size == seq1.size(2) == seq2.size(2)
|
||||
assert seq1_length == mask1.size(1)
|
||||
assert seq2_length == mask2.size(1)
|
||||
|
||||
torch._assert_async(is_right_padded(mask1))
|
||||
torch._assert_async(is_right_padded(mask2))
|
||||
|
||||
actual_seq1_lengths = (~mask1).sum(dim=-1)
|
||||
actual_seq2_lengths = (~mask2).sum(dim=-1)
|
||||
|
||||
final_lengths = actual_seq1_lengths + actual_seq2_lengths
|
||||
max_length = seq1_length + seq2_length
|
||||
concatenated_mask = (
|
||||
torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) >= final_lengths[:, None]
|
||||
)
|
||||
|
||||
# (max_len, batch_size, hidden_size)
|
||||
concatenated_sequence = torch.zeros((max_length, batch_size, hidden_size), device=seq2.device, dtype=seq2.dtype)
|
||||
concatenated_sequence[:seq1_length, :, :] = seq1
|
||||
|
||||
# At this point, the element of seq1 are in the right place
|
||||
# We just need to shift the elements of seq2
|
||||
|
||||
index = torch.arange(seq2_length, device=seq2.device)[:, None].repeat(1, batch_size)
|
||||
index = index + actual_seq1_lengths[None]
|
||||
|
||||
concatenated_sequence = concatenated_sequence.scatter(0, index[:, :, None].expand(-1, -1, hidden_size), seq2)
|
||||
|
||||
if return_index:
|
||||
return concatenated_sequence, concatenated_mask, index
|
||||
|
||||
return concatenated_sequence, concatenated_mask
|
||||
|
||||
|
||||
class Prompt:
|
||||
"""Utility class to manipulate geometric prompts.
|
||||
|
||||
We expect the sequences in pytorch convention, that is sequence first, batch second The dimensions are expected as
|
||||
follows: box_embeddings shape: N_boxes x B x C_box box_mask shape: B x N_boxes. Can be None if nothing is masked out
|
||||
point_embeddings shape: N_points x B x C_point point_mask shape: B x N_points. Can be None if nothing is masked out
|
||||
mask_embeddings shape: N_masks x B x 1 x H_mask x W_mask mask_mask shape: B x N_masks. Can be None if nothing is
|
||||
masked out
|
||||
|
||||
We also store positive/negative labels. These tensors are also stored batch-first If they are None, we'll assume
|
||||
positive labels everywhere box_labels: long tensor of shape N_boxes x B point_labels: long tensor of shape N_points
|
||||
x B mask_labels: long tensor of shape N_masks x B
|
||||
"""
|
||||
|
||||
def __init__(self, box_embeddings=None, box_mask=None, box_labels=None):
|
||||
"""Initialize the Prompt object."""
|
||||
# Check for null prompt
|
||||
# Check for null prompt
|
||||
if box_embeddings is None:
|
||||
self.box_embeddings = None
|
||||
self.box_labels = None
|
||||
self.box_mask = None
|
||||
return
|
||||
|
||||
# Get sequence length, batch size, and device
|
||||
box_seq_len = box_embeddings.shape[0]
|
||||
bs = box_embeddings.shape[1]
|
||||
device = box_embeddings.device
|
||||
|
||||
# Initialize labels and attention mask if not provided
|
||||
if box_labels is None:
|
||||
box_labels = torch.ones(box_seq_len, bs, device=device, dtype=torch.long)
|
||||
if box_mask is None:
|
||||
box_mask = torch.zeros(bs, box_seq_len, device=device, dtype=torch.bool)
|
||||
|
||||
# Dimension checks
|
||||
assert list(box_embeddings.shape[:2]) == [box_seq_len, bs], (
|
||||
f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
|
||||
)
|
||||
assert box_embeddings.shape[-1] == 4, (
|
||||
f"Expected box embeddings to have 4 coordinates, got {box_embeddings.shape[-1]}"
|
||||
)
|
||||
assert list(box_mask.shape) == [bs, box_seq_len], (
|
||||
f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
|
||||
)
|
||||
assert list(box_labels.shape) == [box_seq_len, bs], (
|
||||
f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
|
||||
)
|
||||
|
||||
# Device checks
|
||||
assert box_embeddings.device == device, (
|
||||
f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
|
||||
)
|
||||
assert box_mask.device == device, f"Expected box mask to be on device {device}, got {box_mask.device}"
|
||||
assert box_labels.device == device, f"Expected box labels to be on device {device}, got {box_labels.device}"
|
||||
|
||||
self.box_embeddings = box_embeddings
|
||||
self.box_mask = box_mask
|
||||
self.box_labels = box_labels
|
||||
|
||||
def append_boxes(self, boxes, labels=None, mask=None):
|
||||
"""Append box prompts to existing prompts.
|
||||
|
||||
Args:
|
||||
boxes: Tensor of shape (N_new_boxes, B, 4) with normalized box coordinates
|
||||
labels: Optional tensor of shape (N_new_boxes, B) with positive/negative labels
|
||||
mask: Optional tensor of shape (B, N_new_boxes) for attention mask
|
||||
"""
|
||||
if self.box_embeddings is None:
|
||||
# First boxes - initialize
|
||||
self.box_embeddings = boxes
|
||||
bs = boxes.shape[1]
|
||||
box_seq_len = boxes.shape[0]
|
||||
|
||||
if labels is None:
|
||||
labels = torch.ones(box_seq_len, bs, device=boxes.device, dtype=torch.long)
|
||||
if mask is None:
|
||||
mask = torch.zeros(bs, box_seq_len, device=boxes.device, dtype=torch.bool)
|
||||
|
||||
self.box_labels = labels
|
||||
self.box_mask = mask
|
||||
return
|
||||
|
||||
# Append to existing boxes
|
||||
bs = self.box_embeddings.shape[1]
|
||||
assert boxes.shape[1] == bs, f"Batch size mismatch: expected {bs}, got {boxes.shape[1]}"
|
||||
|
||||
if labels is None:
|
||||
labels = torch.ones(boxes.shape[0], bs, device=boxes.device, dtype=torch.long)
|
||||
if mask is None:
|
||||
mask = torch.zeros(bs, boxes.shape[0], dtype=torch.bool, device=boxes.device)
|
||||
|
||||
assert list(boxes.shape[:2]) == list(labels.shape[:2]), (
|
||||
f"Shape mismatch between boxes {boxes.shape} and labels {labels.shape}"
|
||||
)
|
||||
|
||||
# Concatenate using the helper function
|
||||
self.box_labels, _ = concat_padded_sequences(
|
||||
self.box_labels.unsqueeze(-1), self.box_mask, labels.unsqueeze(-1), mask
|
||||
)
|
||||
self.box_labels = self.box_labels.squeeze(-1)
|
||||
self.box_embeddings, self.box_mask = concat_padded_sequences(self.box_embeddings, self.box_mask, boxes, mask)
|
||||
|
||||
|
||||
class SequenceGeometryEncoder(nn.Module):
|
||||
"""Encoder for geometric box prompts. Assumes boxes are passed in the "normalized CxCyWH" format.
|
||||
|
||||
Boxes can be encoded with any of the three possibilities:
|
||||
- direct projection: linear projection from coordinate space to d_model
|
||||
- pooling: RoI align features from the backbone
|
||||
- pos encoder: position encoding of the box center
|
||||
|
||||
These three options are mutually compatible and will be summed if multiple are selected.
|
||||
|
||||
As an alternative, boxes can be encoded as two corner points (top-left and bottom-right).
|
||||
|
||||
The encoded sequence can be further processed with a transformer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encode_boxes_as_points: bool,
|
||||
boxes_direct_project: bool,
|
||||
boxes_pool: bool,
|
||||
boxes_pos_enc: bool,
|
||||
d_model: int,
|
||||
pos_enc,
|
||||
num_layers: int,
|
||||
layer: nn.Module,
|
||||
roi_size: int = 7,
|
||||
add_cls: bool = True,
|
||||
add_post_encode_proj: bool = True,
|
||||
use_act_ckpt: bool = False,
|
||||
):
|
||||
"""Initialize the SequenceGeometryEncoder."""
|
||||
super().__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
self.pos_enc = pos_enc
|
||||
self.encode_boxes_as_points = encode_boxes_as_points
|
||||
self.roi_size = roi_size
|
||||
|
||||
# Label embeddings: 2 labels if encoding as boxes (pos/neg)
|
||||
# 6 labels if encoding as points (regular pos/neg, top-left pos/neg, bottom-right pos/neg)
|
||||
num_labels = 6 if self.encode_boxes_as_points else 2
|
||||
self.label_embed = torch.nn.Embedding(num_labels, self.d_model)
|
||||
|
||||
# CLS token for pooling
|
||||
self.cls_embed = None
|
||||
if add_cls:
|
||||
self.cls_embed = torch.nn.Embedding(1, self.d_model)
|
||||
|
||||
# Point encoding (used when encode_boxes_as_points is True)
|
||||
if encode_boxes_as_points:
|
||||
self.points_direct_project = nn.Linear(2, self.d_model)
|
||||
self.points_pool_project = None
|
||||
self.points_pos_enc_project = None
|
||||
else:
|
||||
# Box encoding modules
|
||||
assert boxes_direct_project or boxes_pos_enc or boxes_pool, "Error: need at least one way to encode boxes"
|
||||
self.points_direct_project = None
|
||||
self.points_pool_project = None
|
||||
self.points_pos_enc_project = None
|
||||
|
||||
self.boxes_direct_project = None
|
||||
self.boxes_pool_project = None
|
||||
self.boxes_pos_enc_project = None
|
||||
|
||||
if boxes_direct_project:
|
||||
self.boxes_direct_project = nn.Linear(4, self.d_model)
|
||||
if boxes_pool:
|
||||
self.boxes_pool_project = nn.Conv2d(self.d_model, self.d_model, self.roi_size)
|
||||
if boxes_pos_enc:
|
||||
self.boxes_pos_enc_project = nn.Linear(self.d_model + 2, self.d_model)
|
||||
|
||||
self.final_proj = None
|
||||
if add_post_encode_proj:
|
||||
self.final_proj = nn.Linear(self.d_model, self.d_model)
|
||||
self.norm = nn.LayerNorm(self.d_model)
|
||||
|
||||
self.img_pre_norm = nn.Identity()
|
||||
if self.points_pool_project is not None or self.boxes_pool_project is not None:
|
||||
self.img_pre_norm = nn.LayerNorm(self.d_model)
|
||||
|
||||
self.encode = None
|
||||
if num_layers > 0:
|
||||
assert add_cls, "It's currently highly recommended to add a CLS when using a transformer"
|
||||
self.encode = _get_clones(layer, num_layers)
|
||||
self.encode_norm = nn.LayerNorm(self.d_model)
|
||||
|
||||
self.use_act_ckpt = use_act_ckpt
|
||||
|
||||
def _encode_points(self, points, points_mask, points_labels, img_feats):
|
||||
"""Encode points (used when boxes are converted to corner points)."""
|
||||
# Direct projection of coordinates
|
||||
points_embed = self.points_direct_project(points.to(img_feats.dtype))
|
||||
|
||||
# Add label embeddings
|
||||
type_embed = self.label_embed(points_labels.long())
|
||||
return type_embed + points_embed, points_mask
|
||||
|
||||
def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats: torch.Tensor):
|
||||
"""Encode boxes using configured encoding methods."""
|
||||
boxes_embed = None
|
||||
n_boxes, bs = boxes.shape[:2]
|
||||
|
||||
if self.boxes_direct_project is not None:
|
||||
proj = self.boxes_direct_project(boxes.to(img_feats.dtype))
|
||||
boxes_embed = proj
|
||||
|
||||
if self.boxes_pool_project is not None:
|
||||
H, W = img_feats.shape[-2:]
|
||||
|
||||
# Convert boxes to xyxy format and denormalize
|
||||
boxes_xyxy = xywh2xyxy(boxes.to(img_feats.dtype))
|
||||
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
|
||||
scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
|
||||
scale = scale.view(1, 1, 4)
|
||||
boxes_xyxy = boxes_xyxy * scale
|
||||
|
||||
# RoI align
|
||||
sampled = torchvision.ops.roi_align(img_feats, boxes_xyxy.transpose(0, 1).unbind(0), self.roi_size)
|
||||
assert list(sampled.shape) == [
|
||||
bs * n_boxes,
|
||||
self.d_model,
|
||||
self.roi_size,
|
||||
self.roi_size,
|
||||
]
|
||||
proj = self.boxes_pool_project(sampled)
|
||||
proj = proj.view(bs, n_boxes, self.d_model).transpose(0, 1)
|
||||
|
||||
if boxes_embed is None:
|
||||
boxes_embed = proj
|
||||
else:
|
||||
boxes_embed = boxes_embed + proj
|
||||
|
||||
if self.boxes_pos_enc_project is not None:
|
||||
cx, cy, w, h = boxes.unbind(-1)
|
||||
enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten())
|
||||
enc = enc.view(boxes.shape[0], boxes.shape[1], enc.shape[-1])
|
||||
|
||||
proj = self.boxes_pos_enc_project(enc.to(img_feats.dtype))
|
||||
if boxes_embed is None:
|
||||
boxes_embed = proj
|
||||
else:
|
||||
boxes_embed = boxes_embed + proj
|
||||
|
||||
# Add label embeddings
|
||||
type_embed = self.label_embed(boxes_labels.long())
|
||||
return type_embed + boxes_embed, boxes_mask
|
||||
|
||||
def forward(self, geo_prompt: Prompt, img_feats, img_sizes, img_pos_embeds=None):
|
||||
"""Encode geometric box prompts.
|
||||
|
||||
Args:
|
||||
geo_prompt: Prompt object containing box embeddings, masks, and labels
|
||||
img_feats: List of image features from backbone
|
||||
img_sizes: List of (H, W) tuples for each feature level
|
||||
img_pos_embeds: Optional position embeddings for image features
|
||||
|
||||
Returns:
|
||||
Tuple of (encoded_embeddings, attention_mask)
|
||||
"""
|
||||
boxes = geo_prompt.box_embeddings
|
||||
boxes_mask = geo_prompt.box_mask
|
||||
boxes_labels = geo_prompt.box_labels
|
||||
|
||||
seq_first_img_feats = img_feats[-1] # [H*W, B, C]
|
||||
seq_first_img_pos_embeds = (
|
||||
img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(seq_first_img_feats)
|
||||
)
|
||||
|
||||
# Prepare image features for pooling if needed
|
||||
if self.points_pool_project or self.boxes_pool_project:
|
||||
assert len(img_feats) == len(img_sizes)
|
||||
cur_img_feat = img_feats[-1]
|
||||
cur_img_feat = self.img_pre_norm(cur_img_feat)
|
||||
H, W = img_sizes[-1]
|
||||
assert cur_img_feat.shape[0] == H * W
|
||||
N, C = cur_img_feat.shape[-2:]
|
||||
# Reshape to NxCxHxW
|
||||
cur_img_feat = cur_img_feat.permute(1, 2, 0)
|
||||
cur_img_feat = cur_img_feat.view(N, C, H, W)
|
||||
img_feats = cur_img_feat
|
||||
|
||||
if self.encode_boxes_as_points:
|
||||
# Convert boxes to corner points
|
||||
assert boxes is not None and boxes.shape[-1] == 4
|
||||
|
||||
boxes_xyxy = xywh2xyxy(boxes)
|
||||
top_left, bottom_right = boxes_xyxy.split(split_size=2, dim=-1)
|
||||
|
||||
# Adjust labels for corner points (offset by 2 and 4)
|
||||
labels_tl = boxes_labels + 2
|
||||
labels_br = boxes_labels + 4
|
||||
|
||||
# Concatenate top-left and bottom-right points
|
||||
points = torch.cat([top_left, bottom_right], dim=0)
|
||||
points_labels = torch.cat([labels_tl, labels_br], dim=0)
|
||||
points_mask = torch.cat([boxes_mask, boxes_mask], dim=1)
|
||||
|
||||
final_embeds, final_mask = self._encode_points(
|
||||
points=points,
|
||||
points_mask=points_mask,
|
||||
points_labels=points_labels,
|
||||
img_feats=img_feats,
|
||||
)
|
||||
else:
|
||||
# Encode boxes directly
|
||||
final_embeds, final_mask = self._encode_boxes(
|
||||
boxes=boxes,
|
||||
boxes_mask=boxes_mask,
|
||||
boxes_labels=boxes_labels,
|
||||
img_feats=img_feats,
|
||||
)
|
||||
|
||||
bs = final_embeds.shape[1]
|
||||
assert final_mask.shape[0] == bs
|
||||
|
||||
# Add CLS token if configured
|
||||
if self.cls_embed is not None:
|
||||
cls = self.cls_embed.weight.view(1, 1, self.d_model).repeat(1, bs, 1)
|
||||
cls_mask = torch.zeros(bs, 1, dtype=final_mask.dtype, device=final_mask.device)
|
||||
final_embeds, final_mask = concat_padded_sequences(final_embeds, final_mask, cls, cls_mask)
|
||||
|
||||
# Final projection
|
||||
if self.final_proj is not None:
|
||||
final_embeds = self.norm(self.final_proj(final_embeds))
|
||||
|
||||
# Transformer encoding layers
|
||||
if self.encode is not None:
|
||||
for lay in self.encode:
|
||||
final_embeds = lay(
|
||||
tgt=final_embeds,
|
||||
memory=seq_first_img_feats,
|
||||
tgt_key_padding_mask=final_mask,
|
||||
pos=seq_first_img_pos_embeds,
|
||||
)
|
||||
final_embeds = self.encode_norm(final_embeds)
|
||||
|
||||
return final_embeds, final_mask
|
||||
286
ultralytics/models/sam/sam3/maskformer_segmentation.py
Normal file
286
ultralytics/models/sam/sam3/maskformer_segmentation.py
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from ultralytics.nn.modules.transformer import MLP
|
||||
|
||||
|
||||
class LinearPresenceHead(nn.Sequential):
|
||||
"""Linear presence head for predicting the presence of classes in an image."""
|
||||
|
||||
def __init__(self, d_model):
|
||||
"""Initializes the LinearPresenceHead."""
|
||||
# a hack to make `LinearPresenceHead` compatible with old checkpoints
|
||||
super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1))
|
||||
|
||||
def forward(self, hs, prompt, prompt_mask):
|
||||
"""Forward pass of the presence head."""
|
||||
return super().forward(hs)
|
||||
|
||||
|
||||
class MaskPredictor(nn.Module):
|
||||
"""Predicts masks from object queries and pixel embeddings."""
|
||||
|
||||
def __init__(self, hidden_dim, mask_dim):
|
||||
"""Initializes the MaskPredictor."""
|
||||
super().__init__()
|
||||
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
|
||||
|
||||
def forward(self, obj_queries, pixel_embed):
|
||||
"""Predicts masks from object queries and pixel embeddings."""
|
||||
if len(obj_queries.shape) == 3:
|
||||
if pixel_embed.ndim == 3:
|
||||
# batch size was omitted
|
||||
mask_preds = torch.einsum("bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed)
|
||||
else:
|
||||
mask_preds = torch.einsum("bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed)
|
||||
else:
|
||||
# Assumed to have aux masks
|
||||
if pixel_embed.ndim == 3:
|
||||
# batch size was omitted
|
||||
mask_preds = torch.einsum("lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed)
|
||||
else:
|
||||
mask_preds = torch.einsum("lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed)
|
||||
|
||||
return mask_preds
|
||||
|
||||
|
||||
class SegmentationHead(nn.Module):
|
||||
"""Segmentation head that predicts masks from backbone features and object queries."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
upsampling_stages,
|
||||
use_encoder_inputs=False,
|
||||
aux_masks=False,
|
||||
no_dec=False,
|
||||
pixel_decoder=None,
|
||||
act_ckpt=False,
|
||||
shared_conv=False,
|
||||
compile_mode_pixel_decoder=None,
|
||||
):
|
||||
"""Initializes the SegmentationHead."""
|
||||
super().__init__()
|
||||
self.use_encoder_inputs = use_encoder_inputs
|
||||
self.aux_masks = aux_masks
|
||||
if pixel_decoder is not None:
|
||||
self.pixel_decoder = pixel_decoder
|
||||
else:
|
||||
self.pixel_decoder = PixelDecoder(
|
||||
hidden_dim,
|
||||
upsampling_stages,
|
||||
shared_conv=shared_conv,
|
||||
compile_mode=compile_mode_pixel_decoder,
|
||||
)
|
||||
self.no_dec = no_dec
|
||||
if no_dec:
|
||||
self.mask_predictor = nn.Conv2d(hidden_dim, 1, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim)
|
||||
|
||||
self.act_ckpt = act_ckpt
|
||||
|
||||
# used to update the output dictionary
|
||||
self.instance_keys = ["pred_masks"]
|
||||
|
||||
def _embed_pixels(self, backbone_feats: list[torch.Tensor], encoder_hidden_states) -> torch.Tensor:
|
||||
"""Embeds pixels using the pixel decoder."""
|
||||
if self.use_encoder_inputs:
|
||||
backbone_visual_feats = [bb_feat.clone() for bb_feat in backbone_feats]
|
||||
# Extract visual embeddings
|
||||
encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0)
|
||||
spatial_dim = math.prod(backbone_feats[-1].shape[-2:])
|
||||
encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape(-1, *backbone_feats[-1].shape[1:])
|
||||
|
||||
backbone_visual_feats[-1] = encoder_visual_embed
|
||||
if self.act_ckpt:
|
||||
pixel_embed = checkpoint.checkpoint(self.pixel_decoder, backbone_visual_feats, use_reentrant=False)
|
||||
else:
|
||||
pixel_embed = self.pixel_decoder(backbone_visual_feats)
|
||||
else:
|
||||
backbone_feats = [x for x in backbone_feats]
|
||||
pixel_embed = self.pixel_decoder(backbone_feats)
|
||||
if pixel_embed.shape[0] == 1:
|
||||
# For batch_size=1 training, we can avoid the indexing to save memory
|
||||
pixel_embed = pixel_embed.squeeze(0)
|
||||
else:
|
||||
pixel_embed = pixel_embed[[0], ...]
|
||||
return pixel_embed
|
||||
|
||||
def forward(
|
||||
self,
|
||||
backbone_feats: list[torch.Tensor],
|
||||
obj_queries: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Forward pass of the SegmentationHead."""
|
||||
if self.use_encoder_inputs:
|
||||
assert encoder_hidden_states is not None
|
||||
|
||||
pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
if self.no_dec:
|
||||
mask_pred = self.mask_predictor(pixel_embed)
|
||||
elif self.aux_masks:
|
||||
mask_pred = self.mask_predictor(obj_queries, pixel_embed)
|
||||
else:
|
||||
mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed)
|
||||
|
||||
return {"pred_masks": mask_pred}
|
||||
|
||||
|
||||
class PixelDecoder(nn.Module):
|
||||
"""Pixel decoder module that upsamples backbone features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
num_upsampling_stages,
|
||||
interpolation_mode="nearest",
|
||||
shared_conv=False,
|
||||
compile_mode=None,
|
||||
):
|
||||
"""Initializes the PixelDecoder."""
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_upsampling_stages = num_upsampling_stages
|
||||
self.interpolation_mode = interpolation_mode
|
||||
conv_layers = []
|
||||
norms = []
|
||||
num_convs = 1 if shared_conv else num_upsampling_stages
|
||||
for _ in range(num_convs):
|
||||
conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1))
|
||||
norms.append(nn.GroupNorm(8, self.hidden_dim))
|
||||
|
||||
self.conv_layers = nn.ModuleList(conv_layers)
|
||||
self.norms = nn.ModuleList(norms)
|
||||
self.shared_conv = shared_conv
|
||||
self.out_dim = self.conv_layers[-1].out_channels
|
||||
if compile_mode is not None:
|
||||
self.forward = torch.compile(self.forward, mode=compile_mode, dynamic=True, fullgraph=True)
|
||||
# Needed to make checkpointing happy. But we don't know if the module is checkpointed, so we disable it by default.
|
||||
torch._dynamo.config.optimize_ddp = False
|
||||
|
||||
def forward(self, backbone_feats: list[torch.Tensor]):
|
||||
"""Forward pass of the PixelDecoder."""
|
||||
prev_fpn = backbone_feats[-1]
|
||||
fpn_feats = backbone_feats[:-1]
|
||||
for layer_idx, bb_feat in enumerate(fpn_feats[::-1]):
|
||||
curr_fpn = bb_feat
|
||||
prev_fpn = curr_fpn + F.interpolate(prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode)
|
||||
if self.shared_conv:
|
||||
# only one conv layer
|
||||
layer_idx = 0
|
||||
prev_fpn = self.conv_layers[layer_idx](prev_fpn)
|
||||
prev_fpn = F.relu(self.norms[layer_idx](prev_fpn))
|
||||
|
||||
return prev_fpn
|
||||
|
||||
|
||||
class UniversalSegmentationHead(SegmentationHead):
|
||||
"""This module handles semantic+instance segmentation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
upsampling_stages,
|
||||
pixel_decoder,
|
||||
aux_masks=False,
|
||||
no_dec=False,
|
||||
act_ckpt=False,
|
||||
presence_head: bool = False,
|
||||
dot_product_scorer=None,
|
||||
cross_attend_prompt=None,
|
||||
):
|
||||
"""Initializes the UniversalSegmentationHead."""
|
||||
super().__init__(
|
||||
hidden_dim=hidden_dim,
|
||||
upsampling_stages=upsampling_stages,
|
||||
use_encoder_inputs=True,
|
||||
aux_masks=aux_masks,
|
||||
no_dec=no_dec,
|
||||
pixel_decoder=pixel_decoder,
|
||||
act_ckpt=act_ckpt,
|
||||
)
|
||||
self.d_model = hidden_dim
|
||||
|
||||
if dot_product_scorer is not None:
|
||||
assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
|
||||
|
||||
self.presence_head = None
|
||||
if presence_head:
|
||||
self.presence_head = (
|
||||
dot_product_scorer if dot_product_scorer is not None else LinearPresenceHead(self.d_model)
|
||||
)
|
||||
|
||||
self.cross_attend_prompt = cross_attend_prompt
|
||||
if self.cross_attend_prompt is not None:
|
||||
self.cross_attn_norm = nn.LayerNorm(self.d_model)
|
||||
|
||||
self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1)
|
||||
self.instance_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, self.d_model, kernel_size=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
backbone_feats: list[torch.Tensor],
|
||||
obj_queries: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
prompt: torch.Tensor = None,
|
||||
prompt_mask: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Forward pass of the UniversalSegmentationHead."""
|
||||
assert encoder_hidden_states is not None
|
||||
bs = encoder_hidden_states.shape[1]
|
||||
|
||||
if self.cross_attend_prompt is not None:
|
||||
tgt2 = self.cross_attn_norm(encoder_hidden_states)
|
||||
tgt2 = self.cross_attend_prompt(
|
||||
query=tgt2,
|
||||
key=prompt.to(tgt2.dtype),
|
||||
value=prompt.to(tgt2.dtype),
|
||||
key_padding_mask=prompt_mask,
|
||||
need_weights=False,
|
||||
)[0]
|
||||
encoder_hidden_states = tgt2 + encoder_hidden_states
|
||||
|
||||
presence_logit = None
|
||||
if self.presence_head is not None:
|
||||
pooled_enc = encoder_hidden_states.mean(0)
|
||||
presence_logit = (
|
||||
self.presence_head(
|
||||
pooled_enc.view(1, bs, 1, self.d_model),
|
||||
prompt=prompt,
|
||||
prompt_mask=prompt_mask,
|
||||
)
|
||||
.squeeze(0)
|
||||
.squeeze(1)
|
||||
)
|
||||
|
||||
pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
instance_embeds = self.instance_seg_head(pixel_embed)
|
||||
|
||||
if self.no_dec:
|
||||
mask_pred = self.mask_predictor(instance_embeds)
|
||||
elif self.aux_masks:
|
||||
mask_pred = self.mask_predictor(obj_queries, instance_embeds)
|
||||
else:
|
||||
mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds)
|
||||
|
||||
return {
|
||||
"pred_masks": mask_pred,
|
||||
"semantic_seg": self.semantic_seg_head(pixel_embed),
|
||||
"presence_logit": presence_logit,
|
||||
}
|
||||
198
ultralytics/models/sam/sam3/model_misc.py
Normal file
198
ultralytics/models/sam/sam3/model_misc.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""Various utility models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class DotProductScoring(torch.nn.Module):
|
||||
"""A module that computes dot-product scores between a set of query features and a."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_proj,
|
||||
prompt_mlp=None,
|
||||
clamp_logits=True,
|
||||
clamp_max_val=12.0,
|
||||
):
|
||||
"""Initialize the DotProductScoring module."""
|
||||
super().__init__()
|
||||
self.d_proj = d_proj
|
||||
assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None
|
||||
self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt
|
||||
self.prompt_proj = torch.nn.Linear(d_model, d_proj)
|
||||
self.hs_proj = torch.nn.Linear(d_model, d_proj)
|
||||
self.scale = float(1.0 / np.sqrt(d_proj))
|
||||
self.clamp_logits = clamp_logits
|
||||
if self.clamp_logits:
|
||||
self.clamp_max_val = clamp_max_val
|
||||
|
||||
def mean_pool_text(self, prompt, prompt_mask):
|
||||
"""Mean-pool the prompt embeddings over the valid tokens only."""
|
||||
# is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
|
||||
is_valid = (~prompt_mask).to(prompt.dtype).permute(1, 0)[..., None]
|
||||
# num_valid has shape (bs, 1)
|
||||
num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
|
||||
# mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim)
|
||||
pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid
|
||||
return pooled_prompt
|
||||
|
||||
def forward(self, hs, prompt, prompt_mask):
|
||||
"""Compute dot-product scores between hs and prompt."""
|
||||
# hs has shape (num_layer, bs, num_query, d_model)
|
||||
# prompt has shape (seq, bs, d_model)
|
||||
# prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding
|
||||
assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2
|
||||
|
||||
# apply MLP on prompt if specified
|
||||
if self.prompt_mlp is not None:
|
||||
prompt = self.prompt_mlp(prompt.to(hs.dtype))
|
||||
|
||||
# first, get the mean-pooled version of the prompt
|
||||
pooled_prompt = self.mean_pool_text(prompt, prompt_mask)
|
||||
|
||||
# then, project pooled_prompt and hs to d_proj dimensions
|
||||
proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj)
|
||||
proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj)
|
||||
|
||||
# finally, get dot-product scores of shape (num_layer, bs, num_query, 1)
|
||||
scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1))
|
||||
scores *= self.scale
|
||||
|
||||
# clamp scores to a max value to avoid numerical issues in loss or matcher
|
||||
if self.clamp_logits:
|
||||
scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
"""LayerScale module as introduced in "Meta Pseudo Labels" and used in."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
init_values: float | Tensor = 1e-5,
|
||||
inplace: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the LayerScale module."""
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Apply LayerScale to the input tensor."""
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
|
||||
|
||||
class TransformerWrapper(nn.Module):
|
||||
"""A wrapper for the transformer consisting of an encoder and a decoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder,
|
||||
decoder,
|
||||
d_model: int,
|
||||
two_stage_type="none", # ["none"] only for now
|
||||
pos_enc_at_input_dec=True,
|
||||
):
|
||||
"""Initialize the TransformerWrapper."""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.num_queries = decoder.num_queries if decoder is not None else None
|
||||
self.pos_enc_at_input_dec = pos_enc_at_input_dec
|
||||
|
||||
# for two stage
|
||||
assert two_stage_type in ["none"], f"unknown param {two_stage_type} of two_stage_type"
|
||||
self.two_stage_type = two_stage_type
|
||||
|
||||
self._reset_parameters()
|
||||
self.d_model = d_model
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""Initialize the parameters of the model."""
|
||||
for n, p in self.named_parameters():
|
||||
if p.dim() > 1:
|
||||
if "box_embed" not in n and "query_embed" not in n and "reference_points" not in n:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
|
||||
def get_valid_ratio(mask):
|
||||
"""Compute the valid ratio of height and width from the mask."""
|
||||
_, H, W = mask.shape
|
||||
valid_H = torch.sum(~mask[:, :, 0], 1)
|
||||
valid_W = torch.sum(~mask[:, 0, :], 1)
|
||||
valid_ratio_h = valid_H.float() / H
|
||||
valid_ratio_w = valid_W.float() / W
|
||||
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
||||
return valid_ratio
|
||||
|
||||
|
||||
def gen_sineembed_for_position(pos_tensor: torch.Tensor, num_feats: int = 256):
|
||||
"""Generate sinusoidal position embeddings for 2D or 4D coordinate tensors.
|
||||
|
||||
This function creates sinusoidal embeddings using sine and cosine functions at different frequencies, similar to the
|
||||
positional encoding used in Transformer models. It supports both 2D position tensors (x, y) and 4D tensors (x, y, w,
|
||||
h) for bounding box coordinates.
|
||||
|
||||
Args:
|
||||
pos_tensor (torch.Tensor): Input position tensor of shape (n_query, bs, 2) for 2D coordinates or (n_query, bs,
|
||||
4) for 4D coordinates (bounding boxes).
|
||||
num_feats (int): Number of feature dimensions for the output embedding. Must be even. Defaults to 256.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Sinusoidal position embeddings of shape (n_query, bs, num_feats) for 2D input or (n_query, bs,
|
||||
num_feats * 2) for 4D input.
|
||||
|
||||
Raises:
|
||||
AssertionError: If num_feats is not even.
|
||||
ValueError: If pos_tensor.size(-1) is not 2 or 4.
|
||||
|
||||
Examples:
|
||||
>>> pos_2d = torch.rand(100, 8, 2) # 100 queries, batch size 8, 2D coordinates
|
||||
>>> embeddings_2d = gen_sineembed_for_position(pos_2d, num_feats=256)
|
||||
>>> embeddings_2d.shape
|
||||
torch.Size([100, 8, 256])
|
||||
>>> pos_4d = torch.rand(50, 4, 4) # 50 queries, batch size 4, 4D coordinates
|
||||
>>> embeddings_4d = gen_sineembed_for_position(pos_4d, num_feats=128)
|
||||
>>> embeddings_4d.shape
|
||||
torch.Size([50, 4, 256])
|
||||
"""
|
||||
assert num_feats % 2 == 0
|
||||
num_feats = num_feats // 2
|
||||
# n_query, bs, _ = pos_tensor.size()
|
||||
# sineembed_tensor = torch.zeros(n_query, bs, 256)
|
||||
scale = 2 * math.pi
|
||||
dim_t = torch.arange(num_feats, dtype=pos_tensor.dtype, device=pos_tensor.device)
|
||||
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats)
|
||||
x_embed = pos_tensor[:, :, 0] * scale
|
||||
y_embed = pos_tensor[:, :, 1] * scale
|
||||
pos_x = x_embed[:, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
||||
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
||||
if pos_tensor.size(-1) == 2:
|
||||
pos = torch.cat((pos_y, pos_x), dim=2)
|
||||
elif pos_tensor.size(-1) == 4:
|
||||
w_embed = pos_tensor[:, :, 2] * scale
|
||||
pos_w = w_embed[:, :, None] / dim_t
|
||||
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
|
||||
|
||||
h_embed = pos_tensor[:, :, 3] * scale
|
||||
pos_h = h_embed[:, :, None] / dim_t
|
||||
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
|
||||
|
||||
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
|
||||
return pos
|
||||
129
ultralytics/models/sam/sam3/necks.py
Normal file
129
ultralytics/models/sam/sam3/necks.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""Necks are the interface between a vision backbone and the rest of the detection model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Sam3DualViTDetNeck(nn.Module):
|
||||
"""A neck that implements a simple FPN as in ViTDet, with support for dual necks (for SAM3 and SAM2)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trunk: nn.Module,
|
||||
position_encoding: nn.Module,
|
||||
d_model: int,
|
||||
scale_factors=(4.0, 2.0, 1.0, 0.5),
|
||||
add_sam2_neck: bool = False,
|
||||
):
|
||||
"""
|
||||
SimpleFPN neck a la ViTDet
|
||||
(From detectron2, very lightly adapted)
|
||||
It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights.
|
||||
|
||||
:param trunk: the backbone
|
||||
:param position_encoding: the positional encoding to use
|
||||
:param d_model: the dimension of the model
|
||||
"""
|
||||
super().__init__()
|
||||
self.trunk = trunk
|
||||
self.position_encoding = position_encoding
|
||||
self.convs = nn.ModuleList()
|
||||
|
||||
self.scale_factors = scale_factors
|
||||
use_bias = True
|
||||
dim: int = self.trunk.channel_list[-1]
|
||||
|
||||
for _, scale in enumerate(scale_factors):
|
||||
current = nn.Sequential()
|
||||
|
||||
if scale == 4.0:
|
||||
current.add_module(
|
||||
"dconv_2x2_0",
|
||||
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
|
||||
)
|
||||
current.add_module(
|
||||
"gelu",
|
||||
nn.GELU(),
|
||||
)
|
||||
current.add_module(
|
||||
"dconv_2x2_1",
|
||||
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
|
||||
)
|
||||
out_dim = dim // 4
|
||||
elif scale == 2.0:
|
||||
current.add_module(
|
||||
"dconv_2x2",
|
||||
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
|
||||
)
|
||||
out_dim = dim // 2
|
||||
elif scale == 1.0:
|
||||
out_dim = dim
|
||||
elif scale == 0.5:
|
||||
current.add_module(
|
||||
"maxpool_2x2",
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
out_dim = dim
|
||||
else:
|
||||
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
|
||||
|
||||
current.add_module(
|
||||
"conv_1x1",
|
||||
nn.Conv2d(
|
||||
in_channels=out_dim,
|
||||
out_channels=d_model,
|
||||
kernel_size=1,
|
||||
bias=use_bias,
|
||||
),
|
||||
)
|
||||
current.add_module(
|
||||
"conv_3x3",
|
||||
nn.Conv2d(
|
||||
in_channels=d_model,
|
||||
out_channels=d_model,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=use_bias,
|
||||
),
|
||||
)
|
||||
self.convs.append(current)
|
||||
|
||||
self.sam2_convs = None
|
||||
if add_sam2_neck:
|
||||
# Assumes sam2 neck is just a clone of the original neck
|
||||
self.sam2_convs = deepcopy(self.convs)
|
||||
|
||||
def forward(
|
||||
self, tensor_list: list[torch.Tensor]
|
||||
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""Get the feature maps and positional encodings from the neck."""
|
||||
xs = self.trunk(tensor_list)
|
||||
sam3_out, sam3_pos = [], []
|
||||
sam2_out, sam2_pos = None, None
|
||||
if self.sam2_convs is not None:
|
||||
sam2_out, sam2_pos = [], []
|
||||
x = xs[-1] # simpleFPN
|
||||
for i in range(len(self.convs)):
|
||||
sam3_x_out = self.convs[i](x)
|
||||
sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype)
|
||||
sam3_out.append(sam3_x_out)
|
||||
sam3_pos.append(sam3_pos_out)
|
||||
|
||||
if self.sam2_convs is not None:
|
||||
sam2_x_out = self.sam2_convs[i](x)
|
||||
sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
|
||||
sam2_out.append(sam2_x_out)
|
||||
sam2_pos.append(sam2_pos_out)
|
||||
return sam3_out, sam3_pos, sam2_out, sam2_pos
|
||||
|
||||
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
|
||||
"""Set the image size for the trunk backbone."""
|
||||
self.trunk.set_imgsz(imgsz)
|
||||
357
ultralytics/models/sam/sam3/sam3_image.py
Normal file
357
ultralytics/models/sam/sam3/sam3_image.py
Normal file
|
|
@ -0,0 +1,357 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.nn.modules.utils import inverse_sigmoid
|
||||
from ultralytics.utils.ops import xywh2xyxy
|
||||
|
||||
from .geometry_encoders import Prompt
|
||||
from .vl_combiner import SAM3VLBackbone
|
||||
|
||||
|
||||
def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True):
|
||||
"""Helper function to update output dictionary with main and auxiliary outputs."""
|
||||
out[out_name] = out_value[-1] if auxiliary else out_value
|
||||
if auxiliary and update_aux:
|
||||
if "aux_outputs" not in out:
|
||||
out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)]
|
||||
assert len(out["aux_outputs"]) == len(out_value) - 1
|
||||
for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]):
|
||||
aux_output[out_name] = aux_value
|
||||
|
||||
|
||||
class SAM3SemanticModel(torch.nn.Module):
|
||||
"""SAM3 model for semantic segmentation with vision-language backbone."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone: SAM3VLBackbone,
|
||||
transformer,
|
||||
input_geometry_encoder,
|
||||
segmentation_head=None,
|
||||
num_feature_levels=1,
|
||||
o2m_mask_predict=True,
|
||||
dot_prod_scoring=None,
|
||||
use_instance_query: bool = True,
|
||||
multimask_output: bool = True,
|
||||
use_act_checkpoint_seg_head: bool = True,
|
||||
matcher=None,
|
||||
use_dot_prod_scoring=True,
|
||||
supervise_joint_box_scores: bool = False, # only relevant if using presence token/score
|
||||
detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score
|
||||
separate_scorer_for_instance: bool = False,
|
||||
num_interactive_steps_val: int = 0,
|
||||
):
|
||||
"""Initialize the SAM3SemanticModel."""
|
||||
super().__init__()
|
||||
self.backbone = backbone
|
||||
self.geometry_encoder = input_geometry_encoder
|
||||
self.transformer = transformer
|
||||
self.hidden_dim = transformer.d_model
|
||||
self.num_feature_levels = num_feature_levels
|
||||
self.segmentation_head = segmentation_head
|
||||
|
||||
self.o2m_mask_predict = o2m_mask_predict
|
||||
|
||||
self.dot_prod_scoring = dot_prod_scoring
|
||||
self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head
|
||||
self.matcher = matcher
|
||||
|
||||
self.num_interactive_steps_val = num_interactive_steps_val
|
||||
self.use_dot_prod_scoring = use_dot_prod_scoring
|
||||
|
||||
if self.use_dot_prod_scoring:
|
||||
assert dot_prod_scoring is not None
|
||||
self.dot_prod_scoring = dot_prod_scoring
|
||||
self.instance_dot_prod_scoring = None
|
||||
if separate_scorer_for_instance:
|
||||
self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring)
|
||||
else:
|
||||
self.class_embed = torch.nn.Linear(self.hidden_dim, 1)
|
||||
self.instance_class_embed = None
|
||||
if separate_scorer_for_instance:
|
||||
self.instance_class_embed = deepcopy(self.class_embed)
|
||||
|
||||
self.supervise_joint_box_scores = supervise_joint_box_scores
|
||||
self.detach_presence_in_joint_score = detach_presence_in_joint_score
|
||||
|
||||
# verify the number of queries for O2O and O2M
|
||||
num_o2o_static = self.transformer.decoder.num_queries
|
||||
num_o2m_static = self.transformer.decoder.num_o2m_queries
|
||||
assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0)
|
||||
self.dac = self.transformer.decoder.dac
|
||||
|
||||
self.use_instance_query = use_instance_query
|
||||
self.multimask_output = multimask_output
|
||||
|
||||
self.text_embeddings = {}
|
||||
self.names = []
|
||||
|
||||
def _prepare_backbone_features(self, backbone_out, num_prompts=1):
|
||||
"""Prepare and flatten visual features from the image backbone output for further processing."""
|
||||
if num_prompts > 1: # expand features if there's more than one prompt
|
||||
for i, feat in enumerate(backbone_out["backbone_fpn"]):
|
||||
backbone_out["backbone_fpn"][i] = feat.expand(num_prompts, -1, -1, -1)
|
||||
for i, pos in enumerate(backbone_out["vision_pos_enc"]):
|
||||
pos = pos.expand(num_prompts, -1, -1, -1)
|
||||
backbone_out["vision_pos_enc"][i] = pos
|
||||
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
||||
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
||||
|
||||
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
|
||||
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
|
||||
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
||||
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
|
||||
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
img_feats,
|
||||
img_pos_embeds,
|
||||
vis_feat_sizes,
|
||||
geometric_prompt,
|
||||
visual_prompt_embed=None,
|
||||
visual_prompt_mask=None,
|
||||
prev_mask_pred=None,
|
||||
):
|
||||
"""Encode the geometric and visual prompts."""
|
||||
if prev_mask_pred is not None:
|
||||
img_feats = [img_feats[-1] + prev_mask_pred]
|
||||
# Encode geometry
|
||||
geo_feats, geo_masks = self.geometry_encoder(
|
||||
geo_prompt=geometric_prompt,
|
||||
img_feats=img_feats,
|
||||
img_sizes=vis_feat_sizes,
|
||||
img_pos_embeds=img_pos_embeds,
|
||||
)
|
||||
if visual_prompt_embed is None:
|
||||
visual_prompt_embed = torch.zeros((0, *geo_feats.shape[1:]), device=geo_feats.device)
|
||||
visual_prompt_mask = torch.zeros(
|
||||
(*geo_masks.shape[:-1], 0),
|
||||
device=geo_masks.device,
|
||||
dtype=geo_masks.dtype,
|
||||
)
|
||||
prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0)
|
||||
prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1)
|
||||
return prompt, prompt_mask
|
||||
|
||||
def _run_encoder(
|
||||
self,
|
||||
img_feats,
|
||||
img_pos_embeds,
|
||||
vis_feat_sizes,
|
||||
prompt,
|
||||
prompt_mask,
|
||||
encoder_extra_kwargs: dict | None = None,
|
||||
):
|
||||
"""Run the transformer encoder."""
|
||||
# Run the encoder
|
||||
# make a copy of the image feature lists since the encoder may modify these lists in-place
|
||||
memory = self.transformer.encoder(
|
||||
src=img_feats.copy(),
|
||||
src_key_padding_mask=None,
|
||||
src_pos=img_pos_embeds.copy(),
|
||||
prompt=prompt,
|
||||
prompt_key_padding_mask=prompt_mask,
|
||||
feat_sizes=vis_feat_sizes,
|
||||
encoder_extra_kwargs=encoder_extra_kwargs,
|
||||
)
|
||||
encoder_out = {
|
||||
# encoded image features
|
||||
"encoder_hidden_states": memory["memory"],
|
||||
"pos_embed": memory["pos_embed"],
|
||||
"padding_mask": memory["padding_mask"],
|
||||
"spatial_shapes": memory["spatial_shapes"],
|
||||
"valid_ratios": memory["valid_ratios"],
|
||||
"vis_feat_sizes": vis_feat_sizes,
|
||||
# encoded text features (or other prompts)
|
||||
"prompt_before_enc": prompt,
|
||||
"prompt_after_enc": memory.get("memory_text", prompt),
|
||||
"prompt_mask": prompt_mask,
|
||||
}
|
||||
return encoder_out
|
||||
|
||||
def _run_decoder(
|
||||
self,
|
||||
pos_embed,
|
||||
memory,
|
||||
src_mask,
|
||||
out,
|
||||
prompt,
|
||||
prompt_mask,
|
||||
encoder_out,
|
||||
):
|
||||
"""Run the transformer decoder."""
|
||||
bs = memory.shape[1]
|
||||
query_embed = self.transformer.decoder.query_embed.weight
|
||||
tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
hs, reference_boxes, dec_presence_out, _ = self.transformer.decoder(
|
||||
tgt=tgt,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=src_mask,
|
||||
pos=pos_embed,
|
||||
reference_boxes=None,
|
||||
spatial_shapes=encoder_out["spatial_shapes"],
|
||||
valid_ratios=encoder_out["valid_ratios"],
|
||||
tgt_mask=None,
|
||||
memory_text=prompt,
|
||||
text_attention_mask=prompt_mask,
|
||||
apply_dac=False,
|
||||
)
|
||||
hs = hs.transpose(1, 2) # seq-first to batch-first
|
||||
reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first
|
||||
if dec_presence_out is not None:
|
||||
# seq-first to batch-first
|
||||
dec_presence_out = dec_presence_out.transpose(1, 2)
|
||||
self._update_scores_and_boxes(
|
||||
out,
|
||||
hs,
|
||||
reference_boxes,
|
||||
prompt,
|
||||
prompt_mask,
|
||||
dec_presence_out=dec_presence_out,
|
||||
)
|
||||
return out, hs
|
||||
|
||||
def _update_scores_and_boxes(
|
||||
self,
|
||||
out,
|
||||
hs,
|
||||
reference_boxes,
|
||||
prompt,
|
||||
prompt_mask,
|
||||
dec_presence_out=None,
|
||||
is_instance_prompt=False,
|
||||
):
|
||||
"""Update output dict with class scores and box predictions."""
|
||||
num_o2o = hs.size(2)
|
||||
# score prediction
|
||||
if self.use_dot_prod_scoring:
|
||||
dot_prod_scoring_head = self.dot_prod_scoring
|
||||
if is_instance_prompt and self.instance_dot_prod_scoring is not None:
|
||||
dot_prod_scoring_head = self.instance_dot_prod_scoring
|
||||
outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask)
|
||||
else:
|
||||
class_embed_head = self.class_embed
|
||||
if is_instance_prompt and self.instance_class_embed is not None:
|
||||
class_embed_head = self.instance_class_embed
|
||||
outputs_class = class_embed_head(hs)
|
||||
|
||||
# box prediction
|
||||
box_head = self.transformer.decoder.bbox_embed
|
||||
if is_instance_prompt and self.transformer.decoder.instance_bbox_embed is not None:
|
||||
box_head = self.transformer.decoder.instance_bbox_embed
|
||||
anchor_box_offsets = box_head(hs)
|
||||
reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
|
||||
outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid()
|
||||
outputs_boxes_xyxy = xywh2xyxy(outputs_coord)
|
||||
|
||||
if dec_presence_out is not None:
|
||||
_update_out(out, "presence_logit_dec", dec_presence_out, update_aux=False)
|
||||
|
||||
if self.supervise_joint_box_scores:
|
||||
assert dec_presence_out is not None
|
||||
prob_dec_presence_out = dec_presence_out.clone().sigmoid()
|
||||
if self.detach_presence_in_joint_score:
|
||||
prob_dec_presence_out = prob_dec_presence_out.detach()
|
||||
|
||||
outputs_class = inverse_sigmoid(outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)).clamp(
|
||||
min=-10.0, max=10.0
|
||||
)
|
||||
|
||||
_update_out(out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=False)
|
||||
_update_out(out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=False)
|
||||
_update_out(out, "pred_boxes_xyxy", outputs_boxes_xyxy[:, :, :num_o2o], update_aux=False)
|
||||
|
||||
def _run_segmentation_heads(
|
||||
self,
|
||||
out,
|
||||
backbone_out,
|
||||
encoder_hidden_states,
|
||||
prompt,
|
||||
prompt_mask,
|
||||
hs,
|
||||
):
|
||||
"""Run segmentation heads and get masks."""
|
||||
if self.segmentation_head is not None:
|
||||
num_o2o = hs.size(2)
|
||||
obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o]
|
||||
seg_head_outputs = self.segmentation_head(
|
||||
backbone_feats=backbone_out["backbone_fpn"],
|
||||
obj_queries=obj_queries,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
prompt=prompt,
|
||||
prompt_mask=prompt_mask,
|
||||
)
|
||||
for k, v in seg_head_outputs.items():
|
||||
if k in self.segmentation_head.instance_keys:
|
||||
_update_out(out, k, v[:, :num_o2o], auxiliary=False)
|
||||
else:
|
||||
out[k] = v
|
||||
else:
|
||||
backbone_out.pop("backbone_fpn", None)
|
||||
|
||||
def forward_grounding(
|
||||
self, backbone_out: dict[str, torch.Tensor], text_ids: torch.Tensor, geometric_prompt: Prompt = None
|
||||
):
|
||||
"""Forward pass for grounding (detection + segmentation) given input images and text."""
|
||||
backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = self._prepare_backbone_features(
|
||||
backbone_out, num_prompts=len(text_ids)
|
||||
)
|
||||
backbone_out.update({k: v for k, v in self.text_embeddings.items()})
|
||||
with torch.profiler.record_function("SAM3Image._encode_prompt"):
|
||||
prompt, prompt_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt)
|
||||
# index text features (note that regardless of early or late fusion, the batch size of
|
||||
# `txt_feats` is always the number of *prompts* in the encoder)
|
||||
txt_feats = backbone_out["language_features"][:, text_ids]
|
||||
txt_masks = backbone_out["language_mask"][text_ids]
|
||||
# encode text
|
||||
prompt = torch.cat([txt_feats, prompt], dim=0)
|
||||
prompt_mask = torch.cat([txt_masks, prompt_mask], dim=1)
|
||||
|
||||
# Run the encoder
|
||||
with torch.profiler.record_function("SAM3Image._run_encoder"):
|
||||
encoder_out = self._run_encoder(img_feats, img_pos_embeds, vis_feat_sizes, prompt, prompt_mask)
|
||||
out = {"backbone_out": backbone_out}
|
||||
|
||||
# Run the decoder
|
||||
with torch.profiler.record_function("SAM3Image._run_decoder"):
|
||||
out, hs = self._run_decoder(
|
||||
memory=encoder_out["encoder_hidden_states"],
|
||||
pos_embed=encoder_out["pos_embed"],
|
||||
src_mask=encoder_out["padding_mask"],
|
||||
out=out,
|
||||
prompt=prompt,
|
||||
prompt_mask=prompt_mask,
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
|
||||
# Run segmentation heads
|
||||
with torch.profiler.record_function("SAM3Image._run_segmentation_heads"):
|
||||
self._run_segmentation_heads(
|
||||
out=out,
|
||||
backbone_out=backbone_out,
|
||||
encoder_hidden_states=encoder_out["encoder_hidden_states"],
|
||||
prompt=prompt,
|
||||
prompt_mask=prompt_mask,
|
||||
hs=hs,
|
||||
)
|
||||
return out
|
||||
|
||||
def set_classes(self, text: list[str]):
|
||||
"""Set the text embeddings for the given class names."""
|
||||
self.text_embeddings = self.backbone.forward_text(text)
|
||||
self.names = text
|
||||
|
||||
def set_imgsz(self, imgsz: tuple[int, int]):
|
||||
"""Set the image size for the model."""
|
||||
self.backbone.set_imgsz(imgsz)
|
||||
307
ultralytics/models/sam/sam3/text_encoder_ve.py
Normal file
307
ultralytics/models/sam/sam3/text_encoder_ve.py
Normal file
|
|
@ -0,0 +1,307 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from .model_misc import LayerScale
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
"""Transformer block with multi-head attention, layer normalization, and MLP feed-forward network."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
ls_init_value: float | None = None,
|
||||
act_layer: Callable[[], nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
):
|
||||
"""Initialize residual attention block with configurable dimensions and normalization."""
|
||||
super().__init__()
|
||||
# Attention
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
|
||||
|
||||
# LayerNorm, LayerScale
|
||||
self.ln_1 = norm_layer(d_model)
|
||||
self.ln_2 = norm_layer(d_model)
|
||||
|
||||
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||
|
||||
# MLP
|
||||
mlp_width = int(d_model * mlp_ratio)
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("c_fc", nn.Linear(d_model, mlp_width)),
|
||||
("gelu", act_layer()),
|
||||
("c_proj", nn.Linear(mlp_width, d_model)),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def attention(
|
||||
self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
"""Compute multi-head attention with optional cross-attention support and masking."""
|
||||
k_x = k_x if k_x is not None else q_x
|
||||
v_x = v_x if v_x is not None else q_x
|
||||
if attn_mask is not None:
|
||||
# Leave boolean masks as is
|
||||
if not attn_mask.dtype == torch.bool:
|
||||
attn_mask = attn_mask.to(q_x.dtype)
|
||||
|
||||
return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0]
|
||||
|
||||
def forward(
|
||||
self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
"""Apply residual attention with layer normalization and MLP, supporting optional cross-attention."""
|
||||
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
||||
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
||||
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
|
||||
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Stack of residual attention blocks forming a transformer encoder with optional gradient checkpointing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
ls_init_value: float | None = None,
|
||||
act_layer: Callable[[], nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
compile_mode: str | None = None,
|
||||
use_act_checkpoint: bool = False,
|
||||
):
|
||||
"""Initialize transformer with configurable depth, width, and optional compilation/checkpointing."""
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.grad_checkpointing = use_act_checkpoint
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(
|
||||
width,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
ls_init_value=ls_init_value,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
if compile_mode is not None:
|
||||
self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
|
||||
if self.grad_checkpointing:
|
||||
torch._dynamo.config.optimize_ddp = False
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor:
|
||||
"""Process input through all transformer blocks with optional gradient checkpointing during training."""
|
||||
for _, r in enumerate(self.resblocks):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting() and self.training:
|
||||
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
|
||||
def text_global_pool(
|
||||
x: torch.Tensor, text: torch.Tensor = None, pool_type: str = "argmax"
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Extract pooled representation and tokens from text embeddings using specified pooling strategy
|
||||
(first/last/argmax/none).
|
||||
"""
|
||||
if pool_type == "first":
|
||||
pooled, tokens = x[:, 0], x[:, 1:]
|
||||
elif pool_type == "last":
|
||||
pooled, tokens = x[:, -1], x[:, :-1]
|
||||
elif pool_type == "argmax":
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
assert text is not None
|
||||
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
|
||||
else:
|
||||
pooled = tokens = x
|
||||
return pooled, tokens
|
||||
|
||||
|
||||
class TextTransformer(nn.Module):
|
||||
"""Text transformer encoder with causal masking and flexible pooling strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_length: int = 77,
|
||||
vocab_size: int = 49408,
|
||||
width: int = 512,
|
||||
heads: int = 8,
|
||||
layers: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
ls_init_value: float | None = None,
|
||||
output_dim: int = 512,
|
||||
no_causal_mask: bool = False,
|
||||
pool_type: str = "none", # no pooling
|
||||
proj_bias: bool = False,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
output_tokens: bool = False,
|
||||
use_ln_post: bool = True,
|
||||
compile_mode: str | None = None,
|
||||
use_act_checkpoint: bool = False,
|
||||
):
|
||||
"""Initialize text transformer with embedding layers, transformer blocks, and pooling options."""
|
||||
super().__init__()
|
||||
assert pool_type in ("first", "last", "argmax", "none")
|
||||
self.output_tokens = output_tokens
|
||||
self.num_pos = self.context_length = context_length
|
||||
self.vocab_size = vocab_size
|
||||
self.width = width
|
||||
self.output_dim = output_dim
|
||||
self.heads = heads
|
||||
self.pool_type = pool_type
|
||||
|
||||
self.token_embedding = nn.Embedding(self.vocab_size, width)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
||||
self.transformer = Transformer(
|
||||
width=width,
|
||||
layers=layers,
|
||||
heads=heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
ls_init_value=ls_init_value,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
compile_mode=compile_mode,
|
||||
use_act_checkpoint=use_act_checkpoint,
|
||||
)
|
||||
self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
|
||||
if no_causal_mask:
|
||||
self.attn_mask = None
|
||||
else:
|
||||
self.register_buffer("attn_mask", self.build_causal_mask(), persistent=False)
|
||||
if proj_bias:
|
||||
self.text_projection = nn.Linear(width, output_dim)
|
||||
else:
|
||||
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
||||
|
||||
def build_causal_mask(self) -> torch.Tensor:
|
||||
"""Create a causal attention mask to prevent attention to future tokens."""
|
||||
# lazily create causal attention mask, with full attention between the tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.num_pos, self.num_pos)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
def forward(self, text: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass through the text transformer, returning pooled output and optionally token embeddings."""
|
||||
seq_len = text.shape[1]
|
||||
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||
|
||||
attn_mask = self.attn_mask
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask[:seq_len, :seq_len]
|
||||
|
||||
x = x + self.positional_embedding[:seq_len]
|
||||
x = self.transformer(x, attn_mask=attn_mask)
|
||||
|
||||
x = self.ln_final(x)
|
||||
pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
|
||||
if self.text_projection is not None:
|
||||
if isinstance(self.text_projection, nn.Linear):
|
||||
pooled = self.text_projection(pooled)
|
||||
else:
|
||||
pooled = pooled @ self.text_projection
|
||||
if self.output_tokens:
|
||||
return pooled, tokens
|
||||
return pooled
|
||||
|
||||
|
||||
class VETextEncoder(nn.Module):
|
||||
"""Text encoder for Vision Encoder (VE) models, combining a text transformer and a linear resizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
tokenizer: Callable,
|
||||
width: int = 1024,
|
||||
heads: int = 16,
|
||||
layers: int = 24,
|
||||
context_length: int = 32,
|
||||
vocab_size: int = 49408,
|
||||
use_ln_post: bool = True,
|
||||
compile_mode: str | None = None,
|
||||
use_act_checkpoint: bool = True,
|
||||
):
|
||||
"""Initialize VE text encoder with a text transformer and a linear resizer to match decoder dimensions."""
|
||||
super().__init__()
|
||||
self.context_length = context_length
|
||||
self.use_ln_post = use_ln_post
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.encoder = TextTransformer(
|
||||
context_length=self.context_length,
|
||||
vocab_size=vocab_size,
|
||||
width=width,
|
||||
heads=heads,
|
||||
layers=layers,
|
||||
# we want the tokens, not just the pooled output
|
||||
output_tokens=True,
|
||||
use_ln_post=use_ln_post,
|
||||
compile_mode=compile_mode,
|
||||
use_act_checkpoint=use_act_checkpoint,
|
||||
)
|
||||
self.resizer = nn.Linear(self.encoder.width, d_model)
|
||||
|
||||
def forward(
|
||||
self, text: list[str] | tuple[torch.Tensor, torch.Tensor, dict], input_boxes: list | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Encode text input, either raw strings or pre-encoded tensors, and resize to match decoder dimensions."""
|
||||
if isinstance(text[0], str):
|
||||
# no use case for this
|
||||
assert input_boxes is None or len(input_boxes) == 0, "not supported"
|
||||
|
||||
# Encode the text
|
||||
tokenized = self.tokenizer(text, context_length=self.context_length).to(
|
||||
self.resizer.weight.device
|
||||
) # [b, seq_len]
|
||||
text_attention_mask = (tokenized != 0).bool()
|
||||
|
||||
# manually embed the tokens
|
||||
inputs_embeds = self.encoder.token_embedding(tokenized) # [b, seq_len, d=1024]
|
||||
_, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024]
|
||||
|
||||
assert text_memory.shape[1] == inputs_embeds.shape[1]
|
||||
# Invert attention mask because its the opposite in pytorch transformer
|
||||
text_attention_mask = text_attention_mask.ne(1)
|
||||
# Transpose memory because pytorch's attention expects sequence first
|
||||
text_memory = text_memory.transpose(0, 1)
|
||||
# Resize the encoder hidden states to be of the same d_model as the decoder
|
||||
text_memory_resized = self.resizer(text_memory)
|
||||
else:
|
||||
# The text is already encoded, use as is.
|
||||
text_attention_mask, text_memory_resized, tokenized = text
|
||||
inputs_embeds = tokenized["inputs_embeds"]
|
||||
assert input_boxes is None or len(input_boxes) == 0, "Can't replace boxes in text if it's already encoded"
|
||||
|
||||
# Note that the input_embeds are returned in pytorch's convention (sequence first)
|
||||
return (
|
||||
text_attention_mask,
|
||||
text_memory_resized,
|
||||
inputs_embeds.transpose(0, 1),
|
||||
)
|
||||
242
ultralytics/models/sam/sam3/tokenizer_ve.py
Normal file
242
ultralytics/models/sam/sam3/tokenizer_ve.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
Text Tokenizer.
|
||||
|
||||
Copied and lightly adapted from VE repo, which in turn copied
|
||||
from open_clip and openAI CLIP.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gzip
|
||||
import html
|
||||
import io
|
||||
import os
|
||||
import string
|
||||
from functools import lru_cache
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
import torch
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
|
||||
@lru_cache
|
||||
def bytes_to_unicode():
|
||||
"""Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode
|
||||
strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When
|
||||
you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a
|
||||
significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8
|
||||
bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length
|
||||
strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
"""Basic text cleaning: fix unicode and unescape HTML entities."""
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
"""Remove redundant whitespace."""
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def _clean_canonicalize(x):
|
||||
"""Clean text and canonicalize it."""
|
||||
# basic, remove whitespace, remove punctuation, lower case
|
||||
return canonicalize_text(basic_clean(x))
|
||||
|
||||
|
||||
def _clean_lower(x):
|
||||
"""Clean text and return lowercase."""
|
||||
# basic, remove whitespace, lower case
|
||||
return whitespace_clean(basic_clean(x)).lower()
|
||||
|
||||
|
||||
def _clean_whitespace(x):
|
||||
"""Clean text and remove redundant whitespace."""
|
||||
# basic, remove whitespace
|
||||
return whitespace_clean(basic_clean(x))
|
||||
|
||||
|
||||
def get_clean_fn(type: str):
|
||||
"""Get text cleaning function by name."""
|
||||
if type == "canonicalize":
|
||||
return _clean_canonicalize
|
||||
elif type == "lower":
|
||||
return _clean_lower
|
||||
elif type == "whitespace":
|
||||
return _clean_whitespace
|
||||
else:
|
||||
assert False, f"Invalid clean function ({type})."
|
||||
|
||||
|
||||
def canonicalize_text(text, *, keep_punctuation_exact_string=None):
|
||||
"""Returns canonicalized `text` (lowercase and punctuation removed). From:
|
||||
https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94.
|
||||
|
||||
Args:
|
||||
text: string to be canonicalized.
|
||||
keep_punctuation_exact_string: If provided, then this exact string kept. For example providing '{}' will keep
|
||||
any occurrences of '{}' (but will still remove '{' and '}' that appear separately).
|
||||
"""
|
||||
text = text.replace("_", " ")
|
||||
if keep_punctuation_exact_string:
|
||||
text = keep_punctuation_exact_string.join(
|
||||
part.translate(str.maketrans("", "", string.punctuation))
|
||||
for part in text.split(keep_punctuation_exact_string)
|
||||
)
|
||||
else:
|
||||
text = text.translate(str.maketrans("", "", string.punctuation))
|
||||
text = text.lower()
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
class SimpleTokenizer:
|
||||
"""A simple tokenizer for text inputs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bpe_path: str | os.PathLike,
|
||||
additional_special_tokens: list[str] | None = None,
|
||||
context_length: int = 77,
|
||||
clean: str = "lower",
|
||||
):
|
||||
"""The tokenizer for text inputs."""
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
with g_pathmgr.open(bpe_path, "rb") as fh:
|
||||
bpe_bytes = io.BytesIO(fh.read())
|
||||
merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
|
||||
# merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
|
||||
merges = merges[1 : 49152 - 256 - 2 + 1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v + "</w>" for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append("".join(merge))
|
||||
special_tokens = ["<start_of_text>", "<end_of_text>"]
|
||||
if additional_special_tokens:
|
||||
special_tokens += additional_special_tokens
|
||||
vocab.extend(special_tokens)
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {t: t for t in special_tokens}
|
||||
special = "|".join(special_tokens)
|
||||
self.pat = re.compile(
|
||||
special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
self.vocab_size = len(self.encoder)
|
||||
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
||||
self.sot_token_id = self.all_special_ids[0]
|
||||
self.eot_token_id = self.all_special_ids[1]
|
||||
self.context_length = context_length
|
||||
self.clean_fn = get_clean_fn(clean)
|
||||
|
||||
def bpe(self, token):
|
||||
"""Byte Pair Encoding."""
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = (*tuple(token[:-1]), token[-1] + "</w>")
|
||||
pairs = get_pairs(word)
|
||||
if not pairs:
|
||||
return token + "</w>"
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
"""Encode text to a sequence of BPE tokens."""
|
||||
bpe_tokens = []
|
||||
text = self.clean_fn(text)
|
||||
for token in re.findall(self.pat, text):
|
||||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
"""Decodes a sequence of tokens back into a text string."""
|
||||
text = "".join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("</w>", " ")
|
||||
return text
|
||||
|
||||
def __call__(self, texts: str | list[str], context_length: int | None = None) -> torch.LongTensor:
|
||||
"""Returns the tokenized representation of given input string(s) Parameters. ---------- texts : Union[str,
|
||||
list[str]] An input string or a list of input strings to tokenize context_length : int The context
|
||||
length to use; all CLIP models use 77 as the context length.
|
||||
|
||||
Returns:
|
||||
-------: A two-dimensional tensor containing the resulting tokens, shape = [number of input strings,
|
||||
context_length]
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
context_length = context_length or self.context_length
|
||||
assert context_length, "Please set a valid context length"
|
||||
all_tokens = [[self.sot_token_id, *self.encode(text), self.eot_token_id] for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
tokens = tokens[:context_length] # Truncate
|
||||
tokens[-1] = self.eot_token_id
|
||||
result[i, : len(tokens)] = torch.tensor(tokens)
|
||||
return result
|
||||
546
ultralytics/models/sam/sam3/vitdet.py
Normal file
546
ultralytics/models/sam/sam3/vitdet.py
Normal file
|
|
@ -0,0 +1,546 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
ViTDet backbone adapted from Detectron2.
|
||||
This module implements Vision Transformer (ViT) backbone for object detection.
|
||||
|
||||
Rope embedding code adopted from:
|
||||
1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
||||
2. https://github.com/naver-ai/rope-vit
|
||||
3. https://github.com/lucidrains/rotary-embedding-torch
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from torch import Tensor
|
||||
|
||||
from ultralytics.models.sam.modules.blocks import PatchEmbed
|
||||
from ultralytics.models.sam.modules.utils import (
|
||||
apply_rotary_enc,
|
||||
compute_axial_cis,
|
||||
concat_rel_pos,
|
||||
get_abs_pos,
|
||||
window_partition,
|
||||
window_unpartition,
|
||||
)
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
from .model_misc import LayerScale
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings and 2d-rope."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
input_size: tuple[int, int] | None = None,
|
||||
cls_token: bool = False,
|
||||
use_rope: bool = False,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_pt_size: tuple[int, int] | None = None,
|
||||
rope_interp: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
input_size (int or None): Input resolution for calculating the relative positional parameter size or rope
|
||||
size.
|
||||
attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer".
|
||||
cls_token: whether a cls_token is present.
|
||||
use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
|
||||
use_rel_pos: whether to use relative positional embeddings
|
||||
rope_theta: control frequencies of rope
|
||||
rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
|
||||
rope_interp: whether to interpolate (or extrapolate) rope to match input size.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.cls_token = cls_token
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
# rel_pos embeddings and rope
|
||||
self.use_rel_pos = use_rel_pos
|
||||
self.input_size = input_size
|
||||
|
||||
self.use_rope = use_rope
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_pt_size = rope_pt_size
|
||||
self.rope_interp = rope_interp
|
||||
|
||||
# init rel_pos embeddings and rope
|
||||
self._setup_rel_pos(rel_pos_zero_init, input_size)
|
||||
self._setup_rope_freqs(input_size)
|
||||
|
||||
def _setup_rel_pos(self, rel_pos_zero_init: bool = True, input_size: tuple[int, int] | None = None) -> None:
|
||||
"""Setup relative positional embeddings."""
|
||||
if not self.use_rel_pos:
|
||||
self.rel_pos_h = None
|
||||
self.rel_pos_w = None
|
||||
return
|
||||
|
||||
assert input_size is not None
|
||||
assert self.cls_token is False, "not supported"
|
||||
# initialize relative positional embeddings
|
||||
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
|
||||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
|
||||
|
||||
if not rel_pos_zero_init:
|
||||
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
|
||||
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
|
||||
|
||||
# Precompute the relative coords
|
||||
H, W = input_size
|
||||
q_coords = torch.arange(H)[:, None]
|
||||
k_coords = torch.arange(W)[None, :]
|
||||
relative_coords = (q_coords - k_coords) + (H - 1)
|
||||
self.relative_coords = relative_coords.long()
|
||||
|
||||
def _setup_rope_freqs(self, input_size: tuple[int, int] | None = None) -> None:
|
||||
"""Setup 2d-rope frequencies."""
|
||||
if not self.use_rope:
|
||||
self.freqs_cis = None
|
||||
return
|
||||
|
||||
assert input_size is not None
|
||||
# determine rope input size
|
||||
if self.rope_pt_size is None:
|
||||
self.rope_pt_size = input_size
|
||||
|
||||
# initialize 2d rope freqs
|
||||
self.compute_cis = partial(
|
||||
compute_axial_cis,
|
||||
dim=self.head_dim,
|
||||
theta=self.rope_theta,
|
||||
)
|
||||
|
||||
# interpolate rope
|
||||
scale_pos = 1.0
|
||||
if self.rope_interp:
|
||||
scale_pos = self.rope_pt_size[0] / input_size[0]
|
||||
# get scaled freqs_cis
|
||||
freqs_cis = self.compute_cis(
|
||||
end_x=input_size[0],
|
||||
end_y=input_size[1],
|
||||
scale_pos=scale_pos,
|
||||
)
|
||||
if self.cls_token:
|
||||
t = torch.zeros(
|
||||
self.head_dim // 2,
|
||||
dtype=torch.float32,
|
||||
device=freqs_cis.device,
|
||||
)
|
||||
cls_freqs_cis = torch.polar(torch.ones_like(t), t)[None, :]
|
||||
freqs_cis = torch.cat([cls_freqs_cis, freqs_cis], dim=0)
|
||||
|
||||
self.freqs_cis = freqs_cis
|
||||
|
||||
def _apply_rope(self, q, k) -> tuple[Tensor, Tensor]:
|
||||
"""Apply 2d-rope to q and k."""
|
||||
if not self.use_rope:
|
||||
return q, k
|
||||
|
||||
assert self.freqs_cis is not None
|
||||
return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis.to(q.device))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward pass of attention block."""
|
||||
s = 1 if self.cls_token else 0 # used to exclude cls_token
|
||||
if x.ndim == 4:
|
||||
B, H, W, _ = x.shape
|
||||
assert s == 0 # no cls_token
|
||||
L = H * W
|
||||
ndim = 4
|
||||
else:
|
||||
assert x.ndim == 3
|
||||
B, L, _ = x.shape
|
||||
ndim = 3
|
||||
H = W = math.sqrt(L - s)
|
||||
|
||||
# qkv with shape (3, B, nHead, L, C)
|
||||
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, -1)
|
||||
# q, k, v with shape (B, nHead, L, C)
|
||||
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
|
||||
|
||||
# handle rope and rel pos embeddings
|
||||
q, k = self._apply_rope(q, k)
|
||||
if self.use_rel_pos:
|
||||
q, k = concat_rel_pos(
|
||||
q.flatten(0, 1),
|
||||
k.flatten(0, 1),
|
||||
(H, W),
|
||||
x.shape[1:3],
|
||||
self.rel_pos_h,
|
||||
self.rel_pos_w,
|
||||
rescale=True,
|
||||
relative_coords=self.relative_coords,
|
||||
)
|
||||
|
||||
# sdpa expects [B, nheads, H*W, C] so we transpose back
|
||||
q = q.reshape(B, self.num_heads, H * W, -1)
|
||||
k = k.reshape(B, self.num_heads, H * W, -1)
|
||||
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
if ndim == 4:
|
||||
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
||||
else:
|
||||
x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
input_size: tuple[int, int] | None = None,
|
||||
use_rope: bool = False,
|
||||
rope_pt_size: tuple[int, int] | None = None,
|
||||
rope_interp: bool = False,
|
||||
cls_token: bool = False,
|
||||
dropout: float = 0.0,
|
||||
init_values: float | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
drop_path (float): Stochastic depth rate.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks. If it equals 0, then not use window attention.
|
||||
input_size (int or None): Input resolution for calculating the relative positional parameter size.
|
||||
dropout (float): Dropout rate.
|
||||
cls_token: whether a cls_token is present.
|
||||
use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
|
||||
rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
|
||||
rope_interp: whether to interpolate (or extrapolate) rope to match target input size, expected to specify
|
||||
source size as rope_pt_size.
|
||||
init_values: layer scale init, None for no layer scale.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
check_requirements("timm")
|
||||
from timm.layers import DropPath, Mlp
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
input_size=input_size if window_size == 0 else (window_size, window_size),
|
||||
use_rope=use_rope,
|
||||
rope_pt_size=rope_pt_size,
|
||||
rope_interp=rope_interp,
|
||||
cls_token=cls_token,
|
||||
)
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=(dropout, 0.0),
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward pass of the transformer block."""
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
if self.window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, self.window_size)
|
||||
|
||||
x = self.ls1(self.attn(x))
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
|
||||
x = shortcut + self.dropout(self.drop_path(x))
|
||||
x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x)))))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ViT(nn.Module):
|
||||
"""This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. "Exploring Plain Vision Transformer
|
||||
Backbones for Object Detection", https://arxiv.org/abs/2203.16527.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
drop_path_rate: float = 0.0,
|
||||
norm_layer: Callable[..., nn.Module] | str = "LayerNorm",
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
tile_abs_pos: bool = True,
|
||||
rel_pos_blocks: tuple[int, ...] | bool = (2, 5, 8, 11),
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 14,
|
||||
global_att_blocks: tuple[int, ...] = (2, 5, 8, 11),
|
||||
use_rope: bool = False,
|
||||
rope_pt_size: int | None = None,
|
||||
use_interp_rope: bool = False,
|
||||
pretrain_img_size: int = 224,
|
||||
pretrain_use_cls_token: bool = True,
|
||||
retain_cls_token: bool = True,
|
||||
dropout: float = 0.0,
|
||||
return_interm_layers: bool = False,
|
||||
init_values: float | None = None, # for layerscale
|
||||
ln_pre: bool = False,
|
||||
ln_post: bool = False,
|
||||
bias_patch_embed: bool = True,
|
||||
compile_mode: str | None = None,
|
||||
use_act_checkpoint: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int): Input image size. Only relevant for rel pos or rope.
|
||||
patch_size (int): Patch size.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
depth (int): Depth of ViT.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
drop_path_rate (float): Stochastic depth rate.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_abs_pos (bool): If True, use absolute positional embeddings.
|
||||
tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation.
|
||||
rel_pos_blocks (list): Blocks which have rel pos embeddings.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks.
|
||||
global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention).
|
||||
use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together).
|
||||
rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling.
|
||||
use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size, expected to
|
||||
specify source size as rope_pt_size.
|
||||
use_act_checkpoint (bool): If True, use activation checkpointing.
|
||||
pretrain_img_size (int): input image size for pretraining models.
|
||||
pretrain_use_cls_token (bool): If True, pretraining models use class token.
|
||||
retain_cls_token: whether cls_token should be retained.
|
||||
dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp.
|
||||
return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks).
|
||||
init_values: layer scale init, None for no layer scale.
|
||||
ln_pre (bool): If True, apply layer norm before transformer blocks.
|
||||
ln_post (bool): If True, apply layer norm after transformer blocks.
|
||||
bias_patch_embed (bool): bias in conv for patch embed?
|
||||
compile_mode (str): mode to compile the forward.
|
||||
"""
|
||||
super().__init__()
|
||||
self.pretrain_use_cls_token = pretrain_use_cls_token
|
||||
|
||||
window_block_indexes = [i for i in range(depth) if i not in global_att_blocks]
|
||||
self.full_attn_ids = list(global_att_blocks)
|
||||
self.rel_pos_blocks = [False] * depth
|
||||
if isinstance(rel_pos_blocks, bool) and rel_pos_blocks:
|
||||
self.rel_pos_blocks = [True] * depth
|
||||
else:
|
||||
for i in rel_pos_blocks:
|
||||
self.rel_pos_blocks[i] = True
|
||||
|
||||
self.retain_cls_token = retain_cls_token
|
||||
if self.retain_cls_token:
|
||||
assert pretrain_use_cls_token
|
||||
assert len(window_block_indexes) == 0, "windowing not supported with cls token"
|
||||
|
||||
assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
|
||||
|
||||
scale = embed_dim**-0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(1, 1, embed_dim))
|
||||
|
||||
if isinstance(norm_layer, str):
|
||||
norm_layer = partial(getattr(nn, norm_layer), eps=1e-5)
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
kernel_size=(patch_size, patch_size),
|
||||
stride=(patch_size, patch_size),
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=bias_patch_embed,
|
||||
)
|
||||
|
||||
# Handle absolute positional embedding
|
||||
self.tile_abs_pos = tile_abs_pos
|
||||
self.use_abs_pos = use_abs_pos
|
||||
if self.tile_abs_pos:
|
||||
assert self.use_abs_pos
|
||||
|
||||
if self.use_abs_pos:
|
||||
# Initialize absolute positional embedding with pretrain image size.
|
||||
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
|
||||
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.window_size = window_size
|
||||
self.blocks = nn.ModuleList()
|
||||
cur_stage = 1
|
||||
for i in range(depth):
|
||||
block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
use_rel_pos=self.rel_pos_blocks[i],
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
window_size=window_size if i in window_block_indexes else 0,
|
||||
input_size=(img_size // patch_size, img_size // patch_size),
|
||||
use_rope=use_rope,
|
||||
rope_pt_size=((window_size, window_size) if rope_pt_size is None else (rope_pt_size, rope_pt_size)),
|
||||
rope_interp=use_interp_rope,
|
||||
cls_token=self.retain_cls_token,
|
||||
dropout=dropout,
|
||||
init_values=init_values,
|
||||
)
|
||||
|
||||
if i not in window_block_indexes:
|
||||
cur_stage += 1
|
||||
|
||||
self.use_act_checkpoint = use_act_checkpoint
|
||||
|
||||
self.blocks.append(block)
|
||||
|
||||
self.return_interm_layers = return_interm_layers
|
||||
self.channel_list = [embed_dim] * len(self.full_attn_ids) if return_interm_layers else [embed_dim]
|
||||
|
||||
if self.pos_embed is not None:
|
||||
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||||
|
||||
self.ln_pre = norm_layer(embed_dim) if ln_pre else nn.Identity()
|
||||
self.ln_post = norm_layer(embed_dim) if ln_post else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
if compile_mode is not None:
|
||||
self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
|
||||
if self.use_act_checkpoint and self.training:
|
||||
torch._dynamo.config.optimize_ddp = False
|
||||
|
||||
def _init_weights(self, m: nn.Module) -> None:
|
||||
"""Initialize the weights."""
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""Vit forward path and get feature maps."""
|
||||
x = self.patch_embed(x)
|
||||
h, w = x.shape[1], x.shape[2]
|
||||
|
||||
s = 0
|
||||
if self.retain_cls_token:
|
||||
# If cls_token is retained, we don't
|
||||
# maintain spatial shape
|
||||
x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
|
||||
s = 1
|
||||
|
||||
if self.pos_embed is not None:
|
||||
x = x + get_abs_pos(
|
||||
self.pos_embed,
|
||||
self.pretrain_use_cls_token,
|
||||
(h, w),
|
||||
self.retain_cls_token,
|
||||
tiling=self.tile_abs_pos,
|
||||
)
|
||||
|
||||
x = self.ln_pre(x)
|
||||
|
||||
outputs = []
|
||||
for i, blk in enumerate(self.blocks):
|
||||
if self.use_act_checkpoint and self.training:
|
||||
x = checkpoint.checkpoint(blk, x, use_reentrant=False)
|
||||
else:
|
||||
x = blk(x)
|
||||
if (i == self.full_attn_ids[-1]) or (self.return_interm_layers and i in self.full_attn_ids):
|
||||
if i == self.full_attn_ids[-1]:
|
||||
x = self.ln_post(x)
|
||||
|
||||
feats = x[:, s:]
|
||||
if feats.ndim == 4:
|
||||
feats = feats.permute(0, 3, 1, 2)
|
||||
else:
|
||||
assert feats.ndim == 3
|
||||
h = w = math.sqrt(feats.shape[1])
|
||||
feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)
|
||||
|
||||
outputs.append(feats)
|
||||
|
||||
return outputs
|
||||
|
||||
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
|
||||
"""Setup rel pos embeddings and rope freqs for a new input image size."""
|
||||
for block in self.blocks:
|
||||
if block.window_size != 0:
|
||||
continue
|
||||
block.attn._setup_rel_pos(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))
|
||||
block.attn._setup_rope_freqs(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))
|
||||
165
ultralytics/models/sam/sam3/vl_combiner.py
Normal file
165
ultralytics/models/sam/sam3/vl_combiner.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
"""Provides utility to combine a vision backbone with a language backbone."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
from .necks import Sam3DualViTDetNeck
|
||||
|
||||
|
||||
class SAM3VLBackbone(nn.Module):
|
||||
"""This backbone combines a vision backbone and a language backbone without fusion. As such it is more of a
|
||||
convenience wrapper to handle the two backbones together.
|
||||
|
||||
It adds support for activation checkpointing and compilation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
visual: Sam3DualViTDetNeck,
|
||||
text,
|
||||
compile_visual: bool = False,
|
||||
act_ckpt_whole_vision_backbone: bool = False,
|
||||
act_ckpt_whole_language_backbone: bool = False,
|
||||
scalp=0,
|
||||
):
|
||||
"""Initialize the backbone combiner.
|
||||
|
||||
:param visual: The vision backbone to use
|
||||
:param text: The text encoder to use
|
||||
"""
|
||||
super().__init__()
|
||||
self.vision_backbone: Sam3DualViTDetNeck = torch.compile(visual) if compile_visual else visual
|
||||
self.language_backbone = text
|
||||
self.scalp = scalp
|
||||
# allow running activation checkpointing on the entire vision and language backbones
|
||||
self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
|
||||
self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
|
||||
|
||||
def forward(
|
||||
self,
|
||||
samples: torch.Tensor,
|
||||
captions: list[str],
|
||||
input_boxes: torch.Tensor = None,
|
||||
additional_text: list[str] | None = None,
|
||||
):
|
||||
"""Forward pass of the backbone combiner.
|
||||
|
||||
:param samples: The input images
|
||||
:param captions: The input captions
|
||||
:param input_boxes: If the text contains place-holders for boxes, this
|
||||
parameter contains the tensor containing their spatial features
|
||||
:param additional_text: This can be used to encode some additional text
|
||||
(different from the captions) in the same forward of the backbone
|
||||
:return: Output dictionary with the following keys:
|
||||
- vision_features: The output of the vision backbone
|
||||
- language_features: The output of the language backbone
|
||||
- language_mask: The attention mask of the language backbone
|
||||
- vision_pos_enc: The positional encoding of the vision backbone
|
||||
- (optional) additional_text_features: The output of the language
|
||||
backbone for the additional text
|
||||
- (optional) additional_text_mask: The attention mask of the
|
||||
language backbone for the additional text
|
||||
"""
|
||||
output = self.forward_image(samples)
|
||||
output.update(self.forward_text(captions, input_boxes, additional_text))
|
||||
return output
|
||||
|
||||
def forward_image(self, samples: torch.Tensor):
|
||||
"""Forward pass of the vision backbone and get both SAM3 and SAM2 features."""
|
||||
# Forward through backbone
|
||||
sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(samples)
|
||||
if self.scalp > 0:
|
||||
# Discard the lowest resolution features
|
||||
sam3_features, sam3_pos = (
|
||||
sam3_features[: -self.scalp],
|
||||
sam3_pos[: -self.scalp],
|
||||
)
|
||||
if sam2_features is not None and sam2_pos is not None:
|
||||
sam2_features, sam2_pos = (
|
||||
sam2_features[: -self.scalp],
|
||||
sam2_pos[: -self.scalp],
|
||||
)
|
||||
|
||||
sam2_output = None
|
||||
|
||||
if sam2_features is not None and sam2_pos is not None:
|
||||
sam2_src = sam2_features[-1]
|
||||
sam2_output = {
|
||||
"vision_features": sam2_src,
|
||||
"vision_pos_enc": sam2_pos,
|
||||
"backbone_fpn": sam2_features,
|
||||
}
|
||||
|
||||
sam3_src = sam3_features[-1]
|
||||
return {
|
||||
"vision_features": sam3_src,
|
||||
"vision_pos_enc": sam3_pos,
|
||||
"backbone_fpn": sam3_features,
|
||||
"sam2_backbone_out": sam2_output,
|
||||
}
|
||||
|
||||
def forward_image_sam2(self, samples: torch.Tensor):
|
||||
"""Forward pass of the vision backbone to get SAM2 features only."""
|
||||
xs = self.vision_backbone.trunk(samples)
|
||||
sam2_features, sam2_pos = [], []
|
||||
x = xs[-1] # simpleFPN
|
||||
|
||||
assert self.vision_backbone.sam2_convs is not None, "SAM2 neck is not available."
|
||||
for i in range(len(self.vision_backbone.sam2_convs)):
|
||||
sam2_x_out = self.vision_backbone.sam2_convs[i](x)
|
||||
sam2_pos_out = self.vision_backbone.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
|
||||
sam2_features.append(sam2_x_out)
|
||||
sam2_pos.append(sam2_pos_out)
|
||||
|
||||
if self.scalp > 0:
|
||||
# Discard the lowest resolution features
|
||||
sam2_features, sam2_pos = (
|
||||
sam2_features[: -self.scalp],
|
||||
sam2_pos[: -self.scalp],
|
||||
)
|
||||
|
||||
return {
|
||||
"vision_features": sam2_features[-1],
|
||||
"vision_pos_enc": sam2_pos,
|
||||
"backbone_fpn": sam2_features,
|
||||
}
|
||||
|
||||
def forward_text(self, captions, input_boxes=None, additional_text=None):
|
||||
"""Forward pass of the text encoder."""
|
||||
output = {}
|
||||
|
||||
# Forward through text_encoder
|
||||
text_to_encode = copy(captions)
|
||||
if additional_text is not None:
|
||||
# if there are additional_text, we piggy-back them into this forward.
|
||||
# They'll be used later for output alignment
|
||||
text_to_encode += additional_text
|
||||
|
||||
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]):
|
||||
text_attention_mask, text_memory, text_embeds = self.language_backbone(text_to_encode, input_boxes)
|
||||
|
||||
if additional_text is not None:
|
||||
output["additional_text_features"] = text_memory[:, -len(additional_text) :]
|
||||
output["additional_text_mask"] = text_attention_mask[-len(additional_text) :]
|
||||
|
||||
text_memory = text_memory[:, : len(captions)]
|
||||
text_attention_mask = text_attention_mask[: len(captions)]
|
||||
text_embeds = text_embeds[:, : len(captions)]
|
||||
output["language_features"] = text_memory
|
||||
output["language_mask"] = text_attention_mask
|
||||
output["language_embeds"] = text_embeds # Text embeddings before forward to the encoder
|
||||
|
||||
return output
|
||||
|
||||
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
|
||||
"""Set the image size for the vision backbone."""
|
||||
self.vision_backbone.set_imgsz(imgsz)
|
||||
|
|
@ -359,7 +359,15 @@ class MLP(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act=nn.ReLU, sigmoid: bool = False
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
act=nn.ReLU,
|
||||
sigmoid: bool = False,
|
||||
residual: bool = False,
|
||||
out_norm: nn.Module = None,
|
||||
):
|
||||
"""Initialize the MLP with specified input, hidden, output dimensions and number of layers.
|
||||
|
||||
|
|
@ -370,6 +378,8 @@ class MLP(nn.Module):
|
|||
num_layers (int): Number of layers.
|
||||
act (nn.Module): Activation function.
|
||||
sigmoid (bool): Whether to apply sigmoid to the output.
|
||||
residual (bool): Whether to use residual connections.
|
||||
out_norm (nn.Module, optional): Normalization layer for the output.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
|
|
@ -377,6 +387,12 @@ class MLP(nn.Module):
|
|||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim, *h], [*h, output_dim]))
|
||||
self.sigmoid = sigmoid
|
||||
self.act = act()
|
||||
if residual and input_dim != output_dim:
|
||||
raise ValueError("residual is only supported if input_dim == output_dim")
|
||||
self.residual = residual
|
||||
# whether to apply a normalization layer to the output
|
||||
assert isinstance(out_norm, nn.Module) or out_norm is None
|
||||
self.out_norm = out_norm or nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass for the entire MLP.
|
||||
|
|
@ -387,8 +403,12 @@ class MLP(nn.Module):
|
|||
Returns:
|
||||
(torch.Tensor): Output tensor after MLP.
|
||||
"""
|
||||
orig_x = x
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
if getattr(self, "residual", False):
|
||||
x = x + orig_x
|
||||
x = getattr(self, "out_norm", nn.Identity())(x)
|
||||
return x.sigmoid() if getattr(self, "sigmoid", False) else x
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -660,6 +660,4 @@ def clean_str(s):
|
|||
|
||||
def empty_like(x):
|
||||
"""Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
|
||||
return (
|
||||
torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
|
||||
)
|
||||
return torch.empty_like(x, dtype=x.dtype) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=x.dtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue