From e0764aa55d4e79aefdc917c7b1aeb97498440275 Mon Sep 17 00:00:00 2001 From: Jing Qiu <61612323+Laughing-q@users.noreply.github.com> Date: Fri, 12 Dec 2025 21:04:33 +0800 Subject: [PATCH] `ultralytics 8.3.237` SAM3 integration (#22897) Signed-off-by: Glenn Jocher Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher Co-authored-by: fatih akyon <34196005+fcakyon@users.noreply.github.com> --- docs/en/models/sam-3.md | 219 +- docs/en/reference/models/sam/build.md | 4 + docs/en/reference/models/sam/build_sam3.md | 32 + docs/en/reference/models/sam/modules/sam.md | 4 + docs/en/reference/models/sam/modules/utils.md | 8 + docs/en/reference/models/sam/predict.md | 16 + docs/en/reference/models/sam/sam3/decoder.md | 20 + docs/en/reference/models/sam/sam3/encoder.md | 28 + .../models/sam/sam3/geometry_encoders.md | 28 + .../sam/sam3/maskformer_segmentation.md | 32 + .../reference/models/sam/sam3/model_misc.md | 32 + docs/en/reference/models/sam/sam3/necks.md | 16 + .../reference/models/sam/sam3/sam3_image.md | 20 + .../models/sam/sam3/text_encoder_ve.md | 32 + .../reference/models/sam/sam3/tokenizer_ve.md | 52 + docs/en/reference/models/sam/sam3/vitdet.md | 24 + .../reference/models/sam/sam3/vl_combiner.md | 16 + mkdocs.yml | 13 + ultralytics/__init__.py | 2 +- ultralytics/engine/predictor.py | 5 +- ultralytics/models/sam/__init__.py | 15 +- ultralytics/models/sam/build.py | 25 +- ultralytics/models/sam/build_sam3.py | 374 +++ ultralytics/models/sam/model.py | 16 +- ultralytics/models/sam/modules/blocks.py | 28 +- ultralytics/models/sam/modules/decoders.py | 5 +- ultralytics/models/sam/modules/encoders.py | 5 +- .../models/sam/modules/memory_attention.py | 8 +- ultralytics/models/sam/modules/sam.py | 156 +- ultralytics/models/sam/modules/utils.py | 138 +- ultralytics/models/sam/predict.py | 2194 ++++++++++++++++- ultralytics/models/sam/sam3/__init__.py | 3 + ultralytics/models/sam/sam3/decoder.py | 546 ++++ ultralytics/models/sam/sam3/encoder.py | 535 ++++ .../models/sam/sam3/geometry_encoders.py | 415 ++++ .../sam/sam3/maskformer_segmentation.py | 286 +++ ultralytics/models/sam/sam3/model_misc.py | 198 ++ ultralytics/models/sam/sam3/necks.py | 129 + ultralytics/models/sam/sam3/sam3_image.py | 357 +++ .../models/sam/sam3/text_encoder_ve.py | 307 +++ ultralytics/models/sam/sam3/tokenizer_ve.py | 242 ++ ultralytics/models/sam/sam3/vitdet.py | 546 ++++ ultralytics/models/sam/sam3/vl_combiner.py | 165 ++ ultralytics/nn/modules/transformer.py | 22 +- ultralytics/utils/ops.py | 4 +- 45 files changed, 7070 insertions(+), 252 deletions(-) create mode 100644 docs/en/reference/models/sam/build_sam3.md create mode 100644 docs/en/reference/models/sam/sam3/decoder.md create mode 100644 docs/en/reference/models/sam/sam3/encoder.md create mode 100644 docs/en/reference/models/sam/sam3/geometry_encoders.md create mode 100644 docs/en/reference/models/sam/sam3/maskformer_segmentation.md create mode 100644 docs/en/reference/models/sam/sam3/model_misc.md create mode 100644 docs/en/reference/models/sam/sam3/necks.md create mode 100644 docs/en/reference/models/sam/sam3/sam3_image.md create mode 100644 docs/en/reference/models/sam/sam3/text_encoder_ve.md create mode 100644 docs/en/reference/models/sam/sam3/tokenizer_ve.md create mode 100644 docs/en/reference/models/sam/sam3/vitdet.md create mode 100644 docs/en/reference/models/sam/sam3/vl_combiner.md create mode 100644 ultralytics/models/sam/build_sam3.py create mode 100644 ultralytics/models/sam/sam3/__init__.py create mode 100644 ultralytics/models/sam/sam3/decoder.py create mode 100644 ultralytics/models/sam/sam3/encoder.py create mode 100644 ultralytics/models/sam/sam3/geometry_encoders.py create mode 100644 ultralytics/models/sam/sam3/maskformer_segmentation.py create mode 100644 ultralytics/models/sam/sam3/model_misc.py create mode 100644 ultralytics/models/sam/sam3/necks.py create mode 100644 ultralytics/models/sam/sam3/sam3_image.py create mode 100644 ultralytics/models/sam/sam3/text_encoder_ve.py create mode 100644 ultralytics/models/sam/sam3/tokenizer_ve.py create mode 100644 ultralytics/models/sam/sam3/vitdet.py create mode 100644 ultralytics/models/sam/sam3/vl_combiner.py diff --git a/docs/en/models/sam-3.md b/docs/en/models/sam-3.md index 9ab1a5eb75..4eb1f9ab30 100644 --- a/docs/en/models/sam-3.md +++ b/docs/en/models/sam-3.md @@ -114,13 +114,17 @@ SAM 3 will be available directly in the Ultralytics package once integration lan pip install ultralytics ``` -Models will download automatically when first used. You can then use standard [predict mode](../modes/predict.md) and later [export](../modes/export.md) models to formats like [ONNX](../integrations/onnx.md) and [TensorRT](../integrations/tensorrt.md) for deployment. Watch for a package update with SAM-3 weights and configs soon. +!!! warning "SAM 3 Model Weights Required" + + Unlike other Ultralytics models, SAM 3 weights (`sam3.pt`) are **not automatically downloaded**. You must manually download the model weights from the [official SAM 3 repository](https://github.com/facebookresearch/sam3) before using SAM 3. Place the downloaded `sam3.pt` file in your working directory or specify the full path when loading the model. + +!!! note "BPE Vocabulary for Text Prompts" + + If you plan to use text-based concept segmentation with `SAM3SemanticPredictor`, you also need to download the BPE vocabulary file `bpe_simple_vocab_16e6.txt.gz` from the [SAM 3 assets](https://github.com/facebookresearch/sam3/blob/main/assets/bpe_simple_vocab_16e6.txt.gz). ## How to Use SAM 3: Versatility in Concept Segmentation -!!! warning "Ultralytics API preview" - - The following examples show the intended Ultralytics API once SAM 3 ships in the package. Until integration lands, details may change. +SAM 3 supports both Promptable Concept Segmentation (PCS) and Promptable Visual Segmentation (PVS) tasks through different predictor interfaces. ### Supported Tasks and Models @@ -138,143 +142,176 @@ SAM 3 supports both Promptable Concept Segmentation (PCS) and Promptable Visual !!! example "Text-based Concept Segmentation" - Find and segment all instances of a concept using a text description. + Find and segment all instances of a concept using a text description. Text prompts require the `SAM3SemanticPredictor` interface. === "Python" ```python - from ultralytics import SAM + from ultralytics.models.sam.predict import SAM3SemanticPredictor - # Load SAM 3 model - model = SAM("sam3.pt") + # Initialize predictor with configuration + overrides = dict( + conf=0.25, + task="segment", + mode="predict", + model="sam3.pt", + half=True, # Use FP16 for faster inference + ) + predictor = SAM3SemanticPredictor( + overrides=overrides, + bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz", # Required for text encoding + ) - # Segment all instances of a concept - results = model("path/to/image.jpg", prompt="yellow school bus") + # Set image once for multiple queries + predictor.set_image("path/to/image.jpg") + + # Query with multiple text prompts + results = predictor(text=["person", "bus", "glasses"], save=True) # Works with descriptive phrases - results = model("path/to/image.jpg", prompt="person wearing a red hat") + results = predictor(text=["person with red cloth", "person with blue cloth"], save=True) - # Or simple object names - results = model("path/to/image.jpg", prompt="striped cat") + # Query with a single concept + results = predictor(text=["a person"], save=True) ``` - === "CLI" + !!! note "Text Encoding Requirement" - ```bash - # Segment all matching concepts in an image - yolo segment model=sam3.pt source=path/to/image.jpg prompt="yellow school bus" - ``` - - !!! warning "API Preview" - - This example shows intended usage. Actual implementation pending Ultralytics integration. + The `bpe_path` parameter is required for text prompt encoding. Download the BPE vocabulary file from the [bpe_simple_vocab_16e6.txt.gz](https://github.com/facebookresearch/sam3/blob/main/assets/bpe_simple_vocab_16e6.txt.gz). #### Segment with Image Exemplars !!! example "Image Exemplar-based Segmentation" - Use one or more example objects to find all similar instances. + Use bounding boxes as visual prompts to find all similar instances. This also requires `SAM3SemanticPredictor` for concept-based matching. === "Python" ```python - from ultralytics import SAM + from ultralytics.models.sam.predict import SAM3SemanticPredictor - model = SAM("sam3.pt") + # Initialize predictor + overrides = dict(conf=0.25, task="segment", mode="predict", model="sam3.pt", half=True) + predictor = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz") - # Provide a positive example box - finds all similar objects - results = model("path/to/image.jpg", bboxes=[100, 150, 300, 400], labels=[1]) + # Set image + predictor.set_image("path/to/image.jpg") - # Add negative examples to exclude certain instances - results = model( - "path/to/image.jpg", - bboxes=[[100, 150, 300, 400], [500, 200, 600, 350]], # Two boxes - labels=[1, 0], # First is positive, second is negative - ) + # Provide bounding box examples to segment similar objects + results = predictor(bboxes=[[480.0, 290.0, 590.0, 650.0]], save=True) - # Combine text and image exemplars for precision - results = model("path/to/image.jpg", prompt="dog", bboxes=[100, 150, 300, 400], labels=[1]) + # Multiple bounding boxes for different concepts + results = predictor(bboxes=[[539, 599, 589, 639], [343, 267, 499, 662]], save=True) ``` - !!! warning "API Preview" +#### Feature-based Inference for Efficiency - This example shows intended usage. Actual implementation pending Ultralytics integration. +!!! example "Reusing Image Features for Multiple Queries" -#### Interactive Refinement - -!!! example "Iterative Refinement with Exemplars" - - Progressively improve results by adding exemplar prompts based on initial output. + Extract image features once and reuse them for multiple segmentation queries to improve efficiency. === "Python" ```python - from ultralytics import SAM + import cv2 - model = SAM("sam3.pt") + from ultralytics.models.sam.predict import SAM3SemanticPredictor + from ultralytics.utils.plotting import Annotator, colors - # Initial segmentation with text - results = model("path/to/image.jpg", prompt="car") + # Initialize predictors + overrides = dict(conf=0.50, task="segment", mode="predict", model="sam3.pt", verbose=False) + predictor = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz") + predictor2 = SAM3SemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz") - # If some cars are missed, add a positive exemplar - results = model( - "path/to/image.jpg", - prompt="car", - bboxes=[missed_car_box], - labels=[1], # Positive example - ) + # Extract features from the first predictor + source = "path/to/image.jpg" + predictor.set_image(source) + src_shape = cv2.imread(source).shape[:2] - # If false positives appear, add negative exemplars - results = model( - "path/to/image.jpg", - prompt="car", - bboxes=[false_positive_box], - labels=[0], # Negative example - ) + # Setup second predictor and reuse features + predictor2.setup_model() + + # Perform inference using shared features with text prompt + masks, boxes = predictor2.inference_features(predictor.features, src_shape=src_shape, text=["person"]) + + # Perform inference using shared features with bounding box prompt + masks, boxes = predictor2.inference_features(predictor.features, src_shape=src_shape, bboxes=[439, 437, 524, 709]) + + # Visualize results + masks, boxes = masks.cpu().numpy(), boxes.cpu().numpy() + im = cv2.imread(source) + annotator = Annotator(im, pil=False) + annotator.masks(masks, [colors(x, True) for x in range(len(masks))]) + + cv2.imshow("result", annotator.result()) + cv2.waitKey(0) ``` - !!! warning "API Preview" - - This example shows intended usage. Actual implementation pending Ultralytics integration. - ### Video Concept Segmentation -!!! example "Track Concepts Across Video" +#### Track Concepts Across Video with Bounding Boxes - Detect and track all instances of a concept throughout a video. +!!! example "Video Tracking with Visual Prompts" + + Detect and track object instances across video frames using bounding box prompts. === "Python" ```python - from ultralytics.models.sam import SAM3VideoPredictor + from ultralytics.models.sam.predict import SAM3VideoPredictor # Create video predictor - predictor = SAM3VideoPredictor(model="sam3.pt", imgsz=1024, conf=0.25) + overrides = dict(conf=0.25, task="segment", mode="predict", model="sam3.pt", half=True) + predictor = SAM3VideoPredictor(overrides=overrides) - # Track all instances of a concept - results = predictor(source="video.mp4", prompt="person wearing blue shirt") + # Track objects using bounding box prompts + results = predictor(source="path/to/video.mp4", bboxes=[[706.5, 442.5, 905.25, 555], [598, 635, 725, 750]], stream=True) - # Combine text with exemplar for precision + # Process and display results + for r in results: + r.show() # Display frame with segmentation masks + ``` + +#### Track Concepts with Text Prompts + +!!! example "Video Tracking with Semantic Queries" + + Track all instances of concepts specified by text across video frames. + + === "Python" + + ```python + from ultralytics.models.sam.sam3_video_model import SAM3VideoSemanticPredictor + + # Initialize semantic video predictor + overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=640, model="sam3.pt", half=True) + predictor = SAM3VideoSemanticPredictor(overrides=overrides, bpe_path="path/to/bpe_simple_vocab_16e6.txt.gz") + + # Track concepts using text prompts + results = predictor(source="path/to/video.mp4", text=["person", "bicycle"], stream=True, save=True) + + # Process results + for r in results: + r.show() # Display frame with tracked objects + + # Alternative: Track with bounding box prompts results = predictor( - source="video.mp4", - prompt="kangaroo", - bboxes=[initial_box], # Exemplar from first frame - labels=[1], + source="path/to/video.mp4", + bboxes=[[864, 383, 975, 620], [705, 229, 782, 402]], + labels=[1, 1], # Positive labels + stream=True, + save=True, ) ``` - !!! warning "API Preview" - - This example shows intended usage. Actual implementation pending Ultralytics integration. - -For broader streaming and production setups, see [object tracking](../guides/object-counting.md) and [view results in terminal](../guides/view-results-in-terminal.md). - ### Visual Prompts (SAM 2 Compatibility) -SAM 3 maintains full backward compatibility with SAM 2's visual prompting: +SAM 3 maintains full backward compatibility with SAM 2's visual prompting for single-object segmentation: !!! example "SAM 2 Style Visual Prompts" + The basic `SAM` interface behaves exactly like SAM 2, segmenting only the specific area indicated by visual prompts (points, boxes, or masks). + === "Python" ```python @@ -282,19 +319,21 @@ SAM 3 maintains full backward compatibility with SAM 2's visual prompting: model = SAM("sam3.pt") - # Single point prompt (SAM 2 style) - results = model(points=[900, 370], labels=[1]) + # Single point prompt - segments object at specific location + results = model.predict(source="path/to/image.jpg", points=[900, 370], labels=[1]) + results[0].show() - # Multiple points - results = model(points=[[400, 370], [900, 370]], labels=[1, 1]) + # Multiple points - segments single object with multiple point hints + results = model.predict(source="path/to/image.jpg", points=[[400, 370], [900, 370]], labels=[1, 1]) - # Box prompt - results = model(bboxes=[100, 150, 300, 400]) + # Box prompt - segments object within bounding box + results = model.predict(source="path/to/image.jpg", bboxes=[100, 150, 300, 400]) + results[0].show() ``` - !!! warning "API Preview" + !!! warning "Visual Prompts vs Concept Segmentation" - This example shows intended usage. Actual implementation pending Ultralytics integration. + Using `SAM("sam3.pt")` with visual prompts (points/boxes/masks) will segment **only the specific object** at that location, just like SAM 2. To segment **all instances of a concept**, use `SAM3SemanticPredictor` with text or exemplar prompts as shown above. ## Performance Benchmarks diff --git a/docs/en/reference/models/sam/build.md b/docs/en/reference/models/sam/build.md index ce233acfb3..a2dbc09cdf 100644 --- a/docs/en/reference/models/sam/build.md +++ b/docs/en/reference/models/sam/build.md @@ -11,6 +11,10 @@ keywords: Ultralytics, SAM model, Segment Anything Model, SAM 2 model, Segment A
+## ::: ultralytics.models.sam.build._load_checkpoint + +



+ ## ::: ultralytics.models.sam.build.build_sam_vit_h



diff --git a/docs/en/reference/models/sam/build_sam3.md b/docs/en/reference/models/sam/build_sam3.md new file mode 100644 index 0000000000..8b2ec7298d --- /dev/null +++ b/docs/en/reference/models/sam/build_sam3.md @@ -0,0 +1,32 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/build_sam3.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/build_sam3.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/build_sam3.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.build_sam3._create_vision_backbone + +



+ +## ::: ultralytics.models.sam.build_sam3._create_sam3_transformer + +



+ +## ::: ultralytics.models.sam.build_sam3.build_sam3_image_model + +



+ +## ::: ultralytics.models.sam.build_sam3.build_interactive_sam3 + +



+ +## ::: ultralytics.models.sam.build_sam3._load_checkpoint + +

diff --git a/docs/en/reference/models/sam/modules/sam.md b/docs/en/reference/models/sam/modules/sam.md index 0a1b61e6e9..d69ea2b75f 100644 --- a/docs/en/reference/models/sam/modules/sam.md +++ b/docs/en/reference/models/sam/modules/sam.md @@ -17,4 +17,8 @@ keywords: Ultralytics, SAM Module, SAM 2 Module, object segmentation, image enco ## ::: ultralytics.models.sam.modules.sam.SAM2Model +



+ +## ::: ultralytics.models.sam.modules.sam.SAM3Model +

diff --git a/docs/en/reference/models/sam/modules/utils.md b/docs/en/reference/models/sam/modules/utils.md index 97ead6dc67..7241ef0404 100644 --- a/docs/en/reference/models/sam/modules/utils.md +++ b/docs/en/reference/models/sam/modules/utils.md @@ -49,4 +49,12 @@ keywords: Ultralytics, SAM, SAM 2, API Reference, models, window partition, data ## ::: ultralytics.models.sam.modules.utils.add_decomposed_rel_pos +



+ +## ::: ultralytics.models.sam.modules.utils.get_abs_pos + +



+ +## ::: ultralytics.models.sam.modules.utils.concat_rel_pos +

diff --git a/docs/en/reference/models/sam/predict.md b/docs/en/reference/models/sam/predict.md index c5b08777a9..bb5096c048 100644 --- a/docs/en/reference/models/sam/predict.md +++ b/docs/en/reference/models/sam/predict.md @@ -25,4 +25,20 @@ keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Mode ## ::: ultralytics.models.sam.predict.SAM2DynamicInteractivePredictor +



+ +## ::: ultralytics.models.sam.predict.SAM3Predictor + +



+ +## ::: ultralytics.models.sam.predict.SAM3SemanticPredictor + +



+ +## ::: ultralytics.models.sam.predict.SAM3VideoPredictor + +



+ +## ::: ultralytics.models.sam.predict.SAM3VideoSemanticPredictor +

diff --git a/docs/en/reference/models/sam/sam3/decoder.md b/docs/en/reference/models/sam/sam3/decoder.md new file mode 100644 index 0000000000..fdbf8c4958 --- /dev/null +++ b/docs/en/reference/models/sam/sam3/decoder.md @@ -0,0 +1,20 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/sam3/decoder.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/decoder.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/decoder.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.sam3.decoder.TransformerDecoderLayer + +



+ +## ::: ultralytics.models.sam.sam3.decoder.TransformerDecoder + +

diff --git a/docs/en/reference/models/sam/sam3/encoder.md b/docs/en/reference/models/sam/sam3/encoder.md new file mode 100644 index 0000000000..2751822db4 --- /dev/null +++ b/docs/en/reference/models/sam/sam3/encoder.md @@ -0,0 +1,28 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/sam3/encoder.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/encoder.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/encoder.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoderLayer + +



+ +## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoder + +



+ +## ::: ultralytics.models.sam.sam3.encoder.TransformerEncoderFusion + +



+ +## ::: ultralytics.models.sam.sam3.encoder.pool_text_feat + +

diff --git a/docs/en/reference/models/sam/sam3/geometry_encoders.md b/docs/en/reference/models/sam/sam3/geometry_encoders.md new file mode 100644 index 0000000000..1b3eb75139 --- /dev/null +++ b/docs/en/reference/models/sam/sam3/geometry_encoders.md @@ -0,0 +1,28 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/sam3/geometry_encoders.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/geometry_encoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/geometry_encoders.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.sam3.geometry_encoders.Prompt + +



+ +## ::: ultralytics.models.sam.sam3.geometry_encoders.SequenceGeometryEncoder + +



+ +## ::: ultralytics.models.sam.sam3.geometry_encoders.is_right_padded + +



+ +## ::: ultralytics.models.sam.sam3.geometry_encoders.concat_padded_sequences + +

diff --git a/docs/en/reference/models/sam/sam3/maskformer_segmentation.md b/docs/en/reference/models/sam/sam3/maskformer_segmentation.md new file mode 100644 index 0000000000..7218d86ecb --- /dev/null +++ b/docs/en/reference/models/sam/sam3/maskformer_segmentation.md @@ -0,0 +1,32 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/sam3/maskformer_segmentation.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/maskformer_segmentation.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/maskformer_segmentation.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.sam3.maskformer_segmentation.LinearPresenceHead + +



+ +## ::: ultralytics.models.sam.sam3.maskformer_segmentation.MaskPredictor + +



+ +## ::: ultralytics.models.sam.sam3.maskformer_segmentation.SegmentationHead + +



+ +## ::: ultralytics.models.sam.sam3.maskformer_segmentation.PixelDecoder + +



+ +## ::: ultralytics.models.sam.sam3.maskformer_segmentation.UniversalSegmentationHead + +

diff --git a/docs/en/reference/models/sam/sam3/model_misc.md b/docs/en/reference/models/sam/sam3/model_misc.md new file mode 100644 index 0000000000..bef97aa7d5 --- /dev/null +++ b/docs/en/reference/models/sam/sam3/model_misc.md @@ -0,0 +1,32 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/sam3/model_misc.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/model_misc.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/model_misc.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.sam3.model_misc.DotProductScoring + +



+ +## ::: ultralytics.models.sam.sam3.model_misc.LayerScale + +



+ +## ::: ultralytics.models.sam.sam3.model_misc.TransformerWrapper + +



+ +## ::: ultralytics.models.sam.sam3.model_misc.get_valid_ratio + +



+ +## ::: ultralytics.models.sam.sam3.model_misc.gen_sineembed_for_position + +

diff --git a/docs/en/reference/models/sam/sam3/necks.md b/docs/en/reference/models/sam/sam3/necks.md new file mode 100644 index 0000000000..f9c13e2ffb --- /dev/null +++ b/docs/en/reference/models/sam/sam3/necks.md @@ -0,0 +1,16 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/sam3/necks.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/necks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/necks.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.sam3.necks.Sam3DualViTDetNeck + +

diff --git a/docs/en/reference/models/sam/sam3/sam3_image.md b/docs/en/reference/models/sam/sam3/sam3_image.md new file mode 100644 index 0000000000..c773796697 --- /dev/null +++ b/docs/en/reference/models/sam/sam3/sam3_image.md @@ -0,0 +1,20 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/sam3/sam3_image.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/sam3_image.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/sam3_image.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.sam3.sam3_image.SAM3SemanticModel + +



+ +## ::: ultralytics.models.sam.sam3.sam3_image._update_out + +

diff --git a/docs/en/reference/models/sam/sam3/text_encoder_ve.md b/docs/en/reference/models/sam/sam3/text_encoder_ve.md new file mode 100644 index 0000000000..91e8a3a92b --- /dev/null +++ b/docs/en/reference/models/sam/sam3/text_encoder_ve.md @@ -0,0 +1,32 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/sam3/text_encoder_ve.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/text_encoder_ve.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/text_encoder_ve.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.sam3.text_encoder_ve.ResidualAttentionBlock + +



+ +## ::: ultralytics.models.sam.sam3.text_encoder_ve.Transformer + +



+ +## ::: ultralytics.models.sam.sam3.text_encoder_ve.TextTransformer + +



+ +## ::: ultralytics.models.sam.sam3.text_encoder_ve.VETextEncoder + +



+ +## ::: ultralytics.models.sam.sam3.text_encoder_ve.text_global_pool + +

diff --git a/docs/en/reference/models/sam/sam3/tokenizer_ve.md b/docs/en/reference/models/sam/sam3/tokenizer_ve.md new file mode 100644 index 0000000000..d38207bd92 --- /dev/null +++ b/docs/en/reference/models/sam/sam3/tokenizer_ve.md @@ -0,0 +1,52 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/sam3/tokenizer_ve.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/tokenizer_ve.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/tokenizer_ve.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.sam3.tokenizer_ve.SimpleTokenizer + +



+ +## ::: ultralytics.models.sam.sam3.tokenizer_ve.bytes_to_unicode + +



+ +## ::: ultralytics.models.sam.sam3.tokenizer_ve.get_pairs + +



+ +## ::: ultralytics.models.sam.sam3.tokenizer_ve.basic_clean + +



+ +## ::: ultralytics.models.sam.sam3.tokenizer_ve.whitespace_clean + +



+ +## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_canonicalize + +



+ +## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_lower + +



+ +## ::: ultralytics.models.sam.sam3.tokenizer_ve._clean_whitespace + +



+ +## ::: ultralytics.models.sam.sam3.tokenizer_ve.get_clean_fn + +



+ +## ::: ultralytics.models.sam.sam3.tokenizer_ve.canonicalize_text + +

diff --git a/docs/en/reference/models/sam/sam3/vitdet.md b/docs/en/reference/models/sam/sam3/vitdet.md new file mode 100644 index 0000000000..5632234de8 --- /dev/null +++ b/docs/en/reference/models/sam/sam3/vitdet.md @@ -0,0 +1,24 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/sam3/vitdet.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vitdet.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vitdet.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.sam3.vitdet.Attention + +



+ +## ::: ultralytics.models.sam.sam3.vitdet.Block + +



+ +## ::: ultralytics.models.sam.sam3.vitdet.ViT + +

diff --git a/docs/en/reference/models/sam/sam3/vl_combiner.md b/docs/en/reference/models/sam/sam3/vl_combiner.md new file mode 100644 index 0000000000..f5a7d84ff5 --- /dev/null +++ b/docs/en/reference/models/sam/sam3/vl_combiner.md @@ -0,0 +1,16 @@ +--- +description: TODO ADD DESCRIPTION +keywords: TODO ADD KEYWORDS +--- + +# Reference for `ultralytics/models/sam/sam3/vl_combiner.py` + +!!! success "Improvements" + + This page is sourced from [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vl_combiner.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/sam3/vl_combiner.py). Have an improvement or example to add? Open a [Pull Request](https://docs.ultralytics.com/help/contributing/) โ€” thank you! ๐Ÿ™ + +
+ +## ::: ultralytics.models.sam.sam3.vl_combiner.SAM3VLBackbone + +

diff --git a/mkdocs.yml b/mkdocs.yml index 30d7967a34..3c53138a5d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -592,6 +592,7 @@ nav: - sam: - amg: reference/models/sam/amg.md - build: reference/models/sam/build.md + - build_sam3: reference/models/sam/build_sam3.md - model: reference/models/sam/model.md - modules: - blocks: reference/models/sam/modules/blocks.md @@ -603,6 +604,18 @@ nav: - transformer: reference/models/sam/modules/transformer.md - utils: reference/models/sam/modules/utils.md - predict: reference/models/sam/predict.md + - sam3: + - decoder: reference/models/sam/sam3/decoder.md + - encoder: reference/models/sam/sam3/encoder.md + - geometry_encoders: reference/models/sam/sam3/geometry_encoders.md + - maskformer_segmentation: reference/models/sam/sam3/maskformer_segmentation.md + - model_misc: reference/models/sam/sam3/model_misc.md + - necks: reference/models/sam/sam3/necks.md + - sam3_image: reference/models/sam/sam3/sam3_image.md + - text_encoder_ve: reference/models/sam/sam3/text_encoder_ve.md + - tokenizer_ve: reference/models/sam/sam3/tokenizer_ve.md + - vitdet: reference/models/sam/sam3/vitdet.md + - vl_combiner: reference/models/sam/sam3/vl_combiner.md - utils: - loss: reference/models/utils/loss.md - ops: reference/models/utils/ops.md diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index dbc57048ef..50fc67bba9 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license -__version__ = "8.3.236" +__version__ = "8.3.237" import importlib import os diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py index 1b11ca12be..dfe0328a05 100644 --- a/ultralytics/engine/predictor.py +++ b/ultralytics/engine/predictor.py @@ -244,14 +244,15 @@ class BasePredictor: for _ in gen: # sourcery skip: remove-empty-nested-block, noqa pass - def setup_source(self, source): + def setup_source(self, source, stride: int | None = None): """Set up source and inference mode. Args: source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor): Source for inference. + stride (int, optional): Model stride for image size checking. """ - self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size + self.imgsz = check_imgsz(self.args.imgsz, stride=stride or self.model.stride, min_dim=2) # check image size self.dataset = load_inference_source( source=source, batch=self.args.batch, diff --git a/ultralytics/models/sam/__init__.py b/ultralytics/models/sam/__init__.py index 8188c74d9e..e8723bcd97 100644 --- a/ultralytics/models/sam/__init__.py +++ b/ultralytics/models/sam/__init__.py @@ -1,7 +1,16 @@ # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license from .model import SAM -from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor +from .predict import ( + Predictor, + SAM2DynamicInteractivePredictor, + SAM2Predictor, + SAM2VideoPredictor, + SAM3Predictor, + SAM3SemanticPredictor, + SAM3VideoPredictor, + SAM3VideoSemanticPredictor, +) __all__ = ( "SAM", @@ -9,4 +18,8 @@ __all__ = ( "SAM2DynamicInteractivePredictor", "SAM2Predictor", "SAM2VideoPredictor", + "SAM3Predictor", + "SAM3SemanticPredictor", + "SAM3VideoPredictor", + "SAM3VideoSemanticPredictor", ) # tuple or list of exportable items diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py index 8e47502cda..ff9f0f56aa 100644 --- a/ultralytics/models/sam/build.py +++ b/ultralytics/models/sam/build.py @@ -21,6 +21,21 @@ from .modules.tiny_encoder import TinyViT from .modules.transformer import TwoWayTransformer +def _load_checkpoint(model, checkpoint): + """Load checkpoint into model from file path.""" + if checkpoint is None: + return model + + checkpoint = attempt_download_asset(checkpoint) + with open(checkpoint, "rb") as f: + state_dict = torch_load(f) + # Handle nested "model" key + if "model" in state_dict and isinstance(state_dict["model"], dict): + state_dict = state_dict["model"] + model.load_state_dict(state_dict) + return model + + def build_sam_vit_h(checkpoint=None): """Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters.""" return _build_sam( @@ -205,10 +220,7 @@ def _build_sam( pixel_std=[58.395, 57.12, 57.375], ) if checkpoint is not None: - checkpoint = attempt_download_asset(checkpoint) - with open(checkpoint, "rb") as f: - state_dict = torch_load(f) - sam.load_state_dict(state_dict) + sam = _load_checkpoint(sam, checkpoint) sam.eval() return sam @@ -299,10 +311,7 @@ def _build_sam2( ) if checkpoint is not None: - checkpoint = attempt_download_asset(checkpoint) - with open(checkpoint, "rb") as f: - state_dict = torch_load(f)["model"] - sam2.load_state_dict(state_dict) + sam2 = _load_checkpoint(sam2, checkpoint) sam2.eval() return sam2 diff --git a/ultralytics/models/sam/build_sam3.py b/ultralytics/models/sam/build_sam3.py new file mode 100644 index 0000000000..5e939c1849 --- /dev/null +++ b/ultralytics/models/sam/build_sam3.py @@ -0,0 +1,374 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +import torch.nn as nn + +from ultralytics.nn.modules.transformer import MLP +from ultralytics.utils.patches import torch_load + +from .modules.blocks import PositionEmbeddingSine, RoPEAttention +from .modules.encoders import MemoryEncoder +from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer +from .modules.sam import SAM3Model +from .sam3.decoder import TransformerDecoder, TransformerDecoderLayer +from .sam3.encoder import TransformerEncoderFusion, TransformerEncoderLayer +from .sam3.geometry_encoders import SequenceGeometryEncoder +from .sam3.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead +from .sam3.model_misc import DotProductScoring, TransformerWrapper +from .sam3.necks import Sam3DualViTDetNeck +from .sam3.sam3_image import SAM3SemanticModel +from .sam3.text_encoder_ve import VETextEncoder +from .sam3.tokenizer_ve import SimpleTokenizer +from .sam3.vitdet import ViT +from .sam3.vl_combiner import SAM3VLBackbone + + +def _create_vision_backbone(compile_mode=None, enable_inst_interactivity=True) -> Sam3DualViTDetNeck: + """Create SAM3 visual backbone with ViT and neck.""" + # Position encoding + position_encoding = PositionEmbeddingSine( + num_pos_feats=256, + normalize=True, + scale=None, + temperature=10000, + ) + + # ViT backbone + vit_backbone = ViT( + img_size=1008, + pretrain_img_size=336, + patch_size=14, + embed_dim=1024, + depth=32, + num_heads=16, + mlp_ratio=4.625, + norm_layer="LayerNorm", + drop_path_rate=0.1, + qkv_bias=True, + use_abs_pos=True, + tile_abs_pos=True, + global_att_blocks=(7, 15, 23, 31), + rel_pos_blocks=(), + use_rope=True, + use_interp_rope=True, + window_size=24, + pretrain_use_cls_token=True, + retain_cls_token=False, + ln_pre=True, + ln_post=False, + return_interm_layers=False, + bias_patch_embed=False, + compile_mode=compile_mode, + ) + return Sam3DualViTDetNeck( + position_encoding=position_encoding, + d_model=256, + scale_factors=[4.0, 2.0, 1.0, 0.5], + trunk=vit_backbone, + add_sam2_neck=enable_inst_interactivity, + ) + + +def _create_sam3_transformer() -> TransformerWrapper: + """Create SAM3 detector encoder and decoder.""" + encoder: TransformerEncoderFusion = TransformerEncoderFusion( + layer=TransformerEncoderLayer( + d_model=256, + dim_feedforward=2048, + dropout=0.1, + pos_enc_at_attn=True, + pos_enc_at_cross_attn_keys=False, + pos_enc_at_cross_attn_queries=False, + pre_norm=True, + self_attention=nn.MultiheadAttention( + num_heads=8, + dropout=0.1, + embed_dim=256, + batch_first=True, + ), + cross_attention=nn.MultiheadAttention( + num_heads=8, + dropout=0.1, + embed_dim=256, + batch_first=True, + ), + ), + num_layers=6, + d_model=256, + num_feature_levels=1, + frozen=False, + use_act_checkpoint=True, + add_pooled_text_to_img_feat=False, + pool_text_with_mask=True, + ) + decoder: TransformerDecoder = TransformerDecoder( + layer=TransformerDecoderLayer( + d_model=256, + dim_feedforward=2048, + dropout=0.1, + cross_attention=nn.MultiheadAttention( + num_heads=8, + dropout=0.1, + embed_dim=256, + ), + n_heads=8, + use_text_cross_attention=True, + ), + num_layers=6, + num_queries=200, + return_intermediate=True, + box_refine=True, + num_o2m_queries=0, + dac=True, + boxRPB="log", + d_model=256, + frozen=False, + interaction_layer=None, + dac_use_selfatt_ln=True, + use_act_checkpoint=True, + presence_token=True, + ) + + return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256) + + +def build_sam3_image_model( + checkpoint_path: str, bpe_path: str, enable_segmentation: bool = True, compile: bool = False +): + """Build SAM3 image model. + + Args: + checkpoint_path: Optional path to model checkpoint + bpe_path: Path to the BPE tokenizer vocabulary + enable_segmentation: Whether to enable segmentation head + compile: To enable compilation, set to "default" + + Returns: + A SAM3 image model + """ + # Create visual components + compile_mode = "default" if compile else None + vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True) + + # Create text components + text_encoder = VETextEncoder( + tokenizer=SimpleTokenizer(bpe_path=bpe_path), + d_model=256, + width=1024, + heads=16, + layers=24, + ) + + # Create visual-language backbone + backbone = SAM3VLBackbone(visual=vision_encoder, text=text_encoder, scalp=1) + + # Create transformer components + transformer = _create_sam3_transformer() + + # Create dot product scoring + dot_prod_scoring = DotProductScoring( + d_model=256, + d_proj=256, + prompt_mlp=MLP( + input_dim=256, + hidden_dim=2048, + output_dim=256, + num_layers=2, + residual=True, + out_norm=nn.LayerNorm(256), + ), + ) + + # Create segmentation head if enabled + segmentation_head = ( + UniversalSegmentationHead( + hidden_dim=256, + upsampling_stages=3, + aux_masks=False, + presence_head=False, + dot_product_scorer=None, + act_ckpt=True, + cross_attend_prompt=nn.MultiheadAttention( + num_heads=8, + dropout=0, + embed_dim=256, + ), + pixel_decoder=PixelDecoder( + num_upsampling_stages=3, + interpolation_mode="nearest", + hidden_dim=256, + compile_mode=compile_mode, + ), + ) + if enable_segmentation + else None + ) + + # Create geometry encoder + input_geometry_encoder = SequenceGeometryEncoder( + pos_enc=PositionEmbeddingSine( + num_pos_feats=256, + normalize=True, + scale=None, + temperature=10000, + ), + encode_boxes_as_points=False, + boxes_direct_project=True, + boxes_pool=True, + boxes_pos_enc=True, + d_model=256, + num_layers=3, + layer=TransformerEncoderLayer( + d_model=256, + dim_feedforward=2048, + dropout=0.1, + pos_enc_at_attn=False, + pre_norm=True, + pos_enc_at_cross_attn_queries=False, + pos_enc_at_cross_attn_keys=True, + ), + use_act_ckpt=True, + add_cls=True, + add_post_encode_proj=True, + ) + + # Create the SAM3SemanticModel model + model = SAM3SemanticModel( + backbone=backbone, + transformer=transformer, + input_geometry_encoder=input_geometry_encoder, + segmentation_head=segmentation_head, + num_feature_levels=1, + o2m_mask_predict=True, + dot_prod_scoring=dot_prod_scoring, + use_instance_query=False, + multimask_output=True, + ) + + # Load checkpoint + model = _load_checkpoint(model, checkpoint_path) + model.eval() + return model + + +def build_interactive_sam3(checkpoint_path: str, compile=None, with_backbone=True) -> SAM3Model: + """Build the SAM3 Tracker module for video tracking. + + Returns: + Sam3TrackerPredictor: Wrapped SAM3 Tracker module + """ + # Create model components + memory_encoder = MemoryEncoder(out_dim=64, interpol_size=[1152, 1152]) + memory_attention = MemoryAttention( + batch_first=True, + d_model=256, + pos_enc_at_input=True, + layer=MemoryAttentionLayer( + dim_feedforward=2048, + dropout=0.1, + pos_enc_at_attn=False, + pos_enc_at_cross_attn_keys=True, + pos_enc_at_cross_attn_queries=False, + self_attn=RoPEAttention( + embedding_dim=256, + num_heads=1, + downsample_rate=1, + rope_theta=10000.0, + feat_sizes=[72, 72], + ), + d_model=256, + cross_attn=RoPEAttention( + embedding_dim=256, + num_heads=1, + downsample_rate=1, + kv_in_dim=64, + rope_theta=10000.0, + feat_sizes=[72, 72], + rope_k_repeat=True, + ), + ), + num_layers=4, + ) + + backbone = ( + SAM3VLBackbone(scalp=1, visual=_create_vision_backbone(compile_mode=compile), text=None) + if with_backbone + else None + ) + model = SAM3Model( + image_size=1008, + image_encoder=backbone, + memory_attention=memory_attention, + memory_encoder=memory_encoder, + backbone_stride=14, + num_maskmem=7, + sigmoid_scale_for_mem_enc=20.0, + sigmoid_bias_for_mem_enc=-10.0, + use_mask_input_as_output_without_sam=True, + directly_add_no_mem_embed=True, + use_high_res_features_in_sam=True, + multimask_output_in_sam=True, + iou_prediction_use_sigmoid=True, + use_obj_ptrs_in_encoder=True, + add_tpos_enc_to_obj_ptrs=True, + only_obj_ptrs_in_the_past_for_eval=True, + pred_obj_scores=True, + pred_obj_scores_mlp=True, + fixed_no_obj_ptr=True, + multimask_output_for_tracking=True, + use_multimask_token_for_obj_ptr=True, + multimask_min_pt_num=0, + multimask_max_pt_num=1, + use_mlp_for_obj_ptr_proj=True, + compile_image_encoder=False, + no_obj_embed_spatial=True, + proj_tpos_enc_in_obj_ptrs=True, + use_signed_tpos_enc_to_obj_ptrs=True, + sam_mask_decoder_extra_args=dict( + dynamic_multimask_via_stability=True, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + ), + ) + + # Load checkpoint if provided + model = _load_checkpoint(model, checkpoint_path, interactive=True) + + # Setup device and mode + model.eval() + return model + + +def _load_checkpoint(model, checkpoint, interactive=False): + """Load SAM3 model checkpoint from file.""" + with open(checkpoint, "rb") as f: + ckpt = torch_load(f) + if "model" in ckpt and isinstance(ckpt["model"], dict): + ckpt = ckpt["model"] + sam3_image_ckpt = {k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k} + if interactive: + sam3_image_ckpt.update( + { + k.replace("backbone.vision_backbone", "image_encoder.vision_backbone"): v + for k, v in sam3_image_ckpt.items() + if "backbone.vision_backbone" in k + } + ) + sam3_image_ckpt.update( + { + k.replace("tracker.transformer.encoder", "memory_attention"): v + for k, v in ckpt.items() + if "tracker.transformer" in k + } + ) + sam3_image_ckpt.update( + { + k.replace("tracker.maskmem_backbone", "memory_encoder"): v + for k, v in ckpt.items() + if "tracker.maskmem_backbone" in k + } + ) + sam3_image_ckpt.update({k.replace("tracker.", ""): v for k, v in ckpt.items() if "tracker." in k}) + model.load_state_dict(sam3_image_ckpt, strict=False) + return model diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py index dd366c115e..1026f6b49c 100644 --- a/ultralytics/models/sam/model.py +++ b/ultralytics/models/sam/model.py @@ -21,7 +21,7 @@ from pathlib import Path from ultralytics.engine.model import Model from ultralytics.utils.torch_utils import model_info -from .predict import Predictor, SAM2Predictor +from .predict import Predictor, SAM2Predictor, SAM3Predictor class SAM(Model): @@ -59,6 +59,7 @@ class SAM(Model): if model and Path(model).suffix not in {".pt", ".pth"}: raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.") self.is_sam2 = "sam2" in Path(model).stem + self.is_sam3 = "sam3" in Path(model).stem super().__init__(model=model, task="segment") def _load(self, weights: str, task=None): @@ -72,9 +73,14 @@ class SAM(Model): >>> sam = SAM("sam_b.pt") >>> sam._load("path/to/custom_weights.pt") """ - from .build import build_sam # slow import + if self.is_sam3: + from .build_sam3 import build_interactive_sam3 - self.model = build_sam(weights) + self.model = build_interactive_sam3(weights) + else: + from .build import build_sam # slow import + + self.model = build_sam(weights) def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs): """Perform segmentation prediction on the given image or video source. @@ -158,4 +164,6 @@ class SAM(Model): >>> print(task_map) {'segment': {'predictor': }} """ - return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}} + return { + "segment": {"predictor": SAM2Predictor if self.is_sam2 else SAM3Predictor if self.is_sam3 else Predictor} + } diff --git a/ultralytics/models/sam/modules/blocks.py b/ultralytics/models/sam/modules/blocks.py index e27f0d0fcf..6ff9ece752 100644 --- a/ultralytics/models/sam/modules/blocks.py +++ b/ultralytics/models/sam/modules/blocks.py @@ -79,6 +79,7 @@ class MaskDownSampler(nn.Module): padding: int = 0, total_stride: int = 16, activation: type[nn.Module] = nn.GELU, + interpol_size: tuple[int, int] | None = None, ): """Initialize a mask downsampler module for progressive downsampling and channel expansion.""" super().__init__() @@ -102,9 +103,24 @@ class MaskDownSampler(nn.Module): mask_in_chans = mask_out_chans self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + self.interpol_size = interpol_size + if self.interpol_size is not None: + assert isinstance(self.interpol_size, (list, tuple)), ( + f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple." + ) + self.interpol_size = list(interpol_size) + assert len(self.interpol_size) == 2 def forward(self, x: Tensor) -> Tensor: """Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d.""" + if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]): + x = F.interpolate( + x.float(), + size=self.interpol_size, + align_corners=False, + mode="bilinear", + antialias=True, + ).to(x.dtype) return self.encoder(x) @@ -429,13 +445,7 @@ class RoPEAttention(Attention): ) # Attention - _, _, _, c_per_head = q.shape - attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens - attn = attn / math.sqrt(c_per_head) - attn = torch.softmax(attn, dim=-1) - - # Get output - out = attn @ v + out = F.scaled_dot_product_attention(q, k, v) out = self._recombine_heads(out) out = self.out_proj(out) @@ -1033,6 +1043,7 @@ class PatchEmbed(nn.Module): padding: tuple[int, int] = (0, 0), in_chans: int = 3, embed_dim: int = 768, + bias: bool = True, ) -> None: """Initialize the PatchEmbed module for converting image patches to embeddings. @@ -1045,10 +1056,11 @@ class PatchEmbed(nn.Module): padding (tuple[int, int]): Padding applied to the input before convolution. in_chans (int): Number of input image channels. embed_dim (int): Dimensionality of the output patch embeddings. + bias (bool): Whether to include a bias term in the convolutional layer. """ super().__init__() - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: """Compute patch embedding by applying convolution and transposing resulting tensor.""" diff --git a/ultralytics/models/sam/modules/decoders.py b/ultralytics/models/sam/modules/decoders.py index 69adf48b5c..b845f6b6e2 100644 --- a/ultralytics/models/sam/modules/decoders.py +++ b/ultralytics/models/sam/modules/decoders.py @@ -436,9 +436,8 @@ class SAM2MaskDecoder(nn.Module): def _get_stability_scores(self, mask_logits): """Compute mask stability scores based on IoU between upper and lower thresholds.""" mask_logits = mask_logits.flatten(-2) - stability_delta = self.dynamic_multimask_stability_delta - area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() - area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + area_i = torch.sum(mask_logits > self.dynamic_multimask_stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -self.dynamic_multimask_stability_delta, dim=-1).float() return torch.where(area_u > 0, area_i / area_u, 1.0) def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py index d5066ce1bd..8fc3da75d5 100644 --- a/ultralytics/models/sam/modules/encoders.py +++ b/ultralytics/models/sam/modules/encoders.py @@ -361,6 +361,7 @@ class MemoryEncoder(nn.Module): self, out_dim, in_dim=256, # in_dim of pix_feats + interpol_size: tuple[int, int] | None = None, ): """Initialize the MemoryEncoder for encoding pixel features and masks into memory representations. @@ -370,10 +371,12 @@ class MemoryEncoder(nn.Module): Args: out_dim (int): Output dimension of the encoded features. in_dim (int): Input dimension of the pixel features. + interpol_size (tuple[int, int] | None): Size to interpolate masks to. If None, uses the size of pixel + features. """ super().__init__() - self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1) + self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1, interpol_size=interpol_size) self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) self.fuser = Fuser(CXBlock(dim=256), num_layers=2) diff --git a/ultralytics/models/sam/modules/memory_attention.py b/ultralytics/models/sam/modules/memory_attention.py index d229a4f326..63c573d697 100644 --- a/ultralytics/models/sam/modules/memory_attention.py +++ b/ultralytics/models/sam/modules/memory_attention.py @@ -59,6 +59,8 @@ class MemoryAttentionLayer(nn.Module): pos_enc_at_attn: bool = False, pos_enc_at_cross_attn_keys: bool = True, pos_enc_at_cross_attn_queries: bool = False, + self_attn: nn.Module | None = None, + cross_attn: nn.Module | None = None, ): """Initialize a memory attention layer with self-attention, cross-attention, and feedforward components. @@ -69,13 +71,15 @@ class MemoryAttentionLayer(nn.Module): pos_enc_at_attn (bool): Whether to add positional encoding at attention. pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys. pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries. + self_attn (nn.Module | None): Custom self-attention module. If None, a default RoPEAttention is used. + cross_attn (nn.Module | None): Custom cross-attention module. If None, a default RoPEAttention is used. """ super().__init__() self.d_model = d_model self.dim_feedforward = dim_feedforward self.dropout_value = dropout - self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1) - self.cross_attn_image = RoPEAttention( + self.self_attn = self_attn or RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1) + self.cross_attn_image = cross_attn or RoPEAttention( rope_k_repeat=True, embedding_dim=256, num_heads=1, diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index 9b5efe8068..701369a15b 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -13,7 +13,7 @@ from torch.nn.init import trunc_normal_ from ultralytics.nn.modules import MLP from ultralytics.utils import LOGGER -from .blocks import SAM2TwoWayTransformer +from .blocks import SAM2TwoWayTransformer, TwoWayTransformer from .decoders import MaskDecoder, SAM2MaskDecoder from .encoders import ImageEncoderViT, PromptEncoder from .utils import get_1d_sine_pe, select_closest_cond_frames @@ -329,6 +329,7 @@ class SAM2Model(torch.nn.Module): self._build_sam_heads() self.max_cond_frames_in_attn = max_cond_frames_in_attn + self.add_all_frames_to_correct_as_cond = True # Model compilation if compile_image_encoder: @@ -473,7 +474,7 @@ class SAM2Model(torch.nn.Module): assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: sam_mask_prompt = F.interpolate( - mask_inputs.float(), + mask_inputs.to(backbone_features.dtype), size=self.sam_prompt_encoder.mask_input_size, align_corners=False, mode="bilinear", @@ -571,7 +572,7 @@ class SAM2Model(torch.nn.Module): # produce an object pointer using the SAM decoder from the mask input _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( backbone_features=backbone_features, - mask_inputs=self.mask_downsample(mask_inputs_float), + mask_inputs=self.mask_downsample(mask_inputs_float.to(backbone_features.dtype)), high_res_features=high_res_features, ) # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; @@ -818,7 +819,6 @@ class SAM2Model(torch.nn.Module): mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied maskmem_features = maskmem_out["vision_features"] - maskmem_pos_enc = maskmem_out["vision_pos_enc"] # add a no-object embedding to the spatial memory to indicate that the frame # is predicted to be occluded (i.e. no object is appearing in the frame) if self.no_obj_embed_spatial is not None: @@ -827,7 +827,7 @@ class SAM2Model(torch.nn.Module): ..., None, None ].expand(*maskmem_features.shape) - return maskmem_features, maskmem_pos_enc + return maskmem_features, maskmem_out["vision_pos_enc"] def _track_step( self, @@ -1005,7 +1005,151 @@ class SAM2Model(torch.nn.Module): def set_imgsz(self, imgsz): """Set image size to make model compatible with different image sizes.""" + if hasattr(self.image_encoder, "set_imgsz"): + self.image_encoder.set_imgsz(imgsz) self.image_size = imgsz[0] self.sam_prompt_encoder.input_image_size = imgsz - self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16 + self.sam_prompt_encoder.image_embedding_size = [ + x // self.backbone_stride for x in imgsz + ] # fixed ViT patch size of 16 + self.sam_prompt_encoder.mask_input_size = [ + x // self.backbone_stride * 4 for x in imgsz + ] # fixed ViT patch size of 16 self.sam_image_embedding_size = self.image_size // self.backbone_stride # update image embedding size + + +class SAM3Model(SAM2Model): + """SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities.""" + + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, + image_size=1008, + backbone_stride=14, + sigmoid_scale_for_mem_enc=1, + sigmoid_bias_for_mem_enc=0, + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, + max_cond_frames_in_attn=-1, + directly_add_no_mem_embed=False, + use_high_res_features_in_sam=False, + multimask_output_in_sam=False, + multimask_min_pt_num=1, + multimask_max_pt_num=1, + multimask_output_for_tracking=False, + use_multimask_token_for_obj_ptr: bool = False, + iou_prediction_use_sigmoid=False, + memory_temporal_stride_for_eval=1, + non_overlap_masks_for_mem_enc=False, + use_obj_ptrs_in_encoder=False, + max_obj_ptrs_in_encoder=16, + add_tpos_enc_to_obj_ptrs=True, + proj_tpos_enc_in_obj_ptrs=False, + use_signed_tpos_enc_to_obj_ptrs=False, + only_obj_ptrs_in_the_past_for_eval=False, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + fixed_no_obj_ptr: bool = False, + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + no_obj_embed_spatial: bool = False, + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + """SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities.""" + super().__init__( + image_encoder, + memory_attention, + memory_encoder, + num_maskmem, + image_size, + backbone_stride, + sigmoid_scale_for_mem_enc, + sigmoid_bias_for_mem_enc, + binarize_mask_from_pts_for_mem_enc, + use_mask_input_as_output_without_sam, + max_cond_frames_in_attn, + directly_add_no_mem_embed, + use_high_res_features_in_sam, + multimask_output_in_sam, + multimask_min_pt_num, + multimask_max_pt_num, + multimask_output_for_tracking, + use_multimask_token_for_obj_ptr, + iou_prediction_use_sigmoid, + memory_temporal_stride_for_eval, + non_overlap_masks_for_mem_enc, + use_obj_ptrs_in_encoder, + max_obj_ptrs_in_encoder, + add_tpos_enc_to_obj_ptrs, + proj_tpos_enc_in_obj_ptrs, + use_signed_tpos_enc_to_obj_ptrs, + only_obj_ptrs_in_the_past_for_eval, + pred_obj_scores, + pred_obj_scores_mlp, + fixed_no_obj_ptr, + soft_no_obj_ptr, + use_mlp_for_obj_ptr_proj, + no_obj_embed_spatial, + sam_mask_decoder_extra_args, + compile_image_encoder, + ) + self.sam_mask_decoder = SAM2MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + + def forward_image(self, img_batch: torch.Tensor): + """Process image batch through encoder to extract multi-level features for SAM model.""" + backbone_out = self.image_encoder.forward_image_sam2(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0]) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1]) + return backbone_out + + def set_imgsz(self, imgsz: tuple[int, int]): + """Set the image size for the model and mask downsampler.""" + super().set_imgsz(imgsz) + self.memory_encoder.mask_downsampler.interpol_size = [size // 14 * 16 for size in imgsz] + + @staticmethod + def _suppress_shrinked_masks(pred_masks, new_pred_masks, shrink_threshold=0.3): + """Suppress masks that shrink in area after applying pixelwise non-overlapping constraints.""" + area_before = (pred_masks > 0).sum(dim=(-1, -2)) + area_after = (new_pred_masks > 0).sum(dim=(-1, -2)) + area_before = torch.clamp(area_before, min=1.0) + area_ratio = area_after / area_before + keep = area_ratio >= shrink_threshold + keep_mask = keep[..., None, None].expand_as(pred_masks) + pred_masks_after = torch.where(keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks_after + + def _suppress_object_pw_area_shrinkage(self, pred_masks): + """This function suppresses masks that shrink in area after applying pixelwise non-overlapping constraints. Note + that the final output can still be overlapping. + """ + # Apply pixel-wise non-overlapping constraint based on mask scores + pixel_level_non_overlapping_masks = self._apply_non_overlapping_constraints(pred_masks) + # Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints + # NOTE: The output of this function can be a no op if none of the masks shrinked by a large factor. + pred_masks = self._suppress_shrinked_masks(pred_masks, pixel_level_non_overlapping_masks) + return pred_masks diff --git a/ultralytics/models/sam/modules/utils.py b/ultralytics/models/sam/modules/utils.py index e934817934..86ea23ac9f 100644 --- a/ultralytics/models/sam/modules/utils.py +++ b/ultralytics/models/sam/modules/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math from typing import Any import torch @@ -86,7 +87,7 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000) return pos_embed -def init_t_xy(end_x: int, end_y: int): +def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0): """Initialize 1D and 2D coordinate tensors for a grid of specified dimensions. This function creates coordinate tensors for a grid with dimensions end_x ร— end_y. It generates a linear index @@ -95,6 +96,8 @@ def init_t_xy(end_x: int, end_y: int): Args: end_x (int): Width of the grid (number of columns). end_y (int): Height of the grid (number of rows). + scale (float): Scaling factor to apply to the coordinates. + offset (int): Offset to add to the coordinates. Returns: t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y). @@ -110,10 +113,10 @@ def init_t_xy(end_x: int, end_y: int): t = torch.arange(end_x * end_y, dtype=torch.float32) t_x = (t % end_x).float() t_y = torch.div(t, end_x, rounding_mode="floor").float() - return t_x, t_y + return t_x * scale + offset, t_y * scale + offset -def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0, scale_pos: float = 1.0): """Compute axial complex exponential positional encodings for 2D spatial positions in a grid. This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate @@ -124,6 +127,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): end_x (int): Width of the 2D grid. end_y (int): Height of the 2D grid. theta (float, optional): Scaling factor for frequency computation. + scale_pos (float, optional): Scaling factor for position coordinates. Returns: (torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2). @@ -137,7 +141,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) - t_x, t_y = init_t_xy(end_x, end_y) + t_x, t_y = init_t_xy(end_x, end_y, scale=scale_pos) freqs_x = torch.outer(t_x, freqs_x) freqs_y = torch.outer(t_y, freqs_y) freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) @@ -375,3 +379,129 @@ def add_decomposed_rel_pos( ) return attn + + +def get_abs_pos( + abs_pos: torch.Tensor, + has_cls_token: bool, + hw: tuple[int, int], + retain_cls_token: bool = False, + tiling: bool = False, +) -> torch.Tensor: + """Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the + original embeddings. + + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + retain_cls_token: whether to retain the cls_token + tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win) + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C),: if retain_cls_token is False, + otherwise (1, 1+H*W, C). + """ + if retain_cls_token: + assert has_cls_token + + h, w = hw + if has_cls_token: + cls_pos = abs_pos[:, :1] + abs_pos = abs_pos[:, 1:] + + xy_num = abs_pos.shape[1] + size = int(math.sqrt(xy_num)) + assert size * size == xy_num + + if size != h or size != w: + new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2) + if tiling: + new_abs_pos = new_abs_pos.tile([1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])])[ + :, :, :h, :w + ] + else: + new_abs_pos = F.interpolate( + new_abs_pos, + size=(h, w), + mode="bicubic", + align_corners=False, + ) + + if not retain_cls_token: + return new_abs_pos.permute(0, 2, 3, 1) + else: + # add cls_token back, flatten spatial dims + assert has_cls_token + return torch.cat( + [cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)], + dim=1, + ) + + else: + if not retain_cls_token: + return abs_pos.reshape(1, h, w, -1) + else: + assert has_cls_token + return torch.cat([cls_pos, abs_pos], dim=1) + + +def concat_rel_pos( + q: torch.Tensor, + k: torch.Tensor, + q_hw: tuple[int, int], + k_hw: tuple[int, int], + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + rescale: bool = False, + relative_coords: torch.Tensor = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now effectively including rel pos biases. + + Args: + q (torch.Tensor): q tensor with shape (B, L_q, C). + k (torch.Tensor): k tensor with shape (B, L_k, C). + q_hw: These are spatial size of q tensors. + k_hw: These are spatial size of k tensors. + rel_pos_h: These are relative pos embeddings/params of height. + rel_pos_w: These are relative pos embeddings/params of width. + rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will scale by the wrong factor due to + the concat. + relative_coords (torch.Tensor, optional): Precomputed relative coords index tensor. + + Returns: + q, k: But, padded so that qk^T accounts for rel pos biases. + """ + q_h, q_w = q_hw + k_h, k_w = k_hw + + assert (q_h == q_w) and (k_h == k_w), "only square inputs supported" + + if relative_coords is not None: + Rh = rel_pos_h[relative_coords] + Rw = rel_pos_w[relative_coords] + else: + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + + old_scale = dim**0.5 + new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa + # attn will be divided by new_scale, but we want to divide q by old_scale + scale_ratio = new_scale / old_scale + + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w) + + eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device) + eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device) + + eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h]) + eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w]) + + q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1) + k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(B, k_h * k_w, -1) + + return q, k diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index 8f421d977a..b236b3d9d3 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -10,7 +10,8 @@ segmentation tasks. from __future__ import annotations -from collections import OrderedDict +from collections import OrderedDict, defaultdict +from copy import deepcopy from typing import Any import cv2 @@ -21,7 +22,8 @@ import torch.nn.functional as F from ultralytics.data.augment import LetterBox from ultralytics.engine.predictor import BasePredictor from ultralytics.engine.results import Results -from ultralytics.utils import DEFAULT_CFG, ops +from ultralytics.utils import DEFAULT_CFG, LOGGER, ops +from ultralytics.utils.metrics import box_iou, mask_iou from ultralytics.utils.torch_utils import select_device, smart_inference_mode from .amg import ( @@ -35,6 +37,7 @@ from .amg import ( uncrop_boxes_xyxy, uncrop_masks, ) +from .sam3.geometry_encoders import Prompt class Predictor(BasePredictor): @@ -79,6 +82,8 @@ class Predictor(BasePredictor): >>> results = predictor(bboxes=bboxes) """ + stride = 16 + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): """Initialize the Predictor with configuration, overrides, and callbacks. @@ -156,7 +161,7 @@ class Predictor(BasePredictor): 1 """ assert len(im) == 1, "SAM model does not currently support batched inference" - letterbox = LetterBox(self.args.imgsz, auto=False, center=False) + letterbox = LetterBox(self.imgsz, auto=False, center=False) return [letterbox(image=x) for x in im] def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): @@ -520,30 +525,6 @@ class Predictor(BasePredictor): self.segment_all = False return results - def setup_source(self, source): - """Set up the data source for inference. - - This method configures the data source from which images will be fetched for inference. It supports various - input types such as image files, directories, video files, and other compatible data sources. - - Args: - source (str | Path | None): The path or identifier for the image data source. Can be a file path, directory - path, URL, or other supported source types. - - Examples: - >>> predictor = Predictor() - >>> predictor.setup_source("path/to/images") - >>> predictor.setup_source("video.mp4") - >>> predictor.setup_source(None) # Uses default source if available - - Notes: - - If source is None, the method may use a default source if configured. - - The method adapts to different source types and prepares them for subsequent inference steps. - - Supported source types may include local files, directories, URLs, and video streams. - """ - if source is not None: - super().setup_source(source) - def set_image(self, image): """Preprocess and set a single image for inference. @@ -576,12 +557,18 @@ class Predictor(BasePredictor): self.features = self.get_im_features(im) break - def get_im_features(self, im): - """Extract image features using the SAM model's image encoder for subsequent mask prediction.""" + def setup_source(self, source): + """Set up the data source for SAM inference.""" + if source is None: # handle the situation when set_imgsz in advance + return + super().setup_source(source, self.stride) assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], ( f"SAM models only support square image size, but got {self.imgsz}." ) self.model.set_imgsz(self.imgsz) + + def get_im_features(self, im): + """Extract image features using the SAM model's image encoder for subsequent mask prediction.""" return self.model.image_encoder(im) def set_prompts(self, prompts): @@ -726,6 +713,7 @@ class SAM2Predictor(Predictor): (128, 128), (64, 64), ] + stride = 16 def get_model(self): """Retrieve and initialize the Segment Anything Model 2 (SAM2) for image segmentation tasks.""" @@ -767,45 +755,13 @@ class SAM2Predictor(Predictor): points, labels = bboxes, bbox_labels return points, labels, masks - def set_image(self, image): - """Preprocess and set a single image for inference using the SAM2 model. - - This method initializes the model if not already done, configures the data source to the specified image, and - preprocesses the image for feature extraction. It supports setting only one image at a time. - - Args: - image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image. - - Raises: - AssertionError: If more than one image is attempted to be set. - - Examples: - >>> predictor = SAM2Predictor() - >>> predictor.set_image("path/to/image.jpg") - >>> predictor.set_image(np.array([...])) # Using a numpy array - - Notes: - - This method must be called before performing any inference on a new image. - - The method caches the extracted features for efficient subsequent inferences on the same image. - - Only one image can be set at a time. To process multiple images, call this method for each new image. - """ - if self.model is None: - self.setup_model(model=None) - self.setup_source(image) - assert len(self.dataset) == 1, "`set_image` only supports setting one image!" - for batch in self.dataset: - im = self.preprocess(batch[1]) - self.features = self.get_im_features(im) - break + def setup_source(self, source): + """Set up the data source and image size for SAM2 inference.""" + super().setup_source(source) + self._bb_feat_sizes = [[int(x / (self.stride * i)) for x in self.imgsz] for i in [1 / 4, 1 / 2, 1]] def get_im_features(self, im): """Extract image features from the SAM image encoder for subsequent processing.""" - assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], ( - f"SAM 2 models only support square image size, but got {self.imgsz}." - ) - self.model.set_imgsz(self.imgsz) - self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]] - backbone_out = self.model.forward_image(im) _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) if self.model.directly_add_no_mem_embed: @@ -1037,6 +993,7 @@ class SAM2VideoPredictor(SAM2Predictor): labels=None, masks=None, frame_idx=0, + inference_state: dict[str, Any] | None = None, ): """Add new points or masks to a specific frame for a given object ID. @@ -1051,6 +1008,8 @@ class SAM2VideoPredictor(SAM2Predictor): labels (torch.Tensor, optional): The labels corresponding to the points. masks (torch.Tensor, optional): Binary masks for the object. frame_idx (int, optional): The index of the frame to which the prompts are applied. + inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's + inference state. Returns: pred_masks (torch.Tensor): The flattened predicted masks. @@ -1064,24 +1023,25 @@ class SAM2VideoPredictor(SAM2Predictor): - If the frame is being tracked for the first time, it is treated as an initial conditioning frame. - The method handles the consolidation of outputs and resizing of masks to the original video resolution. """ + inference_state = inference_state or self.inference_state assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other." - obj_idx = self._obj_id_to_idx(obj_id) + obj_idx = self._obj_id_to_idx(obj_id, inference_state) point_inputs = None pop_key = "point_inputs_per_obj" if points is not None: point_inputs = {"point_coords": points, "point_labels": labels} - self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs + inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs pop_key = "mask_inputs_per_obj" - self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks - self.inference_state[pop_key][obj_idx].pop(frame_idx, None) + inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks + inference_state[pop_key][obj_idx].pop(frame_idx, None) # If this frame hasn't been tracked before, we treat it as an initial conditioning # frame, meaning that the inputs points are to generate segments on this frame without # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # the input points will be used to correct the already tracked masks. - is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"] - obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] - obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] # Add a frame to conditioning output if it's an initial conditioning frame or # if the model sees all frames receiving clicks/mask as conditioning frames. is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond @@ -1119,6 +1079,7 @@ class SAM2VideoPredictor(SAM2Predictor): # them into memory. run_mem_encoder=False, prev_sam_mask_logits=prev_sam_mask_logits, + inference_state=inference_state, ) # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out @@ -1128,31 +1089,37 @@ class SAM2VideoPredictor(SAM2Predictor): frame_idx, is_cond=is_cond, run_mem_encoder=False, + inference_state=inference_state, ) pred_masks = consolidated_out["pred_masks"].flatten(0, 1) return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device) @smart_inference_mode() - def propagate_in_video_preflight(self): + def propagate_in_video_preflight(self, inference_state: dict[str, Any] | None = None): """Prepare inference_state and consolidate temporary outputs before tracking. This method marks the start of tracking, disallowing the addition of new objects until the session is reset. It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent with the provided inputs. + + Args: + inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's + inference state. """ + inference_state = inference_state or self.inference_state # Tracking has started and we don't allow adding new objects until session is reset. - self.inference_state["tracking_has_started"] = True - batch_size = len(self.inference_state["obj_idx_to_id"]) + inference_state["tracking_has_started"] = True + batch_size = len(inference_state["obj_idx_to_id"]) # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and # add them into "output_dict". - temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"] - output_dict = self.inference_state["output_dict"] + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] # "consolidated_frame_inds" contains indices of those frames where consolidated # temporary outputs have been added (either in this call or any previous calls # to `propagate_in_video_preflight`). - consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] for is_cond in {False, True}: # Separately consolidate conditioning and non-conditioning temp outputs storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" @@ -1166,11 +1133,11 @@ class SAM2VideoPredictor(SAM2Predictor): # consolidate the temporary output across all objects on this frame for frame_idx in temp_frame_inds: consolidated_out = self._consolidate_temp_output_across_obj( - frame_idx, is_cond=is_cond, run_mem_encoder=True + frame_idx, is_cond=is_cond, run_mem_encoder=True, inference_state=inference_state ) # merge them into "output_dict" and also create per-object slices output_dict[storage_key][frame_idx] = consolidated_out - self._add_output_per_object(frame_idx, consolidated_out, storage_key) + self._add_output_per_object(frame_idx, consolidated_out, storage_key, inference_state=inference_state) if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): # clear non-conditioning memory of the surrounding frames self._clear_non_cond_mem_around_input(frame_idx) @@ -1183,7 +1150,7 @@ class SAM2VideoPredictor(SAM2Predictor): # output on the same frame in "non_cond_frame_outputs" for frame_idx in output_dict["cond_frame_outputs"]: output_dict["non_cond_frame_outputs"].pop(frame_idx, None) - for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + for obj_output_dict in inference_state["output_dict_per_obj"].values(): for frame_idx in obj_output_dict["cond_frame_outputs"]: obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: @@ -1196,9 +1163,9 @@ class SAM2VideoPredictor(SAM2Predictor): consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"] ) input_frames_inds = set() - for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values(): + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): input_frames_inds.update(point_inputs_per_frame.keys()) - for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values(): + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): input_frames_inds.update(mask_inputs_per_frame.keys()) assert all_consolidated_frame_inds == input_frames_inds @@ -1217,9 +1184,21 @@ class SAM2VideoPredictor(SAM2Predictor): return assert predictor.dataset is not None assert predictor.dataset.mode == "video" + predictor.inference_state = predictor._init_state(predictor.dataset.frames) + @staticmethod + def _init_state(num_frames): + """Initialize an inference state. + + This function sets up the initial state required for performing inference on video data. It includes + initializing various dictionaries and ordered dictionaries that will store inputs, outputs, and other metadata + relevant to the tracking process. + + Args: + num_frames (int): The number of frames in the video. + """ inference_state = { - "num_frames": predictor.dataset.frames, + "num_frames": num_frames, # TODO: see if there's any chance to remove it "point_inputs_per_obj": {}, # inputs points on each frame "mask_inputs_per_obj": {}, # inputs mask on each frame "constants": {}, # values that don't change across frames (so we only need to hold one copy of them) @@ -1247,7 +1226,7 @@ class SAM2VideoPredictor(SAM2Predictor): "tracking_has_started": False, "frames_already_tracked": [], } - predictor.inference_state = inference_state + return inference_state def get_im_features(self, im, batch=1): """Extract and process image features using SAM2's image encoder for subsequent segmentation tasks. @@ -1265,7 +1244,6 @@ class SAM2VideoPredictor(SAM2Predictor): - If `batch` is greater than 1, the features are expanded to fit the batch size. - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features. """ - self.model.set_imgsz(self.imgsz) backbone_out = self.model.forward_image(im) if batch > 1: # expand features if there's more than one prompt for i, feat in enumerate(backbone_out["backbone_fpn"]): @@ -1276,11 +1254,13 @@ class SAM2VideoPredictor(SAM2Predictor): _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out) return vis_feats, vis_pos_embed, feat_sizes - def _obj_id_to_idx(self, obj_id): + def _obj_id_to_idx(self, obj_id, inference_state: dict[str, Any] | None = None): """Map client-side object id to model-side object index. Args: obj_id (int): The unique identifier of the object provided by the client side. + inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's + inference state. Returns: (int): The index of the object on the model side. @@ -1295,27 +1275,28 @@ class SAM2VideoPredictor(SAM2Predictor): - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`). - Additional data structures are initialized for the new object to store inputs and outputs. """ - obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None) + inference_state = inference_state or self.inference_state + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) if obj_idx is not None: return obj_idx # This is a new object id not sent to the server before. We only allow adding # new objects *before* the tracking starts. - allow_new_object = not self.inference_state["tracking_has_started"] + allow_new_object = not inference_state["tracking_has_started"] if allow_new_object: # get the next object slot - obj_idx = len(self.inference_state["obj_id_to_idx"]) - self.inference_state["obj_id_to_idx"][obj_id] = obj_idx - self.inference_state["obj_idx_to_id"][obj_idx] = obj_id - self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"]) + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) # set up input and output structures for this object - self.inference_state["point_inputs_per_obj"][obj_idx] = {} - self.inference_state["mask_inputs_per_obj"][obj_idx] = {} - self.inference_state["output_dict_per_obj"][obj_idx] = { + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } - self.inference_state["temp_output_dict_per_obj"][obj_idx] = { + inference_state["temp_output_dict_per_obj"][obj_idx] = { "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } @@ -1323,7 +1304,7 @@ class SAM2VideoPredictor(SAM2Predictor): else: raise RuntimeError( f"Cannot add new object id {obj_id} after tracking starts. " - f"All existing object ids: {self.inference_state['obj_ids']}. " + f"All existing object ids: {inference_state['obj_ids']}. " f"Please call 'reset_state' to restart from scratch." ) @@ -1338,6 +1319,7 @@ class SAM2VideoPredictor(SAM2Predictor): reverse, run_mem_encoder, prev_sam_mask_logits=None, + inference_state: dict[str, Any] | None = None, ): """Run tracking on a single frame based on current inputs and previous memory. @@ -1351,6 +1333,8 @@ class SAM2VideoPredictor(SAM2Predictor): reverse (bool): Indicates if the tracking should be performed in reverse order. run_mem_encoder (bool): Indicates if the memory encoder should be executed. prev_sam_mask_logits (torch.Tensor | None): Previous mask logits for the current object. + inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's + inference state. Returns: (dict): A dictionary containing the output of the tracking step, including updated features and predictions. @@ -1364,9 +1348,10 @@ class SAM2VideoPredictor(SAM2Predictor): - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored. - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements. """ + inference_state = inference_state or self.inference_state # Retrieve correct image features current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features( - self.inference_state["im"], batch_size + inference_state["im"], batch_size ) # point and mask should not appear as input simultaneously on the same frame @@ -1380,7 +1365,7 @@ class SAM2VideoPredictor(SAM2Predictor): point_inputs=point_inputs, mask_inputs=mask_inputs, output_dict=output_dict, - num_frames=self.inference_state["num_frames"], + num_frames=inference_state["num_frames"], track_in_reverse=reverse, run_mem_encoder=run_mem_encoder, prev_sam_mask_logits=prev_sam_mask_logits, @@ -1398,10 +1383,10 @@ class SAM2VideoPredictor(SAM2Predictor): # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area) # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it - current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"]) + current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"], inference_state) return current_out - def _get_maskmem_pos_enc(self, out_maskmem_pos_enc): + def _get_maskmem_pos_enc(self, out_maskmem_pos_enc, inference_state: dict[str, Any] | None = None): """Cache and manage the positional encoding for mask memory across frames and objects. This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for mask memory, which is @@ -1413,6 +1398,8 @@ class SAM2VideoPredictor(SAM2Predictor): Args: out_maskmem_pos_enc (list[torch.Tensor] | None): The positional encoding for mask memory. Should be a list of tensors or None. + inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's + inference state. Returns: (list[torch.Tensor]): The positional encoding for mask memory, either cached or expanded. @@ -1423,7 +1410,8 @@ class SAM2VideoPredictor(SAM2Predictor): - The method checks if the positional encoding has already been cached in the session's constants. - If the batch size is greater than one, the cached encoding is expanded to fit the batch size. """ - model_constants = self.inference_state["constants"] + inference_state = inference_state or self.inference_state + model_constants = inference_state["constants"] # "out_maskmem_pos_enc" should be either a list of tensors or None if out_maskmem_pos_enc is not None: if "maskmem_pos_enc" not in model_constants: @@ -1444,6 +1432,7 @@ class SAM2VideoPredictor(SAM2Predictor): frame_idx, is_cond=False, run_mem_encoder=False, + inference_state: dict[str, Any] | None = None, ): """Consolidate per-object temporary outputs into a single output for all objects. @@ -1457,6 +1446,8 @@ class SAM2VideoPredictor(SAM2Predictor): is_cond (bool, optional): Indicates if the frame is considered a conditioning frame. run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after consolidating the outputs. + inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's + inference state. Returns: (dict): A consolidated output dictionary containing the combined results for all objects. @@ -1467,7 +1458,8 @@ class SAM2VideoPredictor(SAM2Predictor): - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder. - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True. """ - batch_size = len(self.inference_state["obj_idx_to_id"]) + inference_state = inference_state or self.inference_state + batch_size = len(inference_state["obj_idx_to_id"]) storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" @@ -1478,7 +1470,8 @@ class SAM2VideoPredictor(SAM2Predictor): "maskmem_features": None, "maskmem_pos_enc": None, "pred_masks": torch.full( - size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4), + # size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4), + size=(batch_size, 1, *self._bb_feat_sizes[0]), fill_value=-1024.0, dtype=self.torch_dtype, device=self.device, @@ -1499,8 +1492,8 @@ class SAM2VideoPredictor(SAM2Predictor): ), } for obj_idx in range(batch_size): - obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] - obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] out = ( obj_temp_output_dict[storage_key].get(frame_idx) # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, @@ -1540,21 +1533,25 @@ class SAM2VideoPredictor(SAM2Predictor): high_res_masks=high_res_masks, is_mask_from_pts=True, # these frames are what the user interacted with object_score_logits=consolidated_out["object_score_logits"], + inference_state=inference_state, ) return consolidated_out - def _get_empty_mask_ptr(self, frame_idx): + def _get_empty_mask_ptr(self, frame_idx, inference_state: dict[str, Any] | None = None): """Get a dummy object pointer based on an empty mask on the current frame. Args: frame_idx (int): The index of the current frame for which to generate the dummy object pointer. + inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's + inference state. Returns: (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask. """ + inference_state = inference_state or self.inference_state # Retrieve correct image features - current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"]) + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(inference_state["im"]) # Feed the empty mask and image feature above to get a dummy object pointer current_out = self.model.track_step( @@ -1567,14 +1564,21 @@ class SAM2VideoPredictor(SAM2Predictor): # A dummy (empty) mask with a single object mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=self.torch_dtype, device=self.device), output_dict={}, - num_frames=self.inference_state["num_frames"], + num_frames=inference_state["num_frames"], track_in_reverse=False, run_mem_encoder=False, prev_sam_mask_logits=None, ) return current_out["obj_ptr"] - def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts): + def _run_memory_encoder( + self, + batch_size, + high_res_masks, + object_score_logits, + is_mask_from_pts, + inference_state: dict[str, Any] | None = None, + ): """Run the memory encoder on masks. This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their @@ -1585,13 +1589,16 @@ class SAM2VideoPredictor(SAM2Predictor): high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory. object_score_logits (torch.Tensor): Logits representing the object scores. is_mask_from_pts (bool): Indicates if the mask is derived from point interactions. + inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's + inference state. Returns: maskmem_features (torch.Tensor): The encoded mask features. maskmem_pos_enc (torch.Tensor): The positional encoding. """ + inference_state = inference_state or self.inference_state # Retrieve correct image features - current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size) + current_vision_feats, _, feat_sizes = self.get_im_features(inference_state["im"], batch_size) maskmem_features, maskmem_pos_enc = self.model._encode_new_memory( current_vision_feats=current_vision_feats, feat_sizes=feat_sizes, @@ -1601,12 +1608,14 @@ class SAM2VideoPredictor(SAM2Predictor): ) # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it - maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc) + maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc, inference_state) return maskmem_features.to( dtype=torch.float16, device=self.device, non_blocking=self.device.type == "cuda" ), maskmem_pos_enc - def _add_output_per_object(self, frame_idx, current_out, storage_key): + def _add_output_per_object( + self, frame_idx, current_out, storage_key, inference_state: dict[str, Any] | None = None + ): """Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj. The resulting slices share the same tensor storage. @@ -1615,14 +1624,17 @@ class SAM2VideoPredictor(SAM2Predictor): frame_idx (int): The index of the current frame. current_out (dict): The current output dictionary containing multi-object outputs. storage_key (str): The key used to store the output in the per-object output dictionary. + inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's + inference state. """ + inference_state = inference_state or self.inference_state maskmem_features = current_out["maskmem_features"] assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) maskmem_pos_enc = current_out["maskmem_pos_enc"] assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) - for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items(): + for obj_idx, obj_output_dict in inference_state["output_dict_per_obj"].items(): obj_slice = slice(obj_idx, obj_idx + 1) obj_out = { "maskmem_features": None, @@ -1636,7 +1648,7 @@ class SAM2VideoPredictor(SAM2Predictor): obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] obj_output_dict[storage_key][frame_idx] = obj_out - def _clear_non_cond_mem_around_input(self, frame_idx): + def _clear_non_cond_mem_around_input(self, frame_idx, inference_state: dict[str, Any] | None = None): """Remove the non-conditioning memory around the input frame. When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain @@ -1646,15 +1658,179 @@ class SAM2VideoPredictor(SAM2Predictor): Args: frame_idx (int): The index of the current frame where user interaction occurred. + inference_state (dict[str, Any], optional): The current inference state. If None, uses the instance's + inference state. """ + inference_state = inference_state or self.inference_state r = self.model.memory_temporal_stride_for_eval frame_idx_begin = frame_idx - r * self.model.num_maskmem frame_idx_end = frame_idx + r * self.model.num_maskmem for t in range(frame_idx_begin, frame_idx_end + 1): - self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None) - for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): obj_output_dict["non_cond_frame_outputs"].pop(t, None) + @smart_inference_mode() + def remove_object(self, inference_state, obj_id, strict=False): + """Remove an object id from the tracking state. If strict is True, we check whether the object id actually + exists and raise an error if it doesn't exist. + """ + old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) + # Check whether this object_id to remove actually exists and possibly raise an error. + if old_obj_idx_to_rm is None: + if not strict: + return inference_state["obj_ids"] + raise RuntimeError( + f"Cannot remove object id {obj_id} as it doesn't exist. " + f"All existing object ids: {inference_state['obj_ids']}." + ) + + # If this is the only remaining object id, we simply reset the state. + if len(inference_state["obj_id_to_idx"]) == 1: + self.clear_all_points_in_video(inference_state) + return inference_state["obj_ids"] + + # There are still remaining objects after removing this object id. In this case, + # we need to delete the object storage from inference state tensors. + # Step 0: clear the input on those frames where this object id has point or mask input + # (note that this step is required as it might downgrade conditioning frames to + # non-conditioning ones) + obj_input_frames_inds = set() + obj_input_frames_inds.update(inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]) + obj_input_frames_inds.update(inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]) + for frame_idx in obj_input_frames_inds: + self.clear_all_points_in_frame(inference_state, frame_idx, obj_id) + + # Step 1: Update the object id mapping (note that it must be done after Step 0, + # since Step 0 still requires the old object id mappings in inference_state) + old_obj_ids = inference_state["obj_ids"] + old_obj_inds = list(range(len(old_obj_ids))) + remain_old_obj_inds = old_obj_inds.copy() + remain_old_obj_inds.remove(old_obj_idx_to_rm) + new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] + new_obj_inds = list(range(len(new_obj_ids))) + # build new mappings + old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) + inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) + inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) + inference_state["obj_ids"] = new_obj_ids + + # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. + # (note that "consolidated_frame_inds" doesn't need to be updated in this step as + # it's already handled in Step 0) + def _map_keys(container): + new_kvs = [] + for k in old_obj_inds: + v = container.pop(k) + if k in old_idx_to_new_idx: + new_kvs.append((old_idx_to_new_idx[k], v)) + container.update(new_kvs) + + _map_keys(inference_state["point_inputs_per_obj"]) + _map_keys(inference_state["mask_inputs_per_obj"]) + _map_keys(inference_state["output_dict_per_obj"]) + _map_keys(inference_state["temp_output_dict_per_obj"]) + + # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices. + def _slice_state(output_dict, storage_key): + for frame_idx, out in output_dict[storage_key].items(): + out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds] + out["maskmem_pos_enc"] = [x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]] + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(out["maskmem_pos_enc"], inference_state) + out["pred_masks"] = out["pred_masks"][remain_old_obj_inds] + out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds] + out["object_score_logits"] = out["object_score_logits"][remain_old_obj_inds] + # also update the per-object slices + self._add_output_per_object(frame_idx, out, storage_key, inference_state=inference_state) + + _slice_state(inference_state["output_dict"], "cond_frame_outputs") + _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") + + return inference_state["obj_ids"] + + @smart_inference_mode() + def clear_all_points_in_frame(self, inference_state, frame_idx, obj_id): + """Remove all input points or mask in a specific frame for a given object.""" + obj_idx = self._obj_id_to_idx(obj_id, inference_state) + + # Clear the conditioning information on the given frame + inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) + inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) + + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) + temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) + + # Check and see if there are still any inputs left on this frame + batch_size = len(inference_state["obj_idx_to_id"]) + frame_has_input = False + for obj_idx2 in range(batch_size): + if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + + # If this frame has no remaining inputs for any objects, we further clear its + # conditioning frame status + if not frame_has_input: + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) + out = output_dict["cond_frame_outputs"].pop(frame_idx, None) + if out is not None: + # The frame is not a conditioning frame anymore since it's not receiving inputs, + # so we "downgrade" its output (if exists) to a non-conditioning frame output. + output_dict["non_cond_frame_outputs"][frame_idx] = out + inference_state["frames_already_tracked"].pop(frame_idx, None) + # Similarly, do it for the sliced output on each object. + for obj_idx2 in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] + obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) + if obj_out is not None: + obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out + + # If all the conditioning frames have been removed, we also clear the tracking outputs + if len(output_dict["cond_frame_outputs"]) == 0: + self._reset_tracking_results(inference_state) + + @smart_inference_mode() + def clear_all_points_in_video(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + inference_state["first_ann_frame_idx"] = None + class SAM2DynamicInteractivePredictor(SAM2Predictor): """SAM2DynamicInteractivePredictor extends SAM2Predictor to support dynamic interactions with video frames or a @@ -1986,3 +2162,1785 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor): "obj_ptr": obj_ptr, "object_score_logits": object_score_logits, } + + +class SAM3Predictor(SAM2Predictor): + """Segment Anything Model 3 (SAM3) Interactive Predictor for image segmentation tasks.""" + + _bb_feat_sizes = [ + (288, 288), + (144, 144), + (72, 72), + ] + stride = 14 + + def setup_model(self, model=None, verbose=True): + """Setup the SAM3 model with appropriate mean and standard deviation for preprocessing.""" + super().setup_model(model, verbose) + # update mean and std + self.mean = torch.tensor([127.5, 127.5, 127.5]).view(-1, 1, 1).to(self.device) + self.std = torch.tensor([127.5, 127.5, 127.5]).view(-1, 1, 1).to(self.device) + + def get_model(self): + """Retrieve and initialize the Segment Anything Model 2 (SAM2) for image segmentation tasks.""" + from .build_sam3 import build_interactive_sam3 # slow import + + return build_interactive_sam3(self.args.model, compile=self.args.compile) + + +class SAM3SemanticPredictor(SAM3Predictor): + """Segment Anything Model 3 (SAM3) Predictor for image segmentation tasks.""" + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None, bpe_path=None): + """Initialize the SAM3SemanticPredictor with configuration and optional overrides.""" + super().__init__(cfg, overrides, _callbacks) + self.bpe_path = bpe_path + + def get_model(self): + """Retrieve and initialize the Segment Anything Model 3 (SAM3) for image segmentation tasks.""" + from .build_sam3 import build_sam3_image_model # slow import + + return build_sam3_image_model(self.args.model, bpe_path=self.bpe_path, compile=self.args.compile) + + @smart_inference_mode() + def get_im_features(self, im): + """Extract image features using the model's backbone.""" + return self.model.backbone.forward_image(im) + + def pre_transform(self, im): + """Perform initial transformations on the input image for preprocessing. + + This method applies transformations such as resizing to prepare the image for further preprocessing. Currently, + batched inference is not supported; hence the list length should be 1. + + Args: + im (list[np.ndarray]): List containing a single image in HWC numpy array format. + + Returns: + (list[np.ndarray]): List containing the transformed image. + + Raises: + AssertionError: If the input list contains more than one image. + + Examples: + >>> predictor = Predictor() + >>> image = np.random.rand(480, 640, 3) # Single HWC image + >>> transformed = predictor.pre_transform([image]) + >>> print(len(transformed)) + 1 + """ + assert len(im) == 1, "SAM model does not currently support batched inference" + letterbox = LetterBox(self.imgsz, auto=False, center=False, scale_fill=True) # hardcode here for sam3 + return [letterbox(image=x) for x in im] + + def _prepare_geometric_prompts(self, src_shape, bboxes=None, labels=None): + """Prepare prompts by normalizing bounding boxes and points to the destination shape.""" + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=self.torch_dtype, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + # needs xywh as input + bboxes = ops.xyxy2xywh(bboxes) + bboxes[:, 0::2] /= src_shape[1] + bboxes[:, 1::2] /= src_shape[0] + # Assuming labels are all positive if users don't pass labels. + if labels is None: + labels = np.ones(bboxes.shape[:-1]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + assert bboxes.shape[-2] == labels.shape[-1], ( + f"Number of points {bboxes.shape[-2]} should match number of labels {labels.shape[-1]}." + ) + bboxes = bboxes.view(-1, 1, 4) # (N, 1, 4) + labels = labels.view(-1, 1) # (N, 1) + return bboxes, labels + + def _inference_features(self, features, bboxes=None, labels=None, text: list[str] | None = None): + """Run inference on the extracted features with optional bounding boxes and labels.""" + # NOTE: priority: bboxes > text > pre-set classes + nc = 1 if bboxes is not None else len(text) if text is not None else len(self.model.names) + geometric_prompt = self._get_dummy_prompt(nc) + if bboxes is not None: + for i in range(len(bboxes)): + geometric_prompt.append_boxes(bboxes[[i]], labels[[i]]) + if text is None: + text = ["visual"] # bboxes needs this `visual` text prompt if no text passed + if text is not None and self.model.names != text: + self.model.set_classes(text=text) + outputs = self.model.forward_grounding( + backbone_out=features, + text_ids=torch.arange(nc, device=self.device, dtype=torch.long), + geometric_prompt=geometric_prompt, + ) + return outputs + + def postprocess(self, preds, img, orig_imgs): + """Post-process the predictions to apply non-overlapping constraints if required.""" + pred_boxes = preds["pred_boxes"] # (nc, num_query, 4) + pred_logits = preds["pred_logits"] + pred_masks = preds["pred_masks"] + pred_scores = pred_logits.sigmoid() + presence_score = preds["presence_logit_dec"].sigmoid().unsqueeze(1) + pred_scores = (pred_scores * presence_score).squeeze(-1) + pred_cls = torch.tensor( + list(range(pred_scores.shape[0])), + dtype=pred_scores.dtype, + device=pred_scores.device, + )[:, None].expand_as(pred_scores) + pred_boxes = torch.cat([pred_boxes, pred_scores[..., None], pred_cls[..., None]], dim=-1) + + keep = pred_scores > self.args.conf + pred_masks = pred_masks[keep] + pred_boxes = pred_boxes[keep] + pred_boxes[:, :4] = ops.xywh2xyxy(pred_boxes[:, :4]) + + names = getattr(self.model, "names", [str(i) for i in range(pred_scores.shape[0])]) + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + results = [] + for masks, boxes, orig_img, img_path in zip([pred_masks], [pred_boxes], orig_imgs, self.batch[0]): + if masks.shape[0] == 0: + masks, boxes = None, torch.zeros((0, 6), device=pred_masks.device) + else: + masks = F.interpolate(masks.float()[None], orig_img.shape[:2], mode="bilinear")[0] > 0.5 + boxes[..., [0, 2]] *= orig_img.shape[1] + boxes[..., [1, 3]] *= orig_img.shape[0] + results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=boxes)) + return results + + def inference(self, im, bboxes=None, labels=None, text: list[str] | None = None, *args, **kwargs): + """Perform inference on a single image with optional prompts.""" + bboxes = self.prompts.pop("bboxes", bboxes) + labels = self.prompts.pop("labels", labels) + text = self.prompts.pop("text", text) + features = self.get_im_features(im) if self.features is None else self.features + prompts = self._prepare_geometric_prompts(self.batch[1][0].shape[:2], bboxes, labels) + return self._inference_features(features, *prompts, text=text) + + @smart_inference_mode() + def inference_features( + self, + features, + src_shape, + bboxes=None, + labels=None, + text: list[str] | None = None, + ): + """Perform prompts preprocessing and inference on provided image features using the SAM model. + + Args: + features (dict[str, Any]): Extracted image features from the SAM3 model image encoder. + src_shape (tuple[int, int]): The source shape (height, width) of the input image. + bboxes (np.ndarray | list[list[float]] | None): Bounding boxes in xyxy format with shape (N, 4). pixels. + labels (np.ndarray | list[int] | None): Point prompt labels with shape (N, ). + text (list[str] | None): List of text prompts corresponding to the classes. + + Returns: + pred_masks (torch.Tensor): The output masks in shape (C, H, W), where C is the number of generated masks. + pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 6), where N is the number of boxes. + Each box is in xyxy format with additional columns for score and class. + + Notes: + - The input features is a torch.Tensor of shape (B, C, H, W) if performing on SAM, or a dict[str, Any] if performing on SAM2. + """ + prompts = self._prepare_geometric_prompts(src_shape[:2], bboxes, labels) + preds = self._inference_features(features, *prompts, text=text) + pred_boxes = preds["pred_boxes"] # (nc, num_query, 4) + pred_logits = preds["pred_logits"] + pred_masks = preds["pred_masks"] + pred_scores = pred_logits.sigmoid() + presence_score = preds["presence_logit_dec"].sigmoid().unsqueeze(1) + pred_scores = (pred_scores * presence_score).squeeze(-1) + pred_cls = torch.tensor( + list(range(pred_scores.shape[0])), + dtype=pred_scores.dtype, + device=pred_scores.device, + )[:, None].expand_as(pred_scores) + pred_boxes = torch.cat([pred_boxes, pred_scores[..., None], pred_cls[..., None]], dim=-1) + + keep = pred_scores > self.args.conf + pred_masks = pred_masks[keep] + pred_boxes = pred_boxes[keep] + pred_boxes[:, :4] = ops.xywh2xyxy(pred_boxes[:, :4]) + + if pred_masks.shape[0] == 0: + pred_masks, pred_boxes = None, torch.zeros((0, 6), device=pred_masks.device) + else: + pred_masks = F.interpolate(pred_masks.float()[None], src_shape[:2], mode="bilinear")[0] > 0.5 + pred_boxes[..., 0] *= src_shape[1] + pred_boxes[..., 1] *= src_shape[0] + pred_boxes[..., 2] *= src_shape[1] + pred_boxes[..., 3] *= src_shape[0] + return pred_masks, pred_boxes + + def reset_prompts(self): + """Reset the prompts for the predictor.""" + self.prompts = {} + self.model.text_embeddings = {} + + def _get_dummy_prompt(self, num_prompts=1): + """Get a dummy geometric prompt with zero boxes.""" + geometric_prompt = Prompt( + box_embeddings=torch.zeros(0, num_prompts, 4, device=self.device), + box_mask=torch.zeros(num_prompts, 0, device=self.device, dtype=torch.bool), + ) + return geometric_prompt + + +class SAM3VideoPredictor(SAM2VideoPredictor, SAM3Predictor): + """Segment Anything Model 3 (SAM3) Video Predictor for video segmentation tasks.""" + + def propagate_in_video(self, inference_state, frame_idx): + """Perform image segmentation inference based on the given input cues, using the currently loaded image. This + method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt + encoder, and mask decoder for real-time and promptable segmentation tasks. + + Args: + inference_state (dict): The current state of inference, including input cues and previous outputs. + frame_idx (int): The index of the current frame in the video sequence. + """ + frame = frame_idx + output_dict = inference_state["output_dict"] + obj_ids = inference_state["obj_ids"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + batch_size = len(inference_state["obj_idx_to_id"]) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + + if frame in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame] + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame) + elif frame in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame] + else: + storage_key = "non_cond_frame_outputs" + current_out = self._run_single_frame_inference( + output_dict=output_dict, + frame_idx=frame, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=False, + run_mem_encoder=True, + inference_state=inference_state, + ) + output_dict[storage_key][frame] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object(frame, current_out, storage_key, inference_state=inference_state) + inference_state["frames_already_tracked"].append(frame) + pred_masks = current_out["pred_masks"].flatten(0, 1) + obj_scores = current_out["object_score_logits"] + + return obj_ids, pred_masks, obj_scores + + def get_im_features(self, im, batch=1): + """A wrapper to get image features, supporting pre-extracted backbone outputs.""" + if getattr(self, "backbone_out", None): + backbone_out = self.backbone_out + if batch > 1: # expand features if there's more than one prompt + backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(backbone_out["backbone_fpn"]): + backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1) + for i, pos in enumerate(backbone_out["vision_pos_enc"]): + pos = pos.expand(batch, -1, -1, -1) + backbone_out["vision_pos_enc"][i] = pos + _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out) + return vis_feats, vis_pos_embed, feat_sizes + return super().get_im_features(im, batch) + + +class SAM3VideoSemanticPredictor(SAM3SemanticPredictor): + """Segment Anything Model 3 (SAM3) Video Semantic Predictor.""" + + HIGH_CONF_THRESH = 0.8 + HIGH_IOU_THRESH = 0.8 + NO_OBJ_LOGIT = -10.0 + NEVER_OCCLUDED = -1 + ALWAYS_OCCLUDED = 100000 + + UNCONFIRMED = 1 # newly added masklet, not confirmed by any detection yet + CONFIRMED = 2 # confirmed by at least one detection + _bb_feat_sizes = [ + (288, 288), + (144, 144), + (72, 72), + ] + stride = 14 + + def __init__( + self, + cfg=DEFAULT_CFG, + overrides=None, + _callbacks=None, + bpe_path="bpe_simple_vocab_16e6.txt.gz", + # prob threshold for detection outputs -- only keep detections above this threshold + # enters NMS and det-to-track matching + score_threshold_detection=0.5, + # IoU threshold for detection NMS + det_nms_thresh=0.0, + # IoU threshold for det-to-track matching -- a detection is considered "matched" to a tracklet it + # overlaps with a tracklet above this threshold -- it is often a loose threshold like 0.1 + assoc_iou_thresh=0.5, + # IoU threshold for det-to-track matching, which is used to determine whether a masklet is "unmatched" + # by any detections -- it is often a stricter threshold like 0.5 + trk_assoc_iou_thresh=0.5, + # prob threshold for a detection to be added as a new object + new_det_thresh=0.0, + # hotstart parameters: we hold off the outputs for `hotstart_delay` frames and + # 1) remove those tracklets unmatched by any detections based on `hotstart_unmatch_thresh` + # 2) remove those tracklets overlapping with one another based on `hotstart_dup_thresh` + hotstart_delay=0, + hotstart_unmatch_thresh=3, + hotstart_dup_thresh=3, + # Whether to suppress masks only within hotstart. If False, we can suppress masks even if they start before hotstart period. + suppress_unmatched_only_within_hotstart=True, + init_trk_keep_alive=0, + max_trk_keep_alive=8, + min_trk_keep_alive=-4, + # Threshold for suppressing overlapping objects based on recent occlusion + suppress_overlapping_based_on_recent_occlusion_threshold=0.0, + decrease_trk_keep_alive_for_empty_masklets=False, + o2o_matching_masklets_enable=False, # Enable hungarian matching to match existing masklets + suppress_det_close_to_boundary=False, + fill_hole_area=16, + # The maximum number of objects (masklets) to track across all GPUs (for no limit, set it to -1) + max_num_objects=-1, + recondition_every_nth_frame=-1, + # masket confirmation status (to suppress unconfirmed masklets) + masklet_confirmation_enable=False, + # a masklet is confirmed after being consecutively detected and matched for + # `masklet_confirmation_consecutive_det_thresh` + masklet_confirmation_consecutive_det_thresh=3, + # bbox heuristic parameters + reconstruction_bbox_iou_thresh=0.0, + reconstruction_bbox_det_score=0.0, + ): + """Initialize the SAM3VideoSemanticPredictor with configuration and optional overrides.""" + super().__init__(cfg, overrides, _callbacks, bpe_path=bpe_path) + self.score_threshold_detection = score_threshold_detection + self.det_nms_thresh = det_nms_thresh + self.assoc_iou_thresh = assoc_iou_thresh + self.trk_assoc_iou_thresh = trk_assoc_iou_thresh + self.new_det_thresh = new_det_thresh + + # hotstart parameters + if hotstart_delay > 0: + assert hotstart_unmatch_thresh <= hotstart_delay + assert hotstart_dup_thresh <= hotstart_delay + self.hotstart_delay = hotstart_delay + self.hotstart_unmatch_thresh = hotstart_unmatch_thresh + self.hotstart_dup_thresh = hotstart_dup_thresh + self.suppress_unmatched_only_within_hotstart = suppress_unmatched_only_within_hotstart + self.init_trk_keep_alive = init_trk_keep_alive + self.max_trk_keep_alive = max_trk_keep_alive + self.min_trk_keep_alive = min_trk_keep_alive + self.suppress_overlapping_based_on_recent_occlusion_threshold = ( + suppress_overlapping_based_on_recent_occlusion_threshold + ) + self.suppress_det_close_to_boundary = suppress_det_close_to_boundary + self.decrease_trk_keep_alive_for_empty_masklets = decrease_trk_keep_alive_for_empty_masklets + self.o2o_matching_masklets_enable = o2o_matching_masklets_enable + self.fill_hole_area = fill_hole_area + self._dist_pg_cpu = None # CPU process group (lazy-initialized on first use) + + max_num_objects = 10000 # no limit + num_obj_for_compile = 16 + self.max_num_objects = max_num_objects + self.num_obj_for_compile = num_obj_for_compile + self.recondition_every_nth_frame = recondition_every_nth_frame + self.masklet_confirmation_enable = masklet_confirmation_enable + self.masklet_confirmation_consecutive_det_thresh = masklet_confirmation_consecutive_det_thresh + self.reconstruction_bbox_iou_thresh = reconstruction_bbox_iou_thresh + self.reconstruction_bbox_det_score = reconstruction_bbox_det_score + + # build SAM3 tracker + self.tracker = SAM3VideoPredictor(overrides=overrides) + + self.inference_state = {} + self.callbacks["on_predict_start"].append(self.init_state) + + def setup_model(self, model=None, verbose=True): + """Setup the SAM3VideoSemanticPredictor model.""" + super().setup_model(model, verbose) + from .build_sam3 import build_interactive_sam3 + + # Initialize the SAM3 tracker model without backbone (backbone is handled in the detector) + model = build_interactive_sam3(self.args.model, with_backbone=False) + self.tracker.setup_model(model=model, verbose=False) + + def setup_source(self, source): + """Setup the source for the SAM3VideoSemanticPredictor model.""" + super().setup_source(source) + self.tracker.imgsz = self.imgsz + self.tracker.model.set_imgsz(self.imgsz) + self.tracker._bb_feat_sizes = [[int(x / (self.stride * i)) for x in self.imgsz] for i in [1 / 4, 1 / 2, 1]] + self.interpol_size = self.tracker.model.memory_encoder.mask_downsampler.interpol_size + + @staticmethod + def init_state(predictor): + """Initialize an inference state for the predictor. + + This function sets up the initial state required for performing inference on video data. It includes + initializing various dictionaries and ordered dictionaries that will store inputs, outputs, and other metadata + relevant to the tracking process. + + Args: + predictor (SAM3VideoSemanticPredictor): The predictor object for which to initialize the state. + """ + if len(predictor.inference_state) > 0: # means initialized + return + assert predictor.dataset is not None + assert predictor.dataset.mode == "video" + num_frames = predictor.dataset.frames + inference_state = { + "num_frames": num_frames, + "tracker_inference_states": [], + "tracker_metadata": {}, + } + inference_state["text_prompt"] = None + inference_state["per_frame_geometric_prompt"] = [None] * num_frames + predictor.inference_state = inference_state + + def inference(self, im, bboxes=None, labels=None, text: list[str] | None = None, *args, **kwargs): + """Perform inference on a video sequence with optional prompts.""" + frame = self.dataset.frame - 1 # align frame index to be 0-based + self.inference_state["im"] = im # only pass image for subsequent frames + if "text_ids" not in self.inference_state: # first frame processing + self.add_prompt(frame_idx=frame, text=text, bboxes=bboxes, labels=labels) + return self._run_single_frame_inference(frame, reverse=False) + + def postprocess(self, preds, img, orig_imgs): + """Post-process the predictions to apply non-overlapping constraints if required.""" + obj_id_to_mask = preds["obj_id_to_mask"] # low res masks + curr_obj_ids = sorted(obj_id_to_mask.keys()) + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + if len(curr_obj_ids) == 0: + pred_masks, pred_boxes = None, torch.zeros((0, 7), device=self.device) + else: + pred_masks = torch.cat([obj_id_to_mask[obj_id] for obj_id in curr_obj_ids], dim=0) + pred_masks = F.interpolate(pred_masks.float()[None], orig_imgs[0].shape[:2], mode="bilinear")[0] > 0.5 + pred_ids = torch.tensor(curr_obj_ids, dtype=torch.int32, device=pred_masks.device) + pred_scores = torch.tensor( + [preds["obj_id_to_score"][obj_id] for obj_id in curr_obj_ids], device=pred_masks.device + ) + pred_cls = torch.tensor( + [preds["obj_id_to_cls"][obj_id] for obj_id in curr_obj_ids], device=pred_masks.device + ) + keep = (pred_scores > self.args.conf) & pred_masks.any(dim=(1, 2)) + pred_masks = pred_masks[keep] + pred_boxes = batched_mask_to_box(pred_masks) + pred_boxes = torch.cat( + [pred_boxes, pred_ids[keep][:, None], pred_scores[keep][..., None], pred_cls[keep][..., None]], dim=-1 + ) + if pred_masks.shape[0] > 1: + tracker_scores = torch.tensor( + [ + ( + preds["obj_id_to_tracker_score"][obj_id] + if obj_id in preds["obj_id_to_tracker_score"] + else 0.0 + ) + for obj_id in curr_obj_ids + ], + device=pred_masks.device, + )[keep] + pred_masks = ( + self._apply_object_wise_non_overlapping_constraints( + pred_masks.unsqueeze(1), + tracker_scores.unsqueeze(1), + background_value=0, + ).squeeze(1) + ) > 0 + + # names = getattr(self.model, "names", [str(i) for i in range(pred_scores.shape[0])]) + names = dict(enumerate(str(i) for i in range(pred_masks.shape[0]))) + results = [] + for masks, boxes, orig_img, img_path in zip([pred_masks], [pred_boxes], orig_imgs, self.batch[0]): + results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=boxes)) + return results + + def _run_single_frame_inference(self, frame_idx, reverse=False, inference_state=None): + """Perform inference on a single frame and get its inference results.""" + inference_state = inference_state or self.inference_state + # prepare inputs + tracker_states_local = inference_state["tracker_inference_states"] + has_text_prompt = inference_state["text_prompt"] is not None + has_geometric_prompt = inference_state["per_frame_geometric_prompt"][frame_idx] is not None + # run inference for the current frame + ( + obj_id_to_mask, + obj_id_to_score, + obj_id_to_cls, + tracker_states_local_new, + tracker_metadata_new, + frame_stats, + _, + ) = self._det_track_one_frame( + frame_idx=frame_idx, + num_frames=inference_state["num_frames"], + reverse=reverse, + im=inference_state["im"], + text_ids=inference_state["text_ids"], + geometric_prompt=( + self._get_dummy_prompt(num_prompts=len(inference_state["text_ids"])) + if not has_geometric_prompt + else inference_state["per_frame_geometric_prompt"][frame_idx] + ), + tracker_states_local=tracker_states_local, + tracker_metadata_prev=inference_state["tracker_metadata"], + allow_new_detections=has_text_prompt or has_geometric_prompt, + ) + # update inference state + inference_state["tracker_inference_states"] = tracker_states_local_new + inference_state["tracker_metadata"] = tracker_metadata_new + + out = { + "obj_id_to_mask": obj_id_to_mask, + "obj_id_to_score": obj_id_to_score, # first frame detection score + "obj_id_to_cls": obj_id_to_cls, # first frame detection score + "obj_id_to_tracker_score": tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx], + } + # removed_obj_ids is only needed on rank 0 to handle hotstart delay buffer + metadata = tracker_metadata_new["metadata"] + removed_obj_ids = metadata["removed_obj_ids"] + out["removed_obj_ids"] = removed_obj_ids + out["suppressed_obj_ids"] = metadata["suppressed_obj_ids"][frame_idx] + out["frame_stats"] = frame_stats + if self.masklet_confirmation_enable: + status = metadata["masklet_confirmation"]["status"] + is_unconfirmed = status == self.UNCONFIRMED + out["unconfirmed_obj_ids"] = tracker_metadata_new["obj_ids_all_gpu"][is_unconfirmed].tolist() + else: + out["unconfirmed_obj_ids"] = [] + return out + + @smart_inference_mode() + def add_prompt( + self, + frame_idx, + text=None, + bboxes=None, + labels=None, + inference_state=None, + ): + """Add text, point or box prompts on a single frame. This method returns the inference outputs only on the + prompted frame. + + Note that text prompts are NOT associated with a particular frame (i.e. they apply + to all frames). However, we only run inference on the frame specified in `frame_idx`. + """ + inference_state = inference_state or self.inference_state + assert text is not None or bboxes is not None, "at least one type of prompt (text, boxes) must be provided" + + # 1) handle text prompt + use_text = text is not None + text = text if use_text else "visual" + text_batch = [text] if isinstance(text, str) else text + inference_state["text_prompt"] = text if use_text else None + n = len(text_batch) + text_ids = torch.arange(n, device=self.device, dtype=torch.long) + inference_state["text_ids"] = text_ids + if text is not None and self.model.names != text: + self.model.set_classes(text=text) + + # 2) handle box prompt + bboxes, labels = self._prepare_geometric_prompts(self.batch[1][0].shape[:2], bboxes, labels) + assert (bboxes is not None) == (labels is not None) + geometric_prompt = self._get_dummy_prompt(num_prompts=n) + if bboxes is not None: + for i in range(len(bboxes)): + geometric_prompt.append_boxes(bboxes[[i]], labels[[i]]) + inference_state["per_frame_geometric_prompt"][frame_idx] = geometric_prompt + out = self._run_single_frame_inference(frame_idx, reverse=False, inference_state=inference_state) + return frame_idx, out + + def _apply_object_wise_non_overlapping_constraints(self, pred_masks, obj_scores, background_value=-10.0): + """Applies non-overlapping constraints object wise (i.e. only one object can claim the overlapping region).""" + # Replace pixel scores with object scores + pred_masks_single_score = torch.where(pred_masks > 0, obj_scores[..., None, None], background_value) + # Apply pixel-wise non-overlapping constraint based on mask scores + pixel_level_non_overlapping_masks = self.tracker.model._apply_non_overlapping_constraints( + pred_masks_single_score + ) + # Replace object scores with pixel scores. Note, that now only one object can claim the overlapping region + pred_masks = torch.where( + pixel_level_non_overlapping_masks > 0, + pred_masks, + torch.clamp(pred_masks, max=background_value), + ) + return pred_masks + + def _det_track_one_frame( + self, + im: torch.Tensor, + text_ids: torch.Tensor, + frame_idx: int, + num_frames: int, + reverse: bool, + geometric_prompt: Prompt, + tracker_states_local: list[Any], + tracker_metadata_prev: dict[str, Any], + allow_new_detections: bool = True, + ): + """This function handles one-step inference for the DenseTracking model in an SPMD manner. At a high-level, all + GPUs execute the same function calls as if it's done on a single GPU, while under the hood, some + function calls involve distributed computation based on sharded SAM2 states. + + - `input_batch` contains image and other inputs on the entire video; it should be identical across GPUs + - `tracker_states_local` holds the local masklet information in this GPU shard + - `tracker_metadata_prev` manages the metadata for SAM2 objects, such as which masklet is hold on which GPUs + it contains both global and local masklet information + """ + # Step 1: run backbone and detector in a distributed manner -- this is done via Sam3ImageOnVideoMultiGPU, + # a MultiGPU model (assigned to `self.detector`) that shards frames in a round-robin manner. + det_out = self.run_backbone_and_detection( + im=im, + text_ids=text_ids, + geometric_prompt=geometric_prompt, + allow_new_detections=allow_new_detections, + ) + + # Step 2: each GPU propagates its local SAM2 states to get the SAM2 prediction masks. + # the returned `tracker_low_res_masks_global` contains the concatenated masklet predictions + # gathered from all GPUs (as if they are propagated on a single GPU). Note that this step only + # runs the SAM2 propagation step, but doesn't encode new memory for the predicted masks; + # we defer memory encoding to `run_tracker_update_execution_phase` after resolving all heuristics. + if tracker_metadata_prev == {}: + # initialize masklet metadata if it's uninitialized (empty dict) + tracker_metadata_prev.update(self._initialize_metadata()) + tracker_low_res_masks_global, tracker_obj_scores_global = self.run_tracker_propagation( + frame_idx=frame_idx, + tracker_states_local=tracker_states_local, + tracker_metadata_prev=tracker_metadata_prev, + ) + + # Step 3: based on detection outputs and the propagated SAM2 prediction masks, we make plans + # for SAM2 masklet updates (i.e. which objects to add and remove, how to load-balance them, etc). + # We also run SAM2 memory encoder globally in this step to resolve non-overlapping constraints. + # **This step should involve all the heuristics needed for any updates.** Most of the update + # planning will be done on the master rank (GPU 0) and the resulting plan `tracker_update_plan` is + # broadcasted to other GPUs (to be executed in a distributed manner). This step also generates the + # new masklet metadata `tracker_metadata_new` (based on its previous version `tracker_metadata_prev`). + tracker_update_plan, tracker_metadata_new = self.run_tracker_update_planning_phase( + frame_idx=frame_idx, + reverse=reverse, + det_out=det_out, + tracker_low_res_masks_global=tracker_low_res_masks_global, + tracker_obj_scores_global=tracker_obj_scores_global, + tracker_metadata_prev=tracker_metadata_prev, + tracker_states_local=tracker_states_local, + ) + + # Get reconditioning info from the update plan + reconditioned_obj_ids = tracker_update_plan.get("reconditioned_obj_ids", set()) + + # Step 4: based on `tracker_update_plan`, each GPU executes the update w.r.t. its local SAM2 inference states + tracker_states_local_new = self.run_tracker_update_execution_phase( + frame_idx=frame_idx, + num_frames=num_frames, + det_out=det_out, + tracker_states_local=tracker_states_local, + tracker_update_plan=tracker_update_plan, + ) + + # Step 5: finally, build the outputs for this frame (it only needs to be done on GPU 0 since + # only GPU 0 will send outputs to the server). + obj_id_to_mask = self.build_outputs( + det_out=det_out, + tracker_low_res_masks_global=tracker_low_res_masks_global, + tracker_metadata_prev=tracker_metadata_prev, + tracker_update_plan=tracker_update_plan, + reconditioned_obj_ids=reconditioned_obj_ids, + ) + obj_id_to_score = tracker_metadata_new["obj_id_to_score"] + obj_id_to_cls = tracker_metadata_new["obj_id_to_cls"] + # a few statistics for the current frame as a part of the output + frame_stats = { + "num_obj_tracked": np.sum(tracker_metadata_new["num_obj"]), + "num_obj_dropped": tracker_update_plan["num_obj_dropped_due_to_limit"], + } + # add tracker scores to metadata, it should be fired for frames except the first frame + if tracker_obj_scores_global.shape[0] > 0: + # Convert tracker_obj_scores_global to sigmoid scores before updating + tracker_obj_scores_global = tracker_obj_scores_global.sigmoid().tolist() + tracker_obj_ids = tracker_metadata_prev["obj_ids"] + tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx].update( + dict(zip(tracker_obj_ids, tracker_obj_scores_global)) + ) + return ( + obj_id_to_mask, # a dict: obj_id --> output mask + obj_id_to_score, # a dict: obj_id --> output score (prob) + obj_id_to_cls, # a dict: obj_id --> output cls (int) + tracker_states_local_new, + tracker_metadata_new, + frame_stats, + tracker_obj_scores_global, # a dict: obj_id --> tracker frame-level scores + ) + + def _suppress_detections_close_to_boundary(self, boxes, margin=0.025): + """Suppress detections too close to image edges (for normalized boxes). + + boxes: (N, 4) in xyxy format, normalized [0,1] + margin: fraction of image + """ + x_min, y_min, x_max, y_max = boxes.unbind(-1) + x_c = (x_min + x_max) / 2 + y_c = (y_min + y_max) / 2 + keep = (x_c > margin) & (x_c < 1.0 - margin) & (y_c > margin) & (y_c < 1.0 - margin) + + return keep + + def run_backbone_and_detection( + self, im: torch.Tensor, text_ids: torch.Tensor, geometric_prompt: Prompt, allow_new_detections: bool + ): + """Run backbone and detection for a single frame.""" + features = self.get_im_features(im) + sam3_image_out = self.model.forward_grounding( + backbone_out=features, text_ids=text_ids, geometric_prompt=geometric_prompt + ) + det_out = self._extract_detection_outputs(sam3_image_out, allow_new_detections) + self._cache_backbone_features(sam3_image_out) + return det_out + + def _extract_detection_outputs(self, sam3_image_out, allow_new_detections): + """Extract and filter detection outputs.""" + pred_probs = sam3_image_out["pred_logits"].squeeze(-1).sigmoid() + if not allow_new_detections: + pred_probs = pred_probs - 1e8 + + pred_cls = torch.tensor( + list(range(pred_probs.shape[0])), + dtype=pred_probs.dtype, + device=pred_probs.device, + )[:, None].expand_as(pred_probs) + + pred_boxes_xyxy = sam3_image_out["pred_boxes_xyxy"] + pred_masks = sam3_image_out["pred_masks"] + + keep = pred_probs > self.score_threshold_detection + return { + "bbox": pred_boxes_xyxy[keep], + "mask": pred_masks[keep], + "scores": pred_probs[keep], + "cls": pred_cls[keep], + } + + def _cache_backbone_features(self, sam3_image_out): + """Build and cache SAM2 backbone features.""" + sam_mask_decoder = self.tracker.model.sam_mask_decoder + feats = sam3_image_out["backbone_out"]["sam2_backbone_out"] + tracker_backbone_fpn = [ + sam_mask_decoder.conv_s0(feats["backbone_fpn"][0]), + sam_mask_decoder.conv_s1(feats["backbone_fpn"][1]), + feats["backbone_fpn"][2], + ] + tracker_backbone_out = { + "vision_features": tracker_backbone_fpn[-1], + "vision_pos_enc": feats["vision_pos_enc"], + "backbone_fpn": tracker_backbone_fpn, + } + # cache the SAM2 backbone features for `frame_idx` in the tracker + self.tracker.backbone_out = tracker_backbone_out + + def run_tracker_propagation( + self, frame_idx: int, tracker_states_local: list[Any], tracker_metadata_prev: dict[str, np.ndarray] + ): + """Run the tracker propagation phase for a single frame in an SPMD manner.""" + # Step 1: propagate the local SAM2 states to get the current frame's prediction + # `low_res_masks_local` of the existing masklets on this GPU + # - obj_ids_local: list[int] -- list of object IDs + # - low_res_masks_local: Tensor -- (num_local_obj, H_mask, W_mask) + obj_ids_local, low_res_masks_local, obj_scores_local = self._propogate_tracker_one_frame_local_gpu( + tracker_states_local, frame_idx=frame_idx + ) + + assert np.all(obj_ids_local == tracker_metadata_prev["obj_ids"]), "{} != {}".format( + obj_ids_local, tracker_metadata_prev["obj_ids"] + ) + + # Step 2: all-gather `low_res_masks_local` into `low_res_masks_global` + # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask) + low_res_masks_global = low_res_masks_local + obj_scores_global = obj_scores_local + return low_res_masks_global, obj_scores_global + + def _recondition_masklets( + self, + frame_idx, + det_out: dict[str, torch.Tensor], + trk_id_to_max_iou_high_conf_det: list[int], + tracker_states_local: list[Any], + tracker_metadata: dict[str, np.ndarray], + tracker_obj_scores_global: torch.Tensor, + ): + """Recondition masklets based on new high-confidence detections.""" + # Recondition the masklets based on the new detections + for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items(): + new_mask = det_out["mask"][det_idx : det_idx + 1] + new_mask_binary = ( + F.interpolate(new_mask.unsqueeze(1), size=self.interpol_size, mode="bilinear", align_corners=False) > 0 + ) + HIGH_CONF_THRESH = 0.8 + reconditioned_states_idx = set() + obj_idx = np.where(tracker_metadata["obj_ids"] == trk_obj_id)[0].item() + obj_score = tracker_obj_scores_global[obj_idx] + for state_idx, inference_state in enumerate(tracker_states_local): + if ( + trk_obj_id in inference_state["obj_ids"] + # NOTE: Goal of this condition is to avoid reconditioning masks that are occluded/low qualiy. + # Unfortunately, these can get reconditioned anyway due to batching. We should consider removing these heuristics. + and obj_score > HIGH_CONF_THRESH + ): + LOGGER.debug( + f"Adding new mask for track {trk_obj_id} at frame {frame_idx}. Objects {inference_state['obj_ids']} are all reconditioned." + ) + self.tracker.add_new_prompts( + inference_state=inference_state, + frame_idx=frame_idx, + obj_id=trk_obj_id, + masks=new_mask_binary, + ) + reconditioned_states_idx.add(state_idx) + + for idx in reconditioned_states_idx: + self.tracker.propagate_in_video_preflight(tracker_states_local[idx]) + return tracker_states_local + + def run_tracker_update_planning_phase( + self, + frame_idx: int, + reverse: bool, + det_out: dict[str, torch.Tensor], + tracker_low_res_masks_global: torch.Tensor, + tracker_obj_scores_global: torch.Tensor, + tracker_metadata_prev: dict[str, np.ndarray], + tracker_states_local: list[Any], + ): + """Run the tracker update planning phase for a single frame in an SPMD manner.""" + # initialize new metadata from previous metadata (its values will be updated later) + tracker_metadata_new = { + "obj_ids": deepcopy(tracker_metadata_prev["obj_ids"]), + "num_obj": deepcopy(tracker_metadata_prev["num_obj"]), + "obj_id_to_score": deepcopy(tracker_metadata_prev["obj_id_to_score"]), + "obj_id_to_cls": deepcopy(tracker_metadata_prev["obj_id_to_cls"]), + "obj_id_to_tracker_score_frame_wise": deepcopy(tracker_metadata_prev["obj_id_to_tracker_score_frame_wise"]), + "obj_id_to_last_occluded": {}, # will be filled later + "max_obj_id": deepcopy(tracker_metadata_prev["max_obj_id"]), + } + + # Initialize reconditioned_obj_ids early to avoid UnboundLocalError + reconditioned_obj_ids = set() + + # Step 1: make the update plan and resolve heuristics on GPU 0 + det_mask_preds: torch.Tensor = det_out["mask"] # low-res mask logits + det_scores_np: np.ndarray = det_out["scores"].float().cpu().numpy() + det_cls_np: np.ndarray = det_out["cls"].float().cpu().numpy() + det_bbox_xyxy: torch.Tensor = det_out["bbox"] + # a) match detector and tracker masks and find new objects + ( + new_det_fa_inds, + unmatched_trk_obj_ids, + det_to_matched_trk_obj_ids, + trk_id_to_max_iou_high_conf_det, + empty_trk_obj_ids, + ) = self._associate_det_trk( + det_masks=det_mask_preds, + det_scores_np=det_scores_np, + trk_masks=tracker_low_res_masks_global, + trk_obj_ids=tracker_metadata_prev["obj_ids"], + ) + if self.suppress_det_close_to_boundary: + keep = self._suppress_detections_close_to_boundary(det_bbox_xyxy[new_det_fa_inds]) + new_det_fa_inds = new_det_fa_inds[keep.cpu().numpy()] + + # check whether we've hit the maximum number of objects we can track (and if so, drop some detections) + prev_obj_num = np.sum(tracker_metadata_prev["num_obj"]) + new_det_num = len(new_det_fa_inds) + num_obj_dropped_due_to_limit = 0 + if prev_obj_num + new_det_num > self.max_num_objects: + LOGGER.warning(f"hitting {self.max_num_objects=} with {new_det_num=} and {prev_obj_num=}") + new_det_num_to_keep = self.max_num_objects - prev_obj_num + num_obj_dropped_due_to_limit = new_det_num - new_det_num_to_keep + new_det_fa_inds = self._drop_new_det_with_obj_limit(new_det_fa_inds, det_scores_np, new_det_num_to_keep) + assert len(new_det_fa_inds) == new_det_num_to_keep + new_det_num = len(new_det_fa_inds) + + # assign object IDs to new detections and decide which GPU to place them + new_det_obj_ids = tracker_metadata_prev["max_obj_id"] + 1 + np.arange(new_det_num) + + # b) handle hotstart heuristics to remove objects + # here `metadata` contains metadata stored on (and only accessible to) GPU 0; + # we avoid broadcasting them to other GPUs to save communication cost, assuming + # that `metadata` is not needed by other GPUs + metadata_new = deepcopy(tracker_metadata_prev["metadata"]) + if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: + obj_ids_newly_removed, metadata_new = self._process_hotstart( + frame_idx=frame_idx, + reverse=reverse, + det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, + new_det_obj_ids=new_det_obj_ids, + empty_trk_obj_ids=empty_trk_obj_ids, + unmatched_trk_obj_ids=unmatched_trk_obj_ids, + metadata=metadata_new, + ) + else: + # if warm-up is not complete, we don't remove any objects + obj_ids_newly_removed = set() + tracker_metadata_new["metadata"] = metadata_new + + # `tracker_update_plan` should be identical on all GPUs after broadcasting + tracker_update_plan = { + "new_det_fa_inds": new_det_fa_inds, # np.ndarray + "new_det_obj_ids": new_det_obj_ids, # np.ndarray + # "new_det_gpu_ids": new_det_gpu_ids, # np.ndarray + "unmatched_trk_obj_ids": unmatched_trk_obj_ids, # np.ndarray + "det_to_matched_trk_obj_ids": det_to_matched_trk_obj_ids, # dict + "obj_ids_newly_removed": obj_ids_newly_removed, # set + "num_obj_dropped_due_to_limit": num_obj_dropped_due_to_limit, # int + "trk_id_to_max_iou_high_conf_det": trk_id_to_max_iou_high_conf_det, # dict + "reconditioned_obj_ids": reconditioned_obj_ids, # set + } + + # Step 3 (optional): recondition masklets based on high-confidence detections before memory encoding + # NOTE: Running this in execution phase (after memory encoding) can lead to suboptimal results + should_recondition_iou = False + + # Evaluate tracklets for reconditioning based on bbox IoU mismatch with detections + if self.reconstruction_bbox_iou_thresh > 0 and len(trk_id_to_max_iou_high_conf_det) > 0: + for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items(): + det_box = det_out["bbox"][det_idx] + det_score = det_out["scores"][det_idx] + + try: + trk_idx = list(tracker_metadata_prev["obj_ids"]).index(trk_obj_id) + except ValueError: + continue # Skip if tracklet not found + + tracker_mask = tracker_low_res_masks_global[trk_idx] + mask_binary = tracker_mask > 0 + mask_area = mask_binary.sum().item() + + if mask_area == 0: + continue # Skip tracklets with zero mask area + + # Get bounding box from SAM2 mask and convert to normalized coordinates + tracker_box_pixels = batched_mask_to_box(mask_binary.unsqueeze(0)).squeeze(0) + mask_height, mask_width = tracker_mask.shape[-2:] + tracker_box_normalized = torch.tensor( + [ + tracker_box_pixels[0] / mask_width, + tracker_box_pixels[1] / mask_height, + tracker_box_pixels[2] / mask_width, + tracker_box_pixels[3] / mask_height, + ], + device=tracker_box_pixels.device, + ) + + # Compute IoU between detection and SAM2 tracklet bounding boxes + det_box_batch = det_box.unsqueeze(0) + tracker_box_batch = tracker_box_normalized.unsqueeze(0) + iou = box_iou(det_box_batch, tracker_box_batch)[0] + + if iou < self.reconstruction_bbox_iou_thresh and det_score >= self.reconstruction_bbox_det_score: + should_recondition_iou = True + reconditioned_obj_ids.add(trk_obj_id) + + should_recondition_periodic = ( + self.recondition_every_nth_frame > 0 + and frame_idx % self.recondition_every_nth_frame == 0 + and len(trk_id_to_max_iou_high_conf_det) > 0 + ) + + # Recondition if periodic or IoU condition met + if should_recondition_periodic or should_recondition_iou: + self._recondition_masklets( + frame_idx, + det_out, + trk_id_to_max_iou_high_conf_det, + tracker_states_local, + tracker_metadata_prev, + tracker_obj_scores_global, + ) + + # Step 4: Run SAM2 memory encoder on the current frame's prediction masks + # This is done on all GPUs + batch_size = tracker_low_res_masks_global.size(0) + if batch_size > 0: + if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: + if self.suppress_overlapping_based_on_recent_occlusion_threshold > 0.0: + # NOTE: tracker_low_res_masks_global is updated in-place then returned + tracker_low_res_masks_global = self._suppress_overlapping_based_on_recent_occlusion( + frame_idx, + tracker_low_res_masks_global, + tracker_metadata_prev, + tracker_metadata_new, + obj_ids_newly_removed, + reverse, + ) + + self._tracker_update_memories(tracker_states_local, frame_idx, low_res_masks=tracker_low_res_masks_global) + + # Step 4: update the SAM2 metadata based on the update plan + updated_obj_ids_this_gpu = tracker_metadata_new["obj_ids"] + if len(new_det_obj_ids) > 0: + updated_obj_ids_this_gpu = np.concatenate([updated_obj_ids_this_gpu, new_det_obj_ids]) + if len(obj_ids_newly_removed) > 0: + is_removed = np.isin(updated_obj_ids_this_gpu, list(obj_ids_newly_removed)) + updated_obj_ids_this_gpu = updated_obj_ids_this_gpu[~is_removed] + tracker_metadata_new["obj_ids"] = updated_obj_ids_this_gpu + tracker_metadata_new["num_obj"] = len(updated_obj_ids_this_gpu) + # update object scores and the maximum object ID assigned so far + if len(new_det_obj_ids) > 0: + tracker_metadata_new["obj_id_to_score"].update(zip(new_det_obj_ids, det_scores_np[new_det_fa_inds])) + tracker_metadata_new["obj_id_to_cls"].update(zip(new_det_obj_ids, det_cls_np[new_det_fa_inds])) + # tracker scores are not available for new objects, use det score instead. + tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx].update( + zip(new_det_obj_ids, det_scores_np[new_det_fa_inds]) + ) + tracker_metadata_new["max_obj_id"] = max(tracker_metadata_new["max_obj_id"], np.max(new_det_obj_ids)) + # for removed objects, we set their scores to a very low value (-1e4) but still + # keep them in "obj_id_to_score" (it's easier to handle outputs this way) + for obj_id in obj_ids_newly_removed: + tracker_metadata_new["obj_id_to_score"][obj_id] = -1e4 + tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx][obj_id] = -1e4 + tracker_metadata_new["obj_id_to_last_occluded"].pop(obj_id, None) + # check that "metadata" is in tracker_metadata_new if and only if it's GPU 0 + assert "metadata" in tracker_metadata_new + if self.masklet_confirmation_enable: + metadata = self.update_masklet_confirmation_status( + metadata=tracker_metadata_new["metadata"], + obj_ids_all_gpu_prev=tracker_metadata_prev["obj_ids"], + obj_ids_all_gpu_updated=tracker_metadata_new["obj_ids"], + det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, + new_det_obj_ids=new_det_obj_ids, + ) + tracker_metadata_new["metadata"] = metadata + + return tracker_update_plan, tracker_metadata_new + + def _suppress_overlapping_based_on_recent_occlusion( + self, + frame_idx: int, + tracker_low_res_masks_global: torch.Tensor, + tracker_metadata_prev: dict[str, Any], + tracker_metadata_new: dict[str, Any], + obj_ids_newly_removed: set[int], + reverse: bool = False, + ): + """Suppress overlapping masks based on the most recent occlusion information. If an object is removed by + hotstart, we always suppress it if it overlaps with any other object. + + Args: + frame_idx (int): The current frame index. + tracker_low_res_masks_global (torch.Tensor): The low-resolution masks for the current frame. + tracker_metadata_prev (dict[str, Any]): The metadata from the previous frame. + tracker_metadata_new (dict[str, Any]): The metadata for the current frame. + obj_ids_newly_removed (set[int]): The object IDs that have been removed. + reverse (bool): Whether the tracking is in reverse order. + + Returns: + (torch.Tensor): The updated low-resolution masks with some objects suppressed. + """ + obj_ids_global = tracker_metadata_prev["obj_ids"] + binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0 + batch_size = tracker_low_res_masks_global.size(0) + if batch_size > 0: + assert len(obj_ids_global) == batch_size, ( + f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}" + ) + last_occluded_prev = torch.cat( + [ + tracker_metadata_prev["obj_id_to_last_occluded"].get( + obj_id, + torch.full( + (1,), + fill_value=( + self.NEVER_OCCLUDED if obj_id not in obj_ids_newly_removed else self.ALWAYS_OCCLUDED + ), + device=binary_tracker_low_res_masks_global.device, + dtype=torch.long, + ), + ) + for obj_id in obj_ids_global + ], + dim=0, + ) + to_suppress = self._get_objects_to_suppress_based_on_most_recently_occluded( + binary_tracker_low_res_masks_global, + last_occluded_prev, + obj_ids_global, + frame_idx, + reverse, + ) + + # Update metadata with occlusion information + is_obj_occluded = ~(binary_tracker_low_res_masks_global.any(dim=(-1, -2))) + is_obj_occluded_or_suppressed = is_obj_occluded | to_suppress + last_occluded_new = last_occluded_prev.clone() + last_occluded_new[is_obj_occluded_or_suppressed] = frame_idx + # Slice out the last occluded frame for each object + tracker_metadata_new["obj_id_to_last_occluded"] = { + obj_id: last_occluded_new[obj_idx : obj_idx + 1] for obj_idx, obj_id in enumerate(obj_ids_global) + } + + # Zero out suppressed masks before memory encoding + tracker_low_res_masks_global[to_suppress] = self.NO_OBJ_LOGIT + + return tracker_low_res_masks_global + + def run_tracker_update_execution_phase( + self, + frame_idx: int, + num_frames: int, + det_out: dict[str, torch.Tensor], + tracker_states_local: list[Any], + tracker_update_plan: dict[str, np.ndarray], + ): + """Execute the tracker update plan for a single frame in an SPMD manner.""" + # initialize tracking scores with detection scores + new_det_fa_inds: np.ndarray = tracker_update_plan["new_det_fa_inds"] + new_det_obj_ids: np.ndarray = tracker_update_plan["new_det_obj_ids"] + # new_det_gpu_ids: np.ndarray = tracker_update_plan["new_det_gpu_ids"] + new_det_obj_ids_local: np.ndarray = new_det_obj_ids + new_det_fa_inds_local: np.ndarray = new_det_fa_inds + obj_ids_newly_removed: set[int] = tracker_update_plan["obj_ids_newly_removed"] + + # Step 1: add new objects from the detector to SAM2 inference states + if len(new_det_fa_inds_local) > 0: + new_det_fa_inds_local_t = torch.from_numpy(new_det_fa_inds_local) + new_det_masks: torch.Tensor = det_out["mask"][new_det_fa_inds_local_t] + # initialize SAM2 with new object masks + tracker_states_local = self._tracker_add_new_objects( + frame_idx=frame_idx, + num_frames=num_frames, + new_obj_ids=new_det_obj_ids_local, + new_obj_masks=new_det_masks, + tracker_states_local=tracker_states_local, + ) + + # Step 2: remove from SAM2 inference states those objects removed by heuristics + if len(obj_ids_newly_removed) > 0: + self._tracker_remove_objects(tracker_states_local, obj_ids_newly_removed) + + return tracker_states_local + + def build_outputs( + self, + det_out: dict[str, torch.Tensor], + tracker_low_res_masks_global: torch.Tensor, + tracker_metadata_prev: dict[str, np.ndarray], + tracker_update_plan: dict[str, np.ndarray], + reconditioned_obj_ids: set | None = None, + ): + """Build the output masks for the current frame.""" + new_det_fa_inds: np.ndarray = tracker_update_plan["new_det_fa_inds"] + new_det_obj_ids: np.ndarray = tracker_update_plan["new_det_obj_ids"] + obj_id_to_mask = {} # obj_id --> output mask tensor + + # Part 1: masks from previous SAM2 propagation + existing_masklet_obj_ids = tracker_metadata_prev["obj_ids"] + existing_masklet_binary = tracker_low_res_masks_global.unsqueeze(1) + assert len(existing_masklet_obj_ids) == len(existing_masklet_binary) + for obj_id, mask in zip(existing_masklet_obj_ids, existing_masklet_binary): + obj_id_to_mask[obj_id] = mask # (1, H_video, W_video) + + # Part 2: masks from new detections + new_det_fa_inds_t = torch.from_numpy(new_det_fa_inds) + new_det_low_res_masks = det_out["mask"][new_det_fa_inds_t].unsqueeze(1) + assert len(new_det_obj_ids) == len(new_det_low_res_masks) + for obj_id, mask in zip(new_det_obj_ids, new_det_low_res_masks): + obj_id_to_mask[obj_id] = mask # (1, H_video, W_video) + + # Part 3: Override masks for reconditioned objects using detection masks + if reconditioned_obj_ids is not None and len(reconditioned_obj_ids) > 0: + trk_id_to_max_iou_high_conf_det = tracker_update_plan.get("trk_id_to_max_iou_high_conf_det", {}) + + for obj_id in reconditioned_obj_ids: + det_idx = trk_id_to_max_iou_high_conf_det.get(obj_id) + + if det_idx is not None: + obj_id_to_mask[obj_id] = det_out["mask"][det_idx].unsqueeze(0) + + return obj_id_to_mask + + def _get_objects_to_suppress_based_on_most_recently_occluded( + self, + binary_low_res_masks: torch.Tensor, + last_occluded: list[int], + obj_ids: list[int], + frame_idx: int | None = None, + reverse: bool = False, + ): + # Suppress overlapping masks for objects that were most recently occluded + assert binary_low_res_masks.dtype == torch.bool, f"Expected boolean tensor, got {binary_low_res_masks.dtype}" + to_suppress = torch.zeros( + binary_low_res_masks.size(0), + device=binary_low_res_masks.device, + dtype=torch.bool, + ) + if len(obj_ids) <= 1: + return to_suppress + + iou = mask_iou(binary_low_res_masks.flatten(1), binary_low_res_masks.flatten(1)) # [N,N] + + # Create masks for upper triangular matrix (i < j) and IoU threshold + mask_iou_thresh = iou >= self.suppress_overlapping_based_on_recent_occlusion_threshold + overlapping_pairs = torch.triu(mask_iou_thresh, diagonal=1) # [N,N] + + last_occ_expanded_i = last_occluded.unsqueeze(1) # (N, 1) + last_occ_expanded_j = last_occluded.unsqueeze(0) # (1, N) + # Suppress most recently occluded + cmp_op = torch.gt if not reverse else torch.lt + suppress_i_mask = ( + overlapping_pairs + & cmp_op(last_occ_expanded_i, last_occ_expanded_j) # (last_occ_expanded_i > last_occ_expanded_j) + & (last_occ_expanded_j > -1) # j can suppress i only if i was previously occluded + ) + suppress_j_mask = ( + overlapping_pairs + & cmp_op(last_occ_expanded_j, last_occ_expanded_i) + & (last_occ_expanded_i > -1) # i can suppress j only if j was previously occluded + ) + # Apply suppression + to_suppress = suppress_i_mask.any(dim=1) | suppress_j_mask.any(dim=0) + + # Log for debugging + if LOGGER.isEnabledFor(10) and frame_idx is not None: + suppress_i_mask = suppress_i_mask.cpu().numpy() + suppress_j_mask = suppress_j_mask.cpu().numpy() + last_occluded = last_occluded.cpu().numpy() + + # Find all suppression pairs without using torch.where + batch_size = suppress_i_mask.shape[0] + + # Log i-suppression cases (where i gets suppressed in favor of j) + for i in range(batch_size): + for j in range(batch_size): + if suppress_i_mask[i, j]: + LOGGER.debug( + f"{frame_idx=}: Suppressing obj {obj_ids[i]} last occluded {last_occluded[i]} in favor of {obj_ids[j]} last occluded {last_occluded[j]}" + ) + + # Log j-suppression cases (where j gets suppressed in favor of i) + for i in range(batch_size): + for j in range(batch_size): + if suppress_j_mask[i, j]: + LOGGER.debug( + f"{frame_idx=}: Suppressing obj {obj_ids[j]} last occluded {last_occluded[j]} in favor of {obj_ids[i]} last occluded {last_occluded[i]}" + ) + + return to_suppress + + def _propogate_tracker_one_frame_local_gpu(self, inference_states: list[Any], frame_idx: int): + """Inference_states: list of inference states, each state corresponds to a different set of objects.""" + obj_ids_local = [] + low_res_masks_list = [] + obj_scores_list = [] + for inference_state in inference_states: + if len(inference_state["obj_ids"]) == 0: + continue # skip propagation on empty inference states + + out_obj_ids, out_low_res_masks, out_obj_scores = self.tracker.propagate_in_video( + inference_state, frame_idx=frame_idx + ) + assert isinstance(out_obj_ids, list) + obj_ids_local.extend(out_obj_ids) + low_res_masks_list.append(out_low_res_masks.squeeze(1)) + obj_scores_list.append(out_obj_scores.squeeze(1)) + + # concatenate the output masklets from all local inference states + if len(low_res_masks_list) > 0: + low_res_masks_local = torch.cat(low_res_masks_list, dim=0) + obj_scores_local = torch.cat(obj_scores_list, dim=0) + low_res_masks_local = low_res_masks_local.squeeze(1) + else: + low_res_masks_local = torch.zeros(0, *self._bb_feat_sizes[0], device=self.device) + obj_scores_local = torch.zeros(0, device=self.device) + + return obj_ids_local, low_res_masks_local, obj_scores_local + + def _associate_det_trk( + self, + det_masks: torch.Tensor, + det_scores_np: np.ndarray, + trk_masks: torch.Tensor, + trk_obj_ids: np.ndarray, + ): + """Match detections on the current frame with the existing masklets. + + Args: + det_masks: (N, H, W) tensor of predicted masks + det_scores_np: (N,) array of detection scores + trk_masks: (M, H, W) tensor of track masks + trk_obj_ids: (M,) array of object IDs corresponding to trk_masks + + Returns: + new_det_fa_inds: array of new object indices. + unmatched_trk_obj_ids: array of existing masklet object IDs that are not matched to any detections on this + frame (for unmatched, we only count masklets with >0 area) + det_to_matched_trk_obj_ids: dict[int, np.ndarray]: mapping from detector's detection indices to the list of + matched tracklet object IDs + empty_trk_obj_ids: array of existing masklet object IDs with zero area in SAM2 prediction + """ + iou_threshold = self.assoc_iou_thresh + iou_threshold_trk = self.trk_assoc_iou_thresh + new_det_thresh = self.new_det_thresh + + assert det_masks.is_floating_point(), "float tensor expected (do not binarize)" + assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)" + assert trk_masks.size(0) == len(trk_obj_ids), ( + f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}" + ) + if trk_masks.size(0) == 0: + # all detections are new + new_det_fa_inds = np.arange(det_masks.size(0)) + unmatched_trk_obj_ids = np.array([], np.int64) + empty_trk_obj_ids = np.array([], np.int64) + det_to_matched_trk_obj_ids = {} + trk_id_to_max_iou_high_conf_det = {} + return ( + new_det_fa_inds, + unmatched_trk_obj_ids, + det_to_matched_trk_obj_ids, + trk_id_to_max_iou_high_conf_det, + empty_trk_obj_ids, + ) + elif det_masks.size(0) == 0: + # all previous tracklets are unmatched if they have a non-zero area + new_det_fa_inds = np.array([], np.int64) + trk_is_nonempty = (trk_masks > 0).any(dim=(1, 2)).cpu().numpy() + unmatched_trk_obj_ids = trk_obj_ids[trk_is_nonempty] + empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty] + det_to_matched_trk_obj_ids = {} + trk_id_to_max_iou_high_conf_det = {} + return ( + new_det_fa_inds, + unmatched_trk_obj_ids, + det_to_matched_trk_obj_ids, + trk_id_to_max_iou_high_conf_det, + empty_trk_obj_ids, + ) + + if det_masks.shape[-2:] != trk_masks.shape[-2:]: + # resize to the smaller size to save GPU memory + if np.prod(det_masks.shape[-2:]) < np.prod(trk_masks.shape[-2:]): + trk_masks = F.interpolate( + trk_masks.unsqueeze(1), + size=det_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ).squeeze(1) + else: + # resize detections to track size + det_masks = F.interpolate( + det_masks.unsqueeze(1), + size=trk_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ).squeeze(1) + + det_masks_binary = det_masks > 0 + trk_masks_binary = trk_masks > 0 + ious = mask_iou(det_masks_binary.flatten(1).float(), trk_masks_binary.flatten(1).float()) # (N, M) + + ious_np = ious.cpu().numpy() + if self.o2o_matching_masklets_enable: + from scipy.optimize import linear_sum_assignment + + # Hungarian matching for tracks (one-to-one: each track matches at most one detection) + cost_matrix = 1 - ious_np # Hungarian solves for minimum cost + row_ind, col_ind = linear_sum_assignment(cost_matrix) + trk_is_matched = np.zeros(trk_masks.size(0), dtype=bool) + for d, t in zip(row_ind, col_ind): + if ious_np[d, t] >= iou_threshold_trk: + trk_is_matched[t] = True + else: + trk_is_matched = (ious_np >= iou_threshold_trk).any(axis=0) + # Non-empty tracks not matched by Hungarian assignment above threshold are unmatched + trk_is_nonempty = trk_masks_binary.any(dim=(1, 2)).cpu().numpy() + trk_is_unmatched = np.logical_and(trk_is_nonempty, ~trk_is_matched) + unmatched_trk_obj_ids = trk_obj_ids[trk_is_unmatched] + # also record masklets that have zero area in SAM 2 prediction + empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty] + + # For detections: allow many tracks to match to the same detection (many-to-one) + # So, a detection is 'new' if it does not match any track above threshold + is_new_det = np.logical_and( + det_scores_np >= new_det_thresh, + np.logical_not(np.any(ious_np >= iou_threshold, axis=1)), + ) + new_det_fa_inds = np.nonzero(is_new_det)[0] + + # for each detection, which tracks it matched to (above threshold) + det_to_matched_trk_obj_ids = {} + trk_id_to_max_iou_high_conf_det = {} # trk id --> exactly one detection idx + det_to_max_iou_trk_idx = np.argmax(ious_np, axis=1) + det_is_high_conf = (det_scores_np >= self.HIGH_CONF_THRESH) & ~is_new_det + det_is_high_iou = np.max(ious_np, axis=1) >= self.HIGH_IOU_THRESH + det_is_high_conf_and_iou = set(np.nonzero(det_is_high_conf & det_is_high_iou)[0]) + for d in range(det_masks.size(0)): + det_to_matched_trk_obj_ids[d] = trk_obj_ids[ious_np[d, :] >= iou_threshold] + if d in det_is_high_conf_and_iou: + trk_obj_id = trk_obj_ids[det_to_max_iou_trk_idx[d]].item() + trk_id_to_max_iou_high_conf_det[trk_obj_id] = d + + return ( + new_det_fa_inds, + unmatched_trk_obj_ids, + det_to_matched_trk_obj_ids, + trk_id_to_max_iou_high_conf_det, + empty_trk_obj_ids, + ) + + def _process_hotstart( + self, + frame_idx: int, + reverse: bool, + det_to_matched_trk_obj_ids: dict[int, np.ndarray], + new_det_obj_ids: np.ndarray, + empty_trk_obj_ids: np.ndarray, + unmatched_trk_obj_ids: np.ndarray, + metadata: dict[str, Any], + ): + """Handle hotstart heuristics to remove unmatched or duplicated objects.""" + # obj_id --> first frame index where the object was detected + obj_first_frame_idx = metadata["obj_first_frame_idx"] + # obj_id --> [mismatched frame indices] + unmatched_frame_inds = metadata["unmatched_frame_inds"] + trk_keep_alive = metadata["trk_keep_alive"] + # (first_appear_obj_id, obj_id) --> [overlap frame indices] + overlap_pair_to_frame_inds = metadata["overlap_pair_to_frame_inds"] + # removed_obj_ids: object IDs that are suppressed via hot-start + removed_obj_ids = metadata["removed_obj_ids"] + suppressed_obj_ids = metadata["suppressed_obj_ids"][frame_idx] + + obj_ids_newly_removed = set() # object IDs to be newly removed on this frame + hotstart_diff = frame_idx - self.hotstart_delay if not reverse else frame_idx + self.hotstart_delay + + # Step 1: log the frame index where each object ID first appears + for obj_id in new_det_obj_ids: + if obj_id not in obj_first_frame_idx: + obj_first_frame_idx[obj_id] = frame_idx + assert obj_id not in trk_keep_alive + trk_keep_alive[obj_id] = self.init_trk_keep_alive + + matched_trks = set() + # We use the det-->tracks list to check for matched objects. Otherwise, we need to compute areas to decide whether they're occluded + for matched_trks_per_det in det_to_matched_trk_obj_ids.values(): + matched_trks.update(matched_trks_per_det) + for obj_id in matched_trks: + # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the max value of trk_keep_alive + trk_keep_alive[obj_id] = min(self.max_trk_keep_alive, trk_keep_alive[obj_id] + 1) + for obj_id in unmatched_trk_obj_ids: + unmatched_frame_inds[obj_id].append(frame_idx) + # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive + # The max keep alive is 2x the min, means the model prefers to keep the prediction rather than suppress it if it was matched long enough. + trk_keep_alive[obj_id] = max(self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1) + if self.decrease_trk_keep_alive_for_empty_masklets: + for obj_id in empty_trk_obj_ids: + # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive + trk_keep_alive[obj_id] = max(self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1) + + # Step 2: removed tracks that has not matched with detections for `hotstart_unmatch_thresh` frames with hotstart period + # a) add unmatched frame indices for each existing object ID + # note that `unmatched_trk_obj_ids` contains those frames where the SAM2 output mask + # doesn't match any detection; it excludes those frames where SAM2 gives an empty mask + # b) remove a masklet if it first appears after `hotstart_diff` and is unmatched for more + # than `self.hotstart_unmatch_thresh` frames + for obj_id, frame_indices in unmatched_frame_inds.items(): + if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: + continue # skip if the object is already removed + if len(frame_indices) >= self.hotstart_unmatch_thresh: + is_within_hotstart = (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or ( + obj_first_frame_idx[obj_id] < hotstart_diff and reverse + ) + if is_within_hotstart: + obj_ids_newly_removed.add(obj_id) + LOGGER.debug( + f"Removing object {obj_id} at frame {frame_idx} " + f"since it is unmatched for frames: {frame_indices}" + ) + if ( + trk_keep_alive[obj_id] <= 0 # Object has not been matched for too long + and not self.suppress_unmatched_only_within_hotstart + and obj_id not in removed_obj_ids + and obj_id not in obj_ids_newly_removed + ): + LOGGER.debug(f"Suppressing object {obj_id} at frame {frame_idx}, due to being unmatched") + suppressed_obj_ids.add(obj_id) + + # Step 3: removed tracks that overlaps with another track for `hotstart_dup_thresh` frames + # a) find overlaps tracks -- we consider overlap if they match to the same detection + for _, matched_trk_obj_ids in det_to_matched_trk_obj_ids.items(): + if len(matched_trk_obj_ids) < 2: + continue # only count detections that are matched to multiple (>=2) masklets + # if there are multiple matched track ids, we need to find the one that appeared first; + # these later appearing ids may be removed since they may be considered as duplicates + first_appear_obj_id = ( + min(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) + if not reverse + else max(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) + ) + for obj_id in matched_trk_obj_ids: + if obj_id != first_appear_obj_id: + key = (first_appear_obj_id, obj_id) + overlap_pair_to_frame_inds[key].append(frame_idx) + + # b) remove a masklet if it first appears after `hotstart_diff` and it overlaps with another + # masklet (that appears earlier) for more than `self.hotstart_dup_thresh` frames + for (first_obj_id, obj_id), frame_indices in overlap_pair_to_frame_inds.items(): + if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: + continue # skip if the object is already removed + if (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or ( + obj_first_frame_idx[obj_id] < hotstart_diff and reverse + ): + if len(frame_indices) >= self.hotstart_dup_thresh: + obj_ids_newly_removed.add(obj_id) + LOGGER.debug( + f"Removing object {obj_id} at frame {frame_idx} " + f"since it overlaps with another track {first_obj_id} at frames: {frame_indices}" + ) + + removed_obj_ids.update(obj_ids_newly_removed) + return obj_ids_newly_removed, metadata + + def _tracker_update_memories( + self, tracker_inference_states: list[Any], frame_idx: int, low_res_masks: torch.Tensor + ): + """Run Sam2 memory encoder, enforcing non-overlapping constraints globally.""" + if len(tracker_inference_states) == 0: + return + # NOTE: inspect this part if we observe OOMs in the demo + high_res_masks = F.interpolate( + low_res_masks.unsqueeze(1), + size=self.interpol_size, + mode="bilinear", + align_corners=False, + ) + # We first apply non-overlapping constraints before memory encoding. This may include some suppression heuristics. + if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: + high_res_masks = self.tracker.model._suppress_object_pw_area_shrinkage(high_res_masks) + # Instead of gathering the predicted object scores, we use mask areas as a proxy. + object_score_logits = torch.where((high_res_masks > 0).any(dim=(-1, -2)), 10.0, -10.0) + + # Run the memory encoder on local slices for each GPU + start_idx_gpu = 0 + start_idx_state = start_idx_gpu + for tracker_state in tracker_inference_states: + num_obj_per_state = len(tracker_state["obj_ids"]) + if num_obj_per_state == 0: + continue + # Get the local high-res masks and object score logits for this inference state + end_idx_state = start_idx_state + num_obj_per_state + local_high_res_masks = high_res_masks[start_idx_state:end_idx_state] + local_object_score_logits = object_score_logits[start_idx_state:end_idx_state] + local_batch_size = local_high_res_masks.size(0) + # Run Sam2 memory encoder. Note that we do not re-enforce the non-overlapping constraint as it is turned off by default + + encoded_mem = self.tracker._run_memory_encoder( + local_batch_size, + local_high_res_masks, + local_object_score_logits, + is_mask_from_pts=False, + inference_state=tracker_state, + ) + local_maskmem_features, local_maskmem_pos_enc = encoded_mem + # Store encoded memories in the local inference state + output_dict = tracker_state["output_dict"] + for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: + if frame_idx not in output_dict[storage_key]: + continue + output_dict[storage_key][frame_idx]["maskmem_features"] = local_maskmem_features + output_dict[storage_key][frame_idx]["maskmem_pos_enc"] = [pos for pos in local_maskmem_pos_enc] + # for batched inference state, we also need to add per-object + # memory slides to support instance interactivity + self.tracker._add_output_per_object( + inference_state=tracker_state, + frame_idx=frame_idx, + current_out=output_dict[storage_key][frame_idx], + storage_key=storage_key, + ) + start_idx_state += num_obj_per_state + + def _tracker_add_new_objects( + self, + frame_idx: int, + num_frames: int, + new_obj_ids: list[int], + new_obj_masks: torch.Tensor, + tracker_states_local: list[Any], + ): + """Add a new object to SAM2 inference states.""" + prev_tracker_state = tracker_states_local[0] if len(tracker_states_local) > 0 else None + + # prepare inference_state + # batch objects that first appear on the same frame together + # Clear inference state. Keep the cached image features if available. + new_tracker_state = self.tracker._init_state(num_frames=num_frames) + # NOTE: adding image placeholder + new_tracker_state["im"] = None + new_tracker_state["backbone_out"] = ( + prev_tracker_state.get("backbone_out", None) if prev_tracker_state is not None else None + ) + + assert len(new_obj_ids) == new_obj_masks.size(0) + assert new_obj_masks.is_floating_point() + new_obj_masks = F.interpolate( + new_obj_masks.unsqueeze(0), + size=self.interpol_size, + mode="bilinear", + align_corners=False, + ).squeeze(0) + new_obj_masks = new_obj_masks > 0 + + # add object one by one + for new_obj_id, new_mask in zip(new_obj_ids, new_obj_masks): + self.tracker.add_new_prompts( + inference_state=new_tracker_state, + frame_idx=frame_idx, + obj_id=new_obj_id, + masks=new_mask[None, None], # add bs, channel + ) + # NOTE: we skip enforcing the non-overlapping constraint **globally** when adding new objects. + self.tracker.propagate_in_video_preflight(new_tracker_state) + tracker_states_local.append(new_tracker_state) + return tracker_states_local + + def _tracker_remove_objects(self, tracker_states_local: list[Any], obj_ids: list[int]): + """Remove an object from SAM2 inference states. This would remove the object from all frames in the video.""" + if not obj_ids: + return + # Filter out states that become empty after removal + active_states = [] + for state in tracker_states_local: + for obj_id in obj_ids: + # we try to remove `obj_id` on every inference state with `strict=False` + # it will not do anything if an inference state doesn't contain `obj_id` + self.tracker.remove_object(state, obj_id, strict=False) + + if len(state["obj_ids"]) > 0: + active_states.append(state) + + # Update the list in-place + tracker_states_local[:] = active_states + + def _initialize_metadata(self): + """Initialize metadata for the masklets.""" + tracker_metadata = { + "obj_ids": np.array([], np.int32), + "num_obj": np.zeros(1, np.int32), + "max_obj_id": -1, + "obj_id_to_score": {}, + "obj_id_to_cls": {}, + "obj_id_to_tracker_score_frame_wise": defaultdict(dict), + "obj_id_to_last_occluded": {}, + } + # "metadata" contains metadata that is only stored on (and accessible to) GPU 0 + # - obj_first_frame_idx: obj_id --> first frame index where the object was detected + # - unmatched_frame_inds: obj_id --> [mismatched frame indices] + # - overlap_pair_to_frame_inds: (first_appear_obj_id, obj_id) --> [overlap frame indices] + # - removed_obj_ids: object IDs that are suppressed via hot-start + metadata = { + "obj_first_frame_idx": {}, + "unmatched_frame_inds": defaultdict(list), + "trk_keep_alive": defaultdict(int), # This is used only for object suppression not for removal + "overlap_pair_to_frame_inds": defaultdict(list), + "removed_obj_ids": set(), + # frame_idx --> set of objects with suppressed outputs, but still continue to be tracked + "suppressed_obj_ids": defaultdict(set), + } + if self.masklet_confirmation_enable: + # all the following are np.ndarray with the same shape as `obj_ids_all_gpu` + metadata["masklet_confirmation"] = { + # "status" is the confirmation status of each masklet + "status": np.array([], np.int64), + # "consecutive_det_num" is the number of consecutive frames where the masklet is + # detected by the detector (with a matched detection) + "consecutive_det_num": np.array([], np.int64), + } + tracker_metadata["metadata"] = metadata + + return tracker_metadata + + def update_masklet_confirmation_status( + self, + metadata: dict[str, Any], + obj_ids_all_gpu_prev: np.ndarray, + obj_ids_all_gpu_updated: np.ndarray, + det_to_matched_trk_obj_ids: dict[int, np.ndarray], + new_det_obj_ids: np.ndarray, + ): + """Update the confirmation status of masklets based on the current frame's detection results.""" + confirmation_data = metadata["masklet_confirmation"] + + # a) first, expand "confirmation_data" to include new masklets added in this frame + status_prev = confirmation_data["status"] + consecutive_det_num_prev = confirmation_data["consecutive_det_num"] + assert status_prev.shape == obj_ids_all_gpu_prev.shape, ( + f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}" + ) + + obj_id_to_updated_idx = {obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)} + prev_elem_is_in_updated = np.isin(obj_ids_all_gpu_prev, obj_ids_all_gpu_updated) + prev_elem_obj_ids_in_updated = obj_ids_all_gpu_prev[prev_elem_is_in_updated] + prev_elem_inds_in_updated = np.array( + [obj_id_to_updated_idx[obj_id] for obj_id in prev_elem_obj_ids_in_updated], + dtype=np.int64, + ) + # newly added masklets are initialized to "UNCONFIRMED" status + unconfirmed_val = self.UNCONFIRMED + status = np.full_like(obj_ids_all_gpu_updated, fill_value=unconfirmed_val) + status[prev_elem_inds_in_updated] = status_prev[prev_elem_is_in_updated] + consecutive_det_num = np.zeros_like(obj_ids_all_gpu_updated) + consecutive_det_num[prev_elem_inds_in_updated] = consecutive_det_num_prev[prev_elem_is_in_updated] + + # b) update the confirmation status of all masklets based on the current frame + # b.1) update "consecutive_det_num" + # "is_matched": whether a masklet is matched to a detection on this frame + is_matched = np.isin(obj_ids_all_gpu_updated, new_det_obj_ids) + for matched_trk_obj_ids in det_to_matched_trk_obj_ids.values(): + is_matched |= np.isin(obj_ids_all_gpu_updated, matched_trk_obj_ids) + consecutive_det_num = np.where(is_matched, consecutive_det_num + 1, 0) + + # b.2) update "status" + change_to_confirmed = consecutive_det_num >= self.masklet_confirmation_consecutive_det_thresh + status[change_to_confirmed] = self.CONFIRMED + + confirmation_data["status"] = status + confirmation_data["consecutive_det_num"] = consecutive_det_num + return metadata + + def _load_checkpoint(self, ckpt_path: str, strict: bool = True): + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] + missing_keys, unexpected_keys = self.load_state_dict(sd, strict=strict) + if len(missing_keys) > 0 or len(unexpected_keys) > 0: + LOGGER.warning(f"Loaded ckpt with {missing_keys=}, {unexpected_keys=}") + else: + LOGGER.info("Loaded ckpt successfully without missing or unexpected keys") + + def _encode_prompt(self, **kwargs): + return self.model._encode_prompt(**kwargs) + + def _drop_new_det_with_obj_limit(self, new_det_fa_inds, det_scores_np, num_to_keep): + """Drop a few new detections based on the maximum number of objects. We drop new objects based on their + detection scores, keeping the high-scoring ones and dropping the low-scoring ones. + """ + assert 0 <= num_to_keep <= len(new_det_fa_inds) + if num_to_keep == 0: + return np.array([], np.int64) # keep none + if num_to_keep == len(new_det_fa_inds): + return new_det_fa_inds # keep all + + # keep the top-scoring detections + score_order = np.argsort(det_scores_np[new_det_fa_inds])[::-1] + new_det_fa_inds = new_det_fa_inds[score_order[:num_to_keep]] + return new_det_fa_inds diff --git a/ultralytics/models/sam/sam3/__init__.py b/ultralytics/models/sam/sam3/__init__.py new file mode 100644 index 0000000000..fbe077b618 --- /dev/null +++ b/ultralytics/models/sam/sam3/__init__.py @@ -0,0 +1,3 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved diff --git a/ultralytics/models/sam/sam3/decoder.py b/ultralytics/models/sam/sam3/decoder.py new file mode 100644 index 0000000000..f9d789b48f --- /dev/null +++ b/ultralytics/models/sam/sam3/decoder.py @@ -0,0 +1,546 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved +""" +Transformer decoder. +Inspired from Pytorch's version, adds the pre-norm variant. +""" + +from __future__ import annotations + +import numpy as np +import torch +from torch import nn +from torchvision.ops.roi_align import RoIAlign + +from ultralytics.nn.modules.transformer import MLP +from ultralytics.nn.modules.utils import _get_clones, inverse_sigmoid +from ultralytics.utils.ops import xywh2xyxy + +from .model_misc import gen_sineembed_for_position + + +class TransformerDecoderLayer(nn.Module): + """TransformerDecoderLayer is made up of self-attn, cross-attn, and feedforward network (FFN).""" + + def __init__( + self, + d_model: int, + dim_feedforward: int, + dropout: float, + cross_attention: nn.Module, + n_heads: int, + use_text_cross_attention: bool = False, + ): + """Initialize the TransformerDecoderLayer.""" + super().__init__() + # cross attention + self.cross_attn = cross_attention + self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm1 = nn.LayerNorm(d_model) + + # cross attention text + self.use_text_cross_attention = use_text_cross_attention + if use_text_cross_attention: + self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.catext_norm = nn.LayerNorm(d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.activation = nn.ReLU() + self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + """Add positional embedding to the tensor.""" + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + """Feedforward network forward pass.""" + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward( + self, + # for tgt + tgt: torch.Tensor, # nq, bs, d_model + tgt_query_pos: torch.Tensor = None, # pos for query. MLP(Sine(pos)) + memory_text: torch.Tensor = None, # num_token, bs, d_model + text_attention_mask: torch.Tensor = None, # bs, num_token + # for memory + memory: torch.Tensor = None, # hw, bs, d_model + memory_key_padding_mask: torch.Tensor = None, + memory_pos: torch.Tensor = None, # pos for memory + # sa + self_attn_mask: torch.Tensor = None, # mask used for self-attention + cross_attn_mask: torch.Tensor = None, # mask used for cross-attention + # dac + dac=False, + dac_use_selfatt_ln=True, + presence_token=None, + # skip inside deformable attn + **kwargs, # additional kwargs for compatibility + ): + """Input: - tgt/tgt_query_pos: nq, bs, d_model. -.""" + # self attention + tgt, tgt_query_pos = self._apply_self_attention( + tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask + ) + + if self.use_text_cross_attention: + tgt2 = self.ca_text( + self.with_pos_embed(tgt, tgt_query_pos), + memory_text.to(tgt.dtype), + memory_text.to(tgt.dtype), + key_padding_mask=text_attention_mask, + )[0] + tgt = tgt + self.catext_dropout(tgt2) + tgt = self.catext_norm(tgt) + + if presence_token is not None: + presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :]) + cross_attn_mask = torch.cat([presence_token_mask, cross_attn_mask], dim=1) # (bs*nheads, 1+nq, hw) + + # Cross attention to image + tgt2 = self.cross_attn( + query=self.with_pos_embed(tgt, tgt_query_pos), + key=self.with_pos_embed(memory, memory_pos), + value=memory, + attn_mask=cross_attn_mask, + key_padding_mask=(memory_key_padding_mask.transpose(0, 1) if memory_key_padding_mask is not None else None), + need_weights=False, + )[0] + + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt = self.forward_ffn(tgt.to(memory.dtype)) + + presence_token_out = None + if presence_token is not None: + presence_token_out = tgt[:1] + tgt = tgt[1:] + + return tgt, presence_token_out + + def _apply_self_attention(self, tgt, tgt_query_pos, dac, dac_use_selfatt_ln, presence_token, self_attn_mask): + """Apply self-attention with optional DAC splitting.""" + if self.self_attn is None: + return tgt + + if dac: + # Split queries for DAC (detect-and-classify) + assert tgt.shape[0] % 2 == 0, "DAC requires even number of queries" + num_o2o_queries = tgt.shape[0] // 2 + tgt_o2o = tgt[:num_o2o_queries] + tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries] + tgt_o2m = tgt[num_o2o_queries:] + else: + tgt_o2o = tgt + tgt_query_pos_o2o = tgt_query_pos + + # Handle presence token + if presence_token is not None: + tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0) + tgt_query_pos_o2o = torch.cat([torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0).to( + tgt_o2o.dtype + ) + tgt_query_pos = torch.cat([torch.zeros_like(presence_token), tgt_query_pos], dim=0) + + # Self-attention + q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o) + tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0].to(tgt.dtype) + tgt_o2o = tgt_o2o + self.dropout2(tgt2) + + # Recombine and normalize + if dac: + if not dac_use_selfatt_ln: + tgt_o2o = self.norm2(tgt_o2o) + tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0) + if dac_use_selfatt_ln: + tgt = self.norm2(tgt) + else: + tgt = tgt_o2o + tgt = self.norm2(tgt) + + return tgt, tgt_query_pos + + +class TransformerDecoder(nn.Module): + """Transformer Decoder consisting of multiple layers.""" + + def __init__( + self, + d_model: int, + frozen: bool, + interaction_layer, + layer, + num_layers: int, + num_queries: int, + return_intermediate: bool, + box_refine: bool = False, + num_o2m_queries: int = 0, + dac: bool = False, + boxRPB: str = "none", + # Experimental: An object query for SAM 2 tasks + instance_query: bool = False, + # Defines the number of additional instance queries, + # 1 or 4 are the most likely for single vs multi mask support + num_instances: int = 1, # Irrelevant if instance_query is False + dac_use_selfatt_ln: bool = True, + use_act_checkpoint: bool = False, + compile_mode=None, + presence_token: bool = False, + clamp_presence_logits: bool = True, + clamp_presence_logit_max_val: float = 10.0, + use_normed_output_consistently: bool = True, + separate_box_head_instance: bool = False, + separate_norm_instance: bool = False, + ): + """Initialize the TransformerDecoder.""" + super().__init__() + self.d_model = d_model + self.layers = _get_clones(layer, num_layers) + self.fine_layers = ( + _get_clones(interaction_layer, num_layers) if interaction_layer is not None else [None] * num_layers + ) + self.num_layers = num_layers + self.num_queries = num_queries + self.dac = dac + if dac: + self.num_o2m_queries = num_queries + tot_num_queries = num_queries + else: + self.num_o2m_queries = num_o2m_queries + tot_num_queries = num_queries + num_o2m_queries + self.norm = nn.LayerNorm(d_model) + self.return_intermediate = return_intermediate + self.bbox_embed = MLP(d_model, d_model, 4, 3) + self.query_embed = nn.Embedding(tot_num_queries, d_model) + self.instance_query_embed = None + self.instance_query_reference_points = None + self.use_instance_query = instance_query + self.num_instances = num_instances + self.use_normed_output_consistently = use_normed_output_consistently + + self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None + self.instance_bbox_embed = None + if separate_box_head_instance: + self.instance_bbox_embed = MLP(d_model, d_model, 4, 3) + if instance_query: + self.instance_query_embed = nn.Embedding(num_instances, d_model) + self.box_refine = box_refine + if box_refine: + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + + self.reference_points = nn.Embedding(num_queries, 4) + if instance_query: + self.instance_reference_points = nn.Embedding(num_instances, 4) + + assert boxRPB in ["none", "log", "linear", "both"] + self.boxRPB = boxRPB + if boxRPB != "none": + try: + nheads = self.layers[0].cross_attn_image.num_heads + except AttributeError: + nheads = self.layers[0].cross_attn.num_heads + + n_input = 4 if boxRPB == "both" else 2 + self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2) + self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2) + self.compilable_cord_cache = None + self.compilable_stored_size = None + self.coord_cache = {} + + self.roi_pooler = ( + RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True) + if interaction_layer is not None + else None + ) + if frozen: + for p in self.parameters(): + p.requires_grad_(False) + + self.presence_token = None + self.clamp_presence_logits = clamp_presence_logits + self.clamp_presence_logit_max_val = clamp_presence_logit_max_val + if presence_token: + self.presence_token = nn.Embedding(1, d_model) + self.presence_token_head = MLP(d_model, d_model, 1, 3) + self.presence_token_out_norm = nn.LayerNorm(d_model) + + self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2) + self.dac_use_selfatt_ln = dac_use_selfatt_ln + self.use_act_checkpoint = use_act_checkpoint + + nn.init.normal_(self.query_embed.weight.data) + if self.instance_query_embed is not None: + nn.init.normal_(self.instance_query_embed.weight.data) + + assert self.roi_pooler is None + assert self.return_intermediate, "support return_intermediate only" + assert self.box_refine, "support box refine only" + + self.compile_mode = compile_mode + self.compiled = False + # We defer compilation till after the first forward, to first warm-up the boxRPB cache + + # assign layer index to each layer so that some layers can decide what to do + # based on which layer index they are (e.g. cross attention to memory bank only + # in selected layers) + for layer_idx, layer in enumerate(self.layers): + layer.layer_idx = layer_idx + + @staticmethod + def _get_coords(H, W, device, dtype): + """Get normalized coordinates for height and width.""" + coords_h = torch.arange(0, H, dtype=dtype, device=device) / H + coords_w = torch.arange(0, W, dtype=dtype, device=device) / W + return coords_h, coords_w + + def _get_rpb_matrix(self, reference_boxes, feat_size): + """Get the relative position bias (RPB) matrix for box-relative position bias.""" + H, W = feat_size + boxes_xyxy = xywh2xyxy(reference_boxes).transpose(0, 1) + bs, num_queries, _ = boxes_xyxy.shape + if self.compilable_cord_cache is None: + self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device, reference_boxes.dtype) + self.compilable_stored_size = (H, W) + + if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == ( + H, + W, + ): + # good, hitting the cache, will be compilable + coords_h, coords_w = self.compilable_cord_cache + else: + # cache miss, will create compilation issue + # In case we're not compiling, we'll still rely on the dict-based cache + if feat_size not in self.coord_cache: + self.coord_cache[feat_size] = self._get_coords(H, W, reference_boxes.device) + coords_h, coords_w = self.coord_cache[feat_size] + + assert coords_h.shape == (H,) + assert coords_w.shape == (W,) + + deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2] + deltas_y = deltas_y.view(bs, num_queries, -1, 2) + deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2] + deltas_x = deltas_x.view(bs, num_queries, -1, 2) + + if self.boxRPB in ["log", "both"]: + deltas_x_log = deltas_x * 8 # normalize to -8, 8 + deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / np.log2(8) + + deltas_y_log = deltas_y * 8 # normalize to -8, 8 + deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / np.log2(8) + if self.boxRPB == "log": + deltas_x = deltas_x_log + deltas_y = deltas_y_log + else: + deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1) + deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1) + + if self.training: + assert self.use_act_checkpoint, "activation ckpt not enabled in decoder" + deltas_x = self.boxRPB_embed_x(x=deltas_x) # bs, num_queries, W, n_heads + deltas_y = self.boxRPB_embed_y(x=deltas_y) # bs, num_queries, H, n_heads + + if not torch.compiler.is_dynamo_compiling(): + assert deltas_x.shape[:3] == (bs, num_queries, W) + assert deltas_y.shape[:3] == (bs, num_queries, H) + + B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(2) # bs, num_queries, H, W, n_heads + if not torch.compiler.is_dynamo_compiling(): + assert B.shape[:4] == (bs, num_queries, H, W) + B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads + B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W + B = B.contiguous() # memeff attn likes ordered strides + if not torch.compiler.is_dynamo_compiling(): + assert B.shape[2:] == (num_queries, H * W) + return B + + def forward( + self, + tgt, + memory, + tgt_mask: torch.Tensor = None, + memory_mask: torch.Tensor = None, + memory_key_padding_mask: torch.Tensor = None, + pos: torch.Tensor = None, + reference_boxes: torch.Tensor = None, # num_queries, bs, 4 + # for memory + spatial_shapes: torch.Tensor = None, # bs, num_levels, 2 + valid_ratios: torch.Tensor = None, + # for text + memory_text: torch.Tensor = None, + text_attention_mask: torch.Tensor = None, + # if `apply_dac` is None, it will default to `self.dac` + apply_dac: bool | None = None, + is_instance_prompt=False, + decoder_extra_kwargs: dict | None = None, + # ROI memory bank + obj_roi_memory_feat=None, + obj_roi_memory_mask=None, + box_head_trk=None, + ): + """Forward pass of the TransformerDecoder.""" + if memory_mask is not None: + assert self.boxRPB == "none", ( + "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented" + ) + + apply_dac = apply_dac if apply_dac is not None else self.dac + if apply_dac: + assert (tgt.shape[0] == self.num_queries) or ( + self.use_instance_query and (tgt.shape[0] == self.instance_query_embed.num_embeddings) + ) + + tgt = tgt.repeat(2, 1, 1) + # note that we don't tile tgt_mask, since DAC doesn't + # use self-attention in o2m queries + if reference_boxes is not None: + assert (reference_boxes.shape[0] == self.num_queries) or ( + self.use_instance_query and (reference_boxes.shape[0] == self.instance_query_embed.num_embeddings) + ) + reference_boxes = reference_boxes.repeat(2, 1, 1) + + bs = tgt.shape[1] + intermediate = [] + intermediate_presence_logits = [] + presence_feats = None + + if self.box_refine: + if reference_boxes is None: + # In this case, we're in a one-stage model, so we generate the reference boxes + reference_boxes = self.reference_points.weight.unsqueeze(1) + reference_boxes = reference_boxes.repeat(2, bs, 1) if apply_dac else reference_boxes.repeat(1, bs, 1) + reference_boxes = reference_boxes.sigmoid() + intermediate_ref_boxes = [reference_boxes] + else: + reference_boxes = None + intermediate_ref_boxes = None + + output = tgt + presence_out = None + if self.presence_token is not None and is_instance_prompt is False: + # expand to batch dim + presence_out = self.presence_token.weight[None].expand(1, bs, -1) + + box_head = self.bbox_embed + if is_instance_prompt and self.instance_bbox_embed is not None: + box_head = self.instance_bbox_embed + + out_norm = self.norm + if is_instance_prompt and self.instance_norm is not None: + out_norm = self.instance_norm + + for layer_idx, layer in enumerate(self.layers): + reference_points_input = ( + reference_boxes[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :] + ) # nq, bs, nlevel, 4 + + query_sine_embed = gen_sineembed_for_position( + reference_points_input[:, :, 0, :], self.d_model + ) # nq, bs, d_model*2 + + # conditional query + query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model + + if self.boxRPB != "none" and reference_boxes is not None: + assert spatial_shapes.shape[0] == 1, "only single scale support implemented" + memory_mask = self._get_rpb_matrix( + reference_boxes, + (spatial_shapes[0, 0], spatial_shapes[0, 1]), + ) + memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W) + if self.training: + assert self.use_act_checkpoint, "Activation checkpointing not enabled in the decoder" + output, presence_out = layer( + tgt=output, + tgt_query_pos=query_pos, + memory_text=memory_text, + text_attention_mask=text_attention_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + memory_pos=pos, + self_attn_mask=tgt_mask, + cross_attn_mask=memory_mask, + dac=apply_dac, + dac_use_selfatt_ln=self.dac_use_selfatt_ln, + presence_token=presence_out, + **(decoder_extra_kwargs or {}), + # ROI memory bank + obj_roi_memory_feat=obj_roi_memory_feat, + obj_roi_memory_mask=obj_roi_memory_mask, + ) + + # iter update + if self.box_refine: + reference_before_sigmoid = inverse_sigmoid(reference_boxes) + if box_head_trk is None: + # delta_unsig = self.bbox_embed(output) + if not self.use_normed_output_consistently: + delta_unsig = box_head(output) + else: + delta_unsig = box_head(out_norm(output)) + else: + # box_head_trk use a separate box head for tracking queries + Q_det = decoder_extra_kwargs["Q_det"] + assert output.size(0) >= Q_det + delta_unsig_det = self.bbox_embed(output[:Q_det]) + delta_unsig_trk = box_head_trk(output[Q_det:]) + delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0) + outputs_unsig = delta_unsig + reference_before_sigmoid + new_reference_points = outputs_unsig.sigmoid() + + reference_boxes = new_reference_points.detach() + if layer_idx != self.num_layers - 1: + intermediate_ref_boxes.append(new_reference_points) + else: + raise NotImplementedError("not implemented yet") + + intermediate.append(out_norm(output)) + if self.presence_token is not None and is_instance_prompt is False: + # norm, mlp head + intermediate_layer_presence_logits = self.presence_token_head( + self.presence_token_out_norm(presence_out) + ).squeeze(-1) + + # clamp to mitigate numerical issues + if self.clamp_presence_logits: + intermediate_layer_presence_logits.clamp( + min=-self.clamp_presence_logit_max_val, + max=self.clamp_presence_logit_max_val, + ) + + intermediate_presence_logits.append(intermediate_layer_presence_logits) + presence_feats = presence_out.clone() + + if not self.compiled and self.compile_mode is not None: + self.forward = torch.compile(self.forward, mode=self.compile_mode, fullgraph=True) + self.compiled = True + + return ( + torch.stack(intermediate), + torch.stack(intermediate_ref_boxes), + ( + torch.stack(intermediate_presence_logits) + if self.presence_token is not None and is_instance_prompt is False + else None + ), + presence_feats, + ) diff --git a/ultralytics/models/sam/sam3/encoder.py b/ultralytics/models/sam/sam3/encoder.py new file mode 100644 index 0000000000..4300023e78 --- /dev/null +++ b/ultralytics/models/sam/sam3/encoder.py @@ -0,0 +1,535 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved +# Based on https://github.com/IDEA-Research/GroundingDINO +from __future__ import annotations + +import torch +from torch import nn + +from ultralytics.nn.modules.utils import _get_clones + +from .model_misc import get_valid_ratio + + +class TransformerEncoderLayer(nn.Module): + """Transformer encoder layer that performs self-attention followed by cross-attention. + + This layer was previously called TransformerDecoderLayer but was renamed to better reflect its role in the + architecture. It processes input sequences through self-attention and then cross-attention with another input + (typically image features). + + The layer supports both pre-norm and post-norm configurations, as well as positional encoding at different stages of + the attention mechanism. + """ + + def __init__( + self, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + pre_norm: bool, + self_attention: nn.Module = None, + cross_attention: nn.Module = None, + ): + """Initialize a transformer encoder layer. + + Args: + cross_attention: Cross-attention module for attending to image features + d_model: Model dimension/hidden size + dim_feedforward: Dimension of the feedforward network + dropout: Dropout probability + pos_enc_at_attn: Whether to add positional encodings at self-attention + pos_enc_at_cross_attn_keys: Whether to add positional encodings to keys in cross-attention + pos_enc_at_cross_attn_queries: Whether to add positional encodings to queries in cross-attention + pre_norm: Whether to use pre-norm (True) or post-norm (False) architecture + self_attention: Self-attention module + """ + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention or nn.MultiheadAttention(num_heads=8, dropout=0.1, embed_dim=256) + self.cross_attn_image = cross_attention or nn.MultiheadAttention(num_heads=8, dropout=0.1, embed_dim=256) + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = nn.ReLU() + self.pre_norm = pre_norm + + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + self.layer_idx = None + + def forward_post( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: torch.Tensor = None, + memory_mask: torch.Tensor = None, + tgt_key_padding_mask: torch.Tensor = None, + memory_key_padding_mask: torch.Tensor = None, + pos: torch.Tensor = None, + query_pos: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """Forward pass for post-norm architecture. + + In post-norm architecture, normalization is applied after attention and feedforward operations. + + Args: + tgt: Input tensor to be processed + memory: Memory tensor for cross-attention + tgt_mask: Mask for self-attention + memory_mask: Mask for cross-attention + tgt_key_padding_mask: Key padding mask for self-attention + memory_key_padding_mask: Key padding mask for cross-attention + pos: Positional encoding for memory + query_pos: Positional encoding for query + **kwargs: Additional keyword arguments + + Returns: + Processed tensor + """ + q = k = tgt + query_pos if self.pos_enc_at_attn else tgt + + # Self attention + tgt2 = self.self_attn( + q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask, need_weights=False + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # Cross attention to image + tgt2 = self.cross_attn_image( + query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt, + key=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + need_weights=False, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # FFN + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + dac: bool = False, + tgt_mask: torch.Tensor = None, + memory_mask: torch.Tensor = None, + tgt_key_padding_mask: torch.Tensor = None, + memory_key_padding_mask: torch.Tensor = None, + pos: torch.Tensor = None, + query_pos: torch.Tensor = None, + # **kwargs, + ) -> torch.Tensor: + """Forward pass for pre-norm architecture. + + In pre-norm architecture, normalization is applied before attention and feedforward operations. + + Args: + tgt: Input tensor to be processed + memory: Memory tensor for cross-attention + dac: Whether to use Divide-and-Conquer attention + tgt_mask: Mask for self-attention + memory_mask: Mask for cross-attention + tgt_key_padding_mask: Key padding mask for self-attention + memory_key_padding_mask: Key padding mask for cross-attention + pos: Positional encoding for memory + query_pos: Positional encoding for query + attn_bias: Optional attention bias tensor + **kwargs: Additional keyword arguments + + Returns: + Processed tensor + """ + if dac: + # we only apply self attention to the first half of the queries + assert tgt.shape[0] % 2 == 0 + other_tgt = tgt[tgt.shape[0] // 2 :] + tgt = tgt[: tgt.shape[0] // 2] + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + if dac: + # Recombine + tgt = torch.cat((tgt, other_tgt), dim=0) + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + key=memory.to(tgt2.dtype) + pos if self.pos_enc_at_cross_attn_keys else memory.to(tgt2.dtype), + value=memory.to(tgt2.dtype), + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + # attn_bias=attn_bias, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + dac: bool = False, + tgt_mask: torch.Tensor = None, + memory_mask: torch.Tensor = None, + tgt_key_padding_mask: torch.Tensor = None, + memory_key_padding_mask: torch.Tensor = None, + pos: torch.Tensor = None, + query_pos: torch.Tensor = None, + # **kwds: Any, + ) -> torch.Tensor: + """Forward pass for the transformer encoder layer. + + Args: + tgt: Input tensor to be processed + memory: Memory tensor (e.g., image features) for cross-attention + dac: Whether to use Divide-and-Conquer attention (only apply self-attention to first half) + tgt_mask: Mask for self-attention + memory_mask: Mask for cross-attention + tgt_key_padding_mask: Key padding mask for self-attention + memory_key_padding_mask: Key padding mask for cross-attention + pos: Positional encoding for memory + query_pos: Positional encoding for query + attn_bias: Optional attention bias tensor + **kwds: Additional keyword arguments + + Returns: + Processed tensor after self-attention, cross-attention, and feedforward network + """ + fwd_fn = self.forward_pre if self.pre_norm else self.forward_post + return fwd_fn( + tgt, + memory, + dac=dac, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + # attn_bias=attn_bias, + # **kwds, + ) + + +class TransformerEncoder(nn.Module): + """Transformer encoder that processes multi-level features. + + This encoder takes multi-level features (e.g., from a backbone network) and processes them through a stack of + transformer encoder layers. It supports features from multiple levels (e.g., different resolutions) and can apply + activation checkpointing for memory efficiency during training. + + Args: + layer: The encoder layer to be stacked multiple times + num_layers: Number of encoder layers to stack + d_model: Model dimension/hidden size + num_feature_levels: Number of feature levels to process + frozen: Whether to freeze the parameters of this module + use_act_checkpoint: Whether to use activation checkpointing during training + """ + + def __init__( + self, + layer: nn.Module, + num_layers: int, + d_model: int, + num_feature_levels: int, + frozen: bool = False, + use_act_checkpoint: bool = False, + ): + """Initialize the transformer encoder.""" + super().__init__() + self.layers = _get_clones(layer, num_layers) + self.num_layers = num_layers + + self.num_feature_levels = num_feature_levels + self.level_embed = None + if num_feature_levels > 1: + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + if frozen: + for p in self.parameters(): + p.requires_grad_(False) + + self.use_act_checkpoint = use_act_checkpoint + + # assign layer index to each layer so that some layers can decide what to do + # based on which layer index they are (e.g. cross attention to memory bank only + # in selected layers) + for layer_idx, layer in enumerate(self.layers): + layer.layer_idx = layer_idx + + def _prepare_multilevel_features(self, srcs, masks, pos_embeds): + """Prepare multi-level features for transformer encoder.""" + assert len(srcs) == self.num_feature_levels, "mismatch between expected and received # of feature levels" + + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + has_mask = masks is not None and masks[0] is not None + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + _, _, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + + src = src.flatten(2).transpose(1, 2) # bs, hw, c + if has_mask: + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c + if self.level_embed is not None: + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + else: + lvl_pos_embed = pos_embed + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + if has_mask: + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c + mask_flatten = torch.cat(mask_flatten, 1) if has_mask else None # bs, \sum{hxw} + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c + spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) + level_start_index = torch.cat( + ( + spatial_shapes.new_zeros((1,)), + spatial_shapes.prod(1).cumsum(0)[:-1], + ) + ) + if has_mask: + valid_ratios = torch.stack([get_valid_ratio(m) for m in masks], 1) + else: + valid_ratios = torch.ones( + (src_flatten.shape[0], self.num_feature_levels, 2), + device=src_flatten.device, + dtype=src_flatten.dtype, + ) + + return ( + src_flatten, + mask_flatten, + lvl_pos_embed_flatten, + level_start_index, + valid_ratios, + spatial_shapes, + ) + + def forward( + self, + src: list[torch.Tensor], + src_key_padding_masks: list[torch.Tensor] | None = None, + pos: list[torch.Tensor] | None = None, + prompt: torch.Tensor = None, + prompt_key_padding_mask: torch.Tensor = None, + encoder_extra_kwargs: dict | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Process multi-level features through the transformer encoder. + + Args: + src: List of multi-level features, each with shape (batch_size, channels, height, width) + src_key_padding_masks: List of padding masks for each feature level, each with shape (batch_size, height, + width) + pos: List of positional embeddings for each feature level, each with shape (batch_size, channels, height, + width) + prompt: Optional text/prompt features to attend to, with shape (seq_len, batch_size, d_model) + prompt_key_padding_mask: Optional padding mask for prompt, with shape (batch_size, seq_len) + encoder_extra_kwargs: Optional additional arguments to pass to each encoder layer + + Returns: + A tuple containing: + - output: Processed features with shape (seq_len, batch_size, d_model) + - key_padding_masks_flatten: Flattened padding masks + - lvl_pos_embed_flatten: Flattened positional embeddings + - level_start_index: Starting indices for each feature level + - spatial_shapes: Spatial dimensions of each feature level + - valid_ratios: Valid ratios for each feature level + """ + assert len(src) == self.num_feature_levels, "must be equal to num_feature_levels" + if src_key_padding_masks is not None: + assert len(src_key_padding_masks) == self.num_feature_levels + if pos is not None: + assert len(pos) == self.num_feature_levels + # Flatten multilevel feats and add level pos embeds + ( + src_flatten, + key_padding_masks_flatten, + lvl_pos_embed_flatten, + level_start_index, + valid_ratios, + spatial_shapes, + ) = self._prepare_multilevel_features(src, src_key_padding_masks, pos) + + output = src_flatten + for layer in self.layers: + layer_kwargs = {} + + assert isinstance(layer, TransformerEncoderLayer) + layer_kwargs["memory"] = prompt + layer_kwargs["memory_key_padding_mask"] = prompt_key_padding_mask + layer_kwargs["query_pos"] = lvl_pos_embed_flatten + layer_kwargs["tgt"] = output + layer_kwargs["tgt_key_padding_mask"] = key_padding_masks_flatten + + if self.training: + assert self.use_act_checkpoint, "activation ckpt not enabled in encoder" + if encoder_extra_kwargs is not None: + layer_kwargs.update(encoder_extra_kwargs) + output = layer(**layer_kwargs) + # return as seq first + return ( + output.transpose(0, 1), + (key_padding_masks_flatten.transpose(0, 1) if key_padding_masks_flatten is not None else None), + lvl_pos_embed_flatten.transpose(0, 1), + level_start_index, + spatial_shapes, + valid_ratios, + ) + + +class TransformerEncoderFusion(TransformerEncoder): + """Transformer encoder that fuses text and image features. + + This encoder extends TransformerEncoder to handle both text and image features, with the ability to add pooled text + features to image features for better cross-modal fusion. It supports torch.compile for performance optimization. + + Args: + layer: The encoder layer to be stacked multiple times + num_layers: Number of encoder layers to stack + d_model: Model dimension/hidden size + num_feature_levels: Number of feature levels to process + add_pooled_text_to_img_feat: Whether to add pooled text features to image features + pool_text_with_mask: Whether to use the mask when pooling text features + compile_mode: Mode for torch.compile, or None to disable compilation + **kwargs: Additional arguments to pass to the parent class + """ + + def __init__( + self, + layer: nn.Module, + num_layers: int, + d_model: int, + num_feature_levels: int, + add_pooled_text_to_img_feat: bool = True, + pool_text_with_mask: bool = False, + compile_mode: str | None = None, + **kwargs, + ): + """Initialize the transformer encoder with text-image fusion.""" + super().__init__( + layer, + num_layers, + d_model, + num_feature_levels, + **kwargs, + ) + self.add_pooled_text_to_img_feat = add_pooled_text_to_img_feat + if self.add_pooled_text_to_img_feat: + self.text_pooling_proj = nn.Linear(d_model, d_model) + self.pool_text_with_mask = pool_text_with_mask + if compile_mode is not None: + self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True) + + def forward( + self, + src: list[torch.Tensor], + prompt: torch.Tensor, + src_key_padding_mask: list[torch.Tensor] | None = None, + src_pos: list[torch.Tensor] | None = None, + prompt_key_padding_mask: torch.Tensor = None, + feat_sizes: list[int] | None = None, + encoder_extra_kwargs: dict | None = None, + ): + """Forward pass for the transformer encoder with text-image fusion.""" + # Restore spatial shapes of vision + bs = src[0].shape[1] # seq first + if feat_sizes is not None: + assert len(feat_sizes) == len(src) + if src_key_padding_mask is None: + src_key_padding_mask = [None] * len(src) + for i, (h, w) in enumerate(feat_sizes): + src[i] = src[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1) + src_pos[i] = src_pos[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1) + src_key_padding_mask[i] = ( + src_key_padding_mask[i].reshape(h, w, bs).permute(2, 0, 1) + if src_key_padding_mask[i] is not None + else None + ) + else: + assert all(x.dim == 4 for x in src), "expected list of (bs, c, h, w) tensors" + + if self.add_pooled_text_to_img_feat: + # Fusion: Add mean pooled text to image features + pooled_text = pool_text_feat(prompt, prompt_key_padding_mask, self.pool_text_with_mask) + pooled_text = self.text_pooling_proj(pooled_text)[..., None, None] # prompt is seq first + src = [x.add_(pooled_text) for x in src] + + ( + out, + key_padding_masks_flatten, + lvl_pos_embed_flatten, + level_start_index, + spatial_shapes, + valid_ratios, + ) = super().forward( + src, + src_key_padding_masks=src_key_padding_mask, + pos=src_pos, + prompt=prompt.transpose(0, 1), + prompt_key_padding_mask=prompt_key_padding_mask, + encoder_extra_kwargs=encoder_extra_kwargs, + ) + + return { + "memory": out, + "padding_mask": key_padding_masks_flatten, + "pos_embed": lvl_pos_embed_flatten, + "memory_text": prompt, + "level_start_index": level_start_index, + "spatial_shapes": spatial_shapes, + "valid_ratios": valid_ratios, + } + + +def pool_text_feat(prompt, prompt_mask, pool_with_mask): + """Mean-pool the prompt embeddings over the valid tokens only.""" + # prompt has shape (seq, bs, dim) + if not pool_with_mask: + return prompt.mean(dim=0) + + # prompt_mask has shape (bs, seq), where False is valid and True is padding + assert prompt_mask.dim() == 2 + # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding + is_valid = (~prompt_mask).float().permute(1, 0)[..., None] + # num_valid has shape (bs, 1) + num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0) + + # mean pool over all the valid tokens + pooled_text = (prompt * is_valid).sum(dim=0) / num_valid + return pooled_text diff --git a/ultralytics/models/sam/sam3/geometry_encoders.py b/ultralytics/models/sam/sam3/geometry_encoders.py new file mode 100644 index 0000000000..433c392d78 --- /dev/null +++ b/ultralytics/models/sam/sam3/geometry_encoders.py @@ -0,0 +1,415 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +import torch +import torch.nn as nn +import torchvision + +from ultralytics.nn.modules.utils import _get_clones +from ultralytics.utils.ops import xywh2xyxy + + +def is_right_padded(mask: torch.Tensor): + """Given a padding mask (following pytorch convention, 1s for padded values), returns whether the padding is on the + right or not. + """ + return (mask.long() == torch.sort(mask.long(), dim=-1)[0]).all() + + +def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False): + """ + Concatenates two right-padded sequences, such that the resulting sequence + is contiguous and also right-padded. + + Following pytorch's convention, tensors are sequence first, and the mask are + batch first, with 1s for padded values. + + :param seq1: A tensor of shape (seq1_length, batch_size, hidden_size). + :param mask1: A tensor of shape (batch_size, seq1_length). + :param seq2: A tensor of shape (seq2_length, batch_size, hidden_size). + :param mask2: A tensor of shape (batch_size, seq2_length). + :param return_index: If True, also returns the index of the ids of the element of seq2 + in the concatenated sequence. This can be used to retrieve the elements of seq2 + :return: A tuple (concatenated_sequence, concatenated_mask) if return_index is False, + otherwise (concatenated_sequence, concatenated_mask, index). + """ + seq1_length, batch_size, hidden_size = seq1.shape + seq2_length, batch_size, hidden_size = seq2.shape + + assert batch_size == seq1.size(1) == seq2.size(1) == mask1.size(0) == mask2.size(0) + assert hidden_size == seq1.size(2) == seq2.size(2) + assert seq1_length == mask1.size(1) + assert seq2_length == mask2.size(1) + + torch._assert_async(is_right_padded(mask1)) + torch._assert_async(is_right_padded(mask2)) + + actual_seq1_lengths = (~mask1).sum(dim=-1) + actual_seq2_lengths = (~mask2).sum(dim=-1) + + final_lengths = actual_seq1_lengths + actual_seq2_lengths + max_length = seq1_length + seq2_length + concatenated_mask = ( + torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) >= final_lengths[:, None] + ) + + # (max_len, batch_size, hidden_size) + concatenated_sequence = torch.zeros((max_length, batch_size, hidden_size), device=seq2.device, dtype=seq2.dtype) + concatenated_sequence[:seq1_length, :, :] = seq1 + + # At this point, the element of seq1 are in the right place + # We just need to shift the elements of seq2 + + index = torch.arange(seq2_length, device=seq2.device)[:, None].repeat(1, batch_size) + index = index + actual_seq1_lengths[None] + + concatenated_sequence = concatenated_sequence.scatter(0, index[:, :, None].expand(-1, -1, hidden_size), seq2) + + if return_index: + return concatenated_sequence, concatenated_mask, index + + return concatenated_sequence, concatenated_mask + + +class Prompt: + """Utility class to manipulate geometric prompts. + + We expect the sequences in pytorch convention, that is sequence first, batch second The dimensions are expected as + follows: box_embeddings shape: N_boxes x B x C_box box_mask shape: B x N_boxes. Can be None if nothing is masked out + point_embeddings shape: N_points x B x C_point point_mask shape: B x N_points. Can be None if nothing is masked out + mask_embeddings shape: N_masks x B x 1 x H_mask x W_mask mask_mask shape: B x N_masks. Can be None if nothing is + masked out + + We also store positive/negative labels. These tensors are also stored batch-first If they are None, we'll assume + positive labels everywhere box_labels: long tensor of shape N_boxes x B point_labels: long tensor of shape N_points + x B mask_labels: long tensor of shape N_masks x B + """ + + def __init__(self, box_embeddings=None, box_mask=None, box_labels=None): + """Initialize the Prompt object.""" + # Check for null prompt + # Check for null prompt + if box_embeddings is None: + self.box_embeddings = None + self.box_labels = None + self.box_mask = None + return + + # Get sequence length, batch size, and device + box_seq_len = box_embeddings.shape[0] + bs = box_embeddings.shape[1] + device = box_embeddings.device + + # Initialize labels and attention mask if not provided + if box_labels is None: + box_labels = torch.ones(box_seq_len, bs, device=device, dtype=torch.long) + if box_mask is None: + box_mask = torch.zeros(bs, box_seq_len, device=device, dtype=torch.bool) + + # Dimension checks + assert list(box_embeddings.shape[:2]) == [box_seq_len, bs], ( + f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}" + ) + assert box_embeddings.shape[-1] == 4, ( + f"Expected box embeddings to have 4 coordinates, got {box_embeddings.shape[-1]}" + ) + assert list(box_mask.shape) == [bs, box_seq_len], ( + f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}" + ) + assert list(box_labels.shape) == [box_seq_len, bs], ( + f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}" + ) + + # Device checks + assert box_embeddings.device == device, ( + f"Expected box embeddings to be on device {device}, got {box_embeddings.device}" + ) + assert box_mask.device == device, f"Expected box mask to be on device {device}, got {box_mask.device}" + assert box_labels.device == device, f"Expected box labels to be on device {device}, got {box_labels.device}" + + self.box_embeddings = box_embeddings + self.box_mask = box_mask + self.box_labels = box_labels + + def append_boxes(self, boxes, labels=None, mask=None): + """Append box prompts to existing prompts. + + Args: + boxes: Tensor of shape (N_new_boxes, B, 4) with normalized box coordinates + labels: Optional tensor of shape (N_new_boxes, B) with positive/negative labels + mask: Optional tensor of shape (B, N_new_boxes) for attention mask + """ + if self.box_embeddings is None: + # First boxes - initialize + self.box_embeddings = boxes + bs = boxes.shape[1] + box_seq_len = boxes.shape[0] + + if labels is None: + labels = torch.ones(box_seq_len, bs, device=boxes.device, dtype=torch.long) + if mask is None: + mask = torch.zeros(bs, box_seq_len, device=boxes.device, dtype=torch.bool) + + self.box_labels = labels + self.box_mask = mask + return + + # Append to existing boxes + bs = self.box_embeddings.shape[1] + assert boxes.shape[1] == bs, f"Batch size mismatch: expected {bs}, got {boxes.shape[1]}" + + if labels is None: + labels = torch.ones(boxes.shape[0], bs, device=boxes.device, dtype=torch.long) + if mask is None: + mask = torch.zeros(bs, boxes.shape[0], dtype=torch.bool, device=boxes.device) + + assert list(boxes.shape[:2]) == list(labels.shape[:2]), ( + f"Shape mismatch between boxes {boxes.shape} and labels {labels.shape}" + ) + + # Concatenate using the helper function + self.box_labels, _ = concat_padded_sequences( + self.box_labels.unsqueeze(-1), self.box_mask, labels.unsqueeze(-1), mask + ) + self.box_labels = self.box_labels.squeeze(-1) + self.box_embeddings, self.box_mask = concat_padded_sequences(self.box_embeddings, self.box_mask, boxes, mask) + + +class SequenceGeometryEncoder(nn.Module): + """Encoder for geometric box prompts. Assumes boxes are passed in the "normalized CxCyWH" format. + + Boxes can be encoded with any of the three possibilities: + - direct projection: linear projection from coordinate space to d_model + - pooling: RoI align features from the backbone + - pos encoder: position encoding of the box center + + These three options are mutually compatible and will be summed if multiple are selected. + + As an alternative, boxes can be encoded as two corner points (top-left and bottom-right). + + The encoded sequence can be further processed with a transformer. + """ + + def __init__( + self, + encode_boxes_as_points: bool, + boxes_direct_project: bool, + boxes_pool: bool, + boxes_pos_enc: bool, + d_model: int, + pos_enc, + num_layers: int, + layer: nn.Module, + roi_size: int = 7, + add_cls: bool = True, + add_post_encode_proj: bool = True, + use_act_ckpt: bool = False, + ): + """Initialize the SequenceGeometryEncoder.""" + super().__init__() + + self.d_model = d_model + self.pos_enc = pos_enc + self.encode_boxes_as_points = encode_boxes_as_points + self.roi_size = roi_size + + # Label embeddings: 2 labels if encoding as boxes (pos/neg) + # 6 labels if encoding as points (regular pos/neg, top-left pos/neg, bottom-right pos/neg) + num_labels = 6 if self.encode_boxes_as_points else 2 + self.label_embed = torch.nn.Embedding(num_labels, self.d_model) + + # CLS token for pooling + self.cls_embed = None + if add_cls: + self.cls_embed = torch.nn.Embedding(1, self.d_model) + + # Point encoding (used when encode_boxes_as_points is True) + if encode_boxes_as_points: + self.points_direct_project = nn.Linear(2, self.d_model) + self.points_pool_project = None + self.points_pos_enc_project = None + else: + # Box encoding modules + assert boxes_direct_project or boxes_pos_enc or boxes_pool, "Error: need at least one way to encode boxes" + self.points_direct_project = None + self.points_pool_project = None + self.points_pos_enc_project = None + + self.boxes_direct_project = None + self.boxes_pool_project = None + self.boxes_pos_enc_project = None + + if boxes_direct_project: + self.boxes_direct_project = nn.Linear(4, self.d_model) + if boxes_pool: + self.boxes_pool_project = nn.Conv2d(self.d_model, self.d_model, self.roi_size) + if boxes_pos_enc: + self.boxes_pos_enc_project = nn.Linear(self.d_model + 2, self.d_model) + + self.final_proj = None + if add_post_encode_proj: + self.final_proj = nn.Linear(self.d_model, self.d_model) + self.norm = nn.LayerNorm(self.d_model) + + self.img_pre_norm = nn.Identity() + if self.points_pool_project is not None or self.boxes_pool_project is not None: + self.img_pre_norm = nn.LayerNorm(self.d_model) + + self.encode = None + if num_layers > 0: + assert add_cls, "It's currently highly recommended to add a CLS when using a transformer" + self.encode = _get_clones(layer, num_layers) + self.encode_norm = nn.LayerNorm(self.d_model) + + self.use_act_ckpt = use_act_ckpt + + def _encode_points(self, points, points_mask, points_labels, img_feats): + """Encode points (used when boxes are converted to corner points).""" + # Direct projection of coordinates + points_embed = self.points_direct_project(points.to(img_feats.dtype)) + + # Add label embeddings + type_embed = self.label_embed(points_labels.long()) + return type_embed + points_embed, points_mask + + def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats: torch.Tensor): + """Encode boxes using configured encoding methods.""" + boxes_embed = None + n_boxes, bs = boxes.shape[:2] + + if self.boxes_direct_project is not None: + proj = self.boxes_direct_project(boxes.to(img_feats.dtype)) + boxes_embed = proj + + if self.boxes_pool_project is not None: + H, W = img_feats.shape[-2:] + + # Convert boxes to xyxy format and denormalize + boxes_xyxy = xywh2xyxy(boxes.to(img_feats.dtype)) + scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype) + scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True) + scale = scale.view(1, 1, 4) + boxes_xyxy = boxes_xyxy * scale + + # RoI align + sampled = torchvision.ops.roi_align(img_feats, boxes_xyxy.transpose(0, 1).unbind(0), self.roi_size) + assert list(sampled.shape) == [ + bs * n_boxes, + self.d_model, + self.roi_size, + self.roi_size, + ] + proj = self.boxes_pool_project(sampled) + proj = proj.view(bs, n_boxes, self.d_model).transpose(0, 1) + + if boxes_embed is None: + boxes_embed = proj + else: + boxes_embed = boxes_embed + proj + + if self.boxes_pos_enc_project is not None: + cx, cy, w, h = boxes.unbind(-1) + enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten()) + enc = enc.view(boxes.shape[0], boxes.shape[1], enc.shape[-1]) + + proj = self.boxes_pos_enc_project(enc.to(img_feats.dtype)) + if boxes_embed is None: + boxes_embed = proj + else: + boxes_embed = boxes_embed + proj + + # Add label embeddings + type_embed = self.label_embed(boxes_labels.long()) + return type_embed + boxes_embed, boxes_mask + + def forward(self, geo_prompt: Prompt, img_feats, img_sizes, img_pos_embeds=None): + """Encode geometric box prompts. + + Args: + geo_prompt: Prompt object containing box embeddings, masks, and labels + img_feats: List of image features from backbone + img_sizes: List of (H, W) tuples for each feature level + img_pos_embeds: Optional position embeddings for image features + + Returns: + Tuple of (encoded_embeddings, attention_mask) + """ + boxes = geo_prompt.box_embeddings + boxes_mask = geo_prompt.box_mask + boxes_labels = geo_prompt.box_labels + + seq_first_img_feats = img_feats[-1] # [H*W, B, C] + seq_first_img_pos_embeds = ( + img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(seq_first_img_feats) + ) + + # Prepare image features for pooling if needed + if self.points_pool_project or self.boxes_pool_project: + assert len(img_feats) == len(img_sizes) + cur_img_feat = img_feats[-1] + cur_img_feat = self.img_pre_norm(cur_img_feat) + H, W = img_sizes[-1] + assert cur_img_feat.shape[0] == H * W + N, C = cur_img_feat.shape[-2:] + # Reshape to NxCxHxW + cur_img_feat = cur_img_feat.permute(1, 2, 0) + cur_img_feat = cur_img_feat.view(N, C, H, W) + img_feats = cur_img_feat + + if self.encode_boxes_as_points: + # Convert boxes to corner points + assert boxes is not None and boxes.shape[-1] == 4 + + boxes_xyxy = xywh2xyxy(boxes) + top_left, bottom_right = boxes_xyxy.split(split_size=2, dim=-1) + + # Adjust labels for corner points (offset by 2 and 4) + labels_tl = boxes_labels + 2 + labels_br = boxes_labels + 4 + + # Concatenate top-left and bottom-right points + points = torch.cat([top_left, bottom_right], dim=0) + points_labels = torch.cat([labels_tl, labels_br], dim=0) + points_mask = torch.cat([boxes_mask, boxes_mask], dim=1) + + final_embeds, final_mask = self._encode_points( + points=points, + points_mask=points_mask, + points_labels=points_labels, + img_feats=img_feats, + ) + else: + # Encode boxes directly + final_embeds, final_mask = self._encode_boxes( + boxes=boxes, + boxes_mask=boxes_mask, + boxes_labels=boxes_labels, + img_feats=img_feats, + ) + + bs = final_embeds.shape[1] + assert final_mask.shape[0] == bs + + # Add CLS token if configured + if self.cls_embed is not None: + cls = self.cls_embed.weight.view(1, 1, self.d_model).repeat(1, bs, 1) + cls_mask = torch.zeros(bs, 1, dtype=final_mask.dtype, device=final_mask.device) + final_embeds, final_mask = concat_padded_sequences(final_embeds, final_mask, cls, cls_mask) + + # Final projection + if self.final_proj is not None: + final_embeds = self.norm(self.final_proj(final_embeds)) + + # Transformer encoding layers + if self.encode is not None: + for lay in self.encode: + final_embeds = lay( + tgt=final_embeds, + memory=seq_first_img_feats, + tgt_key_padding_mask=final_mask, + pos=seq_first_img_pos_embeds, + ) + final_embeds = self.encode_norm(final_embeds) + + return final_embeds, final_mask diff --git a/ultralytics/models/sam/sam3/maskformer_segmentation.py b/ultralytics/models/sam/sam3/maskformer_segmentation.py new file mode 100644 index 0000000000..f91fd5b1bc --- /dev/null +++ b/ultralytics/models/sam/sam3/maskformer_segmentation.py @@ -0,0 +1,286 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from ultralytics.nn.modules.transformer import MLP + + +class LinearPresenceHead(nn.Sequential): + """Linear presence head for predicting the presence of classes in an image.""" + + def __init__(self, d_model): + """Initializes the LinearPresenceHead.""" + # a hack to make `LinearPresenceHead` compatible with old checkpoints + super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1)) + + def forward(self, hs, prompt, prompt_mask): + """Forward pass of the presence head.""" + return super().forward(hs) + + +class MaskPredictor(nn.Module): + """Predicts masks from object queries and pixel embeddings.""" + + def __init__(self, hidden_dim, mask_dim): + """Initializes the MaskPredictor.""" + super().__init__() + self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) + + def forward(self, obj_queries, pixel_embed): + """Predicts masks from object queries and pixel embeddings.""" + if len(obj_queries.shape) == 3: + if pixel_embed.ndim == 3: + # batch size was omitted + mask_preds = torch.einsum("bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed) + else: + mask_preds = torch.einsum("bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed) + else: + # Assumed to have aux masks + if pixel_embed.ndim == 3: + # batch size was omitted + mask_preds = torch.einsum("lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed) + else: + mask_preds = torch.einsum("lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed) + + return mask_preds + + +class SegmentationHead(nn.Module): + """Segmentation head that predicts masks from backbone features and object queries.""" + + def __init__( + self, + hidden_dim, + upsampling_stages, + use_encoder_inputs=False, + aux_masks=False, + no_dec=False, + pixel_decoder=None, + act_ckpt=False, + shared_conv=False, + compile_mode_pixel_decoder=None, + ): + """Initializes the SegmentationHead.""" + super().__init__() + self.use_encoder_inputs = use_encoder_inputs + self.aux_masks = aux_masks + if pixel_decoder is not None: + self.pixel_decoder = pixel_decoder + else: + self.pixel_decoder = PixelDecoder( + hidden_dim, + upsampling_stages, + shared_conv=shared_conv, + compile_mode=compile_mode_pixel_decoder, + ) + self.no_dec = no_dec + if no_dec: + self.mask_predictor = nn.Conv2d(hidden_dim, 1, kernel_size=3, stride=1, padding=1) + else: + self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim) + + self.act_ckpt = act_ckpt + + # used to update the output dictionary + self.instance_keys = ["pred_masks"] + + def _embed_pixels(self, backbone_feats: list[torch.Tensor], encoder_hidden_states) -> torch.Tensor: + """Embeds pixels using the pixel decoder.""" + if self.use_encoder_inputs: + backbone_visual_feats = [bb_feat.clone() for bb_feat in backbone_feats] + # Extract visual embeddings + encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0) + spatial_dim = math.prod(backbone_feats[-1].shape[-2:]) + encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape(-1, *backbone_feats[-1].shape[1:]) + + backbone_visual_feats[-1] = encoder_visual_embed + if self.act_ckpt: + pixel_embed = checkpoint.checkpoint(self.pixel_decoder, backbone_visual_feats, use_reentrant=False) + else: + pixel_embed = self.pixel_decoder(backbone_visual_feats) + else: + backbone_feats = [x for x in backbone_feats] + pixel_embed = self.pixel_decoder(backbone_feats) + if pixel_embed.shape[0] == 1: + # For batch_size=1 training, we can avoid the indexing to save memory + pixel_embed = pixel_embed.squeeze(0) + else: + pixel_embed = pixel_embed[[0], ...] + return pixel_embed + + def forward( + self, + backbone_feats: list[torch.Tensor], + obj_queries: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + **kwargs, + ) -> dict[str, torch.Tensor]: + """Forward pass of the SegmentationHead.""" + if self.use_encoder_inputs: + assert encoder_hidden_states is not None + + pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states) + + if self.no_dec: + mask_pred = self.mask_predictor(pixel_embed) + elif self.aux_masks: + mask_pred = self.mask_predictor(obj_queries, pixel_embed) + else: + mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed) + + return {"pred_masks": mask_pred} + + +class PixelDecoder(nn.Module): + """Pixel decoder module that upsamples backbone features.""" + + def __init__( + self, + hidden_dim, + num_upsampling_stages, + interpolation_mode="nearest", + shared_conv=False, + compile_mode=None, + ): + """Initializes the PixelDecoder.""" + super().__init__() + self.hidden_dim = hidden_dim + self.num_upsampling_stages = num_upsampling_stages + self.interpolation_mode = interpolation_mode + conv_layers = [] + norms = [] + num_convs = 1 if shared_conv else num_upsampling_stages + for _ in range(num_convs): + conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1)) + norms.append(nn.GroupNorm(8, self.hidden_dim)) + + self.conv_layers = nn.ModuleList(conv_layers) + self.norms = nn.ModuleList(norms) + self.shared_conv = shared_conv + self.out_dim = self.conv_layers[-1].out_channels + if compile_mode is not None: + self.forward = torch.compile(self.forward, mode=compile_mode, dynamic=True, fullgraph=True) + # Needed to make checkpointing happy. But we don't know if the module is checkpointed, so we disable it by default. + torch._dynamo.config.optimize_ddp = False + + def forward(self, backbone_feats: list[torch.Tensor]): + """Forward pass of the PixelDecoder.""" + prev_fpn = backbone_feats[-1] + fpn_feats = backbone_feats[:-1] + for layer_idx, bb_feat in enumerate(fpn_feats[::-1]): + curr_fpn = bb_feat + prev_fpn = curr_fpn + F.interpolate(prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode) + if self.shared_conv: + # only one conv layer + layer_idx = 0 + prev_fpn = self.conv_layers[layer_idx](prev_fpn) + prev_fpn = F.relu(self.norms[layer_idx](prev_fpn)) + + return prev_fpn + + +class UniversalSegmentationHead(SegmentationHead): + """This module handles semantic+instance segmentation.""" + + def __init__( + self, + hidden_dim, + upsampling_stages, + pixel_decoder, + aux_masks=False, + no_dec=False, + act_ckpt=False, + presence_head: bool = False, + dot_product_scorer=None, + cross_attend_prompt=None, + ): + """Initializes the UniversalSegmentationHead.""" + super().__init__( + hidden_dim=hidden_dim, + upsampling_stages=upsampling_stages, + use_encoder_inputs=True, + aux_masks=aux_masks, + no_dec=no_dec, + pixel_decoder=pixel_decoder, + act_ckpt=act_ckpt, + ) + self.d_model = hidden_dim + + if dot_product_scorer is not None: + assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake" + + self.presence_head = None + if presence_head: + self.presence_head = ( + dot_product_scorer if dot_product_scorer is not None else LinearPresenceHead(self.d_model) + ) + + self.cross_attend_prompt = cross_attend_prompt + if self.cross_attend_prompt is not None: + self.cross_attn_norm = nn.LayerNorm(self.d_model) + + self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1) + self.instance_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, self.d_model, kernel_size=1) + + def forward( + self, + backbone_feats: list[torch.Tensor], + obj_queries: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + prompt: torch.Tensor = None, + prompt_mask: torch.Tensor = None, + **kwargs, + ) -> dict[str, torch.Tensor]: + """Forward pass of the UniversalSegmentationHead.""" + assert encoder_hidden_states is not None + bs = encoder_hidden_states.shape[1] + + if self.cross_attend_prompt is not None: + tgt2 = self.cross_attn_norm(encoder_hidden_states) + tgt2 = self.cross_attend_prompt( + query=tgt2, + key=prompt.to(tgt2.dtype), + value=prompt.to(tgt2.dtype), + key_padding_mask=prompt_mask, + need_weights=False, + )[0] + encoder_hidden_states = tgt2 + encoder_hidden_states + + presence_logit = None + if self.presence_head is not None: + pooled_enc = encoder_hidden_states.mean(0) + presence_logit = ( + self.presence_head( + pooled_enc.view(1, bs, 1, self.d_model), + prompt=prompt, + prompt_mask=prompt_mask, + ) + .squeeze(0) + .squeeze(1) + ) + + pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states) + + instance_embeds = self.instance_seg_head(pixel_embed) + + if self.no_dec: + mask_pred = self.mask_predictor(instance_embeds) + elif self.aux_masks: + mask_pred = self.mask_predictor(obj_queries, instance_embeds) + else: + mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds) + + return { + "pred_masks": mask_pred, + "semantic_seg": self.semantic_seg_head(pixel_embed), + "presence_logit": presence_logit, + } diff --git a/ultralytics/models/sam/sam3/model_misc.py b/ultralytics/models/sam/sam3/model_misc.py new file mode 100644 index 0000000000..1b66b05a5f --- /dev/null +++ b/ultralytics/models/sam/sam3/model_misc.py @@ -0,0 +1,198 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +"""Various utility models.""" + +from __future__ import annotations + +import math + +import numpy as np +import torch +from torch import Tensor, nn + + +class DotProductScoring(torch.nn.Module): + """A module that computes dot-product scores between a set of query features and a.""" + + def __init__( + self, + d_model, + d_proj, + prompt_mlp=None, + clamp_logits=True, + clamp_max_val=12.0, + ): + """Initialize the DotProductScoring module.""" + super().__init__() + self.d_proj = d_proj + assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None + self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt + self.prompt_proj = torch.nn.Linear(d_model, d_proj) + self.hs_proj = torch.nn.Linear(d_model, d_proj) + self.scale = float(1.0 / np.sqrt(d_proj)) + self.clamp_logits = clamp_logits + if self.clamp_logits: + self.clamp_max_val = clamp_max_val + + def mean_pool_text(self, prompt, prompt_mask): + """Mean-pool the prompt embeddings over the valid tokens only.""" + # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding + is_valid = (~prompt_mask).to(prompt.dtype).permute(1, 0)[..., None] + # num_valid has shape (bs, 1) + num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0) + # mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim) + pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid + return pooled_prompt + + def forward(self, hs, prompt, prompt_mask): + """Compute dot-product scores between hs and prompt.""" + # hs has shape (num_layer, bs, num_query, d_model) + # prompt has shape (seq, bs, d_model) + # prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding + assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2 + + # apply MLP on prompt if specified + if self.prompt_mlp is not None: + prompt = self.prompt_mlp(prompt.to(hs.dtype)) + + # first, get the mean-pooled version of the prompt + pooled_prompt = self.mean_pool_text(prompt, prompt_mask) + + # then, project pooled_prompt and hs to d_proj dimensions + proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj) + proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj) + + # finally, get dot-product scores of shape (num_layer, bs, num_query, 1) + scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1)) + scores *= self.scale + + # clamp scores to a max value to avoid numerical issues in loss or matcher + if self.clamp_logits: + scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val) + + return scores + + +class LayerScale(nn.Module): + """LayerScale module as introduced in "Meta Pseudo Labels" and used in.""" + + def __init__( + self, + dim: int, + init_values: float | Tensor = 1e-5, + inplace: bool = False, + ) -> None: + """Initialize the LayerScale module.""" + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + """Apply LayerScale to the input tensor.""" + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class TransformerWrapper(nn.Module): + """A wrapper for the transformer consisting of an encoder and a decoder.""" + + def __init__( + self, + encoder, + decoder, + d_model: int, + two_stage_type="none", # ["none"] only for now + pos_enc_at_input_dec=True, + ): + """Initialize the TransformerWrapper.""" + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.num_queries = decoder.num_queries if decoder is not None else None + self.pos_enc_at_input_dec = pos_enc_at_input_dec + + # for two stage + assert two_stage_type in ["none"], f"unknown param {two_stage_type} of two_stage_type" + self.two_stage_type = two_stage_type + + self._reset_parameters() + self.d_model = d_model + + def _reset_parameters(self): + """Initialize the parameters of the model.""" + for n, p in self.named_parameters(): + if p.dim() > 1: + if "box_embed" not in n and "query_embed" not in n and "reference_points" not in n: + nn.init.xavier_uniform_(p) + + +def get_valid_ratio(mask): + """Compute the valid ratio of height and width from the mask.""" + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + +def gen_sineembed_for_position(pos_tensor: torch.Tensor, num_feats: int = 256): + """Generate sinusoidal position embeddings for 2D or 4D coordinate tensors. + + This function creates sinusoidal embeddings using sine and cosine functions at different frequencies, similar to the + positional encoding used in Transformer models. It supports both 2D position tensors (x, y) and 4D tensors (x, y, w, + h) for bounding box coordinates. + + Args: + pos_tensor (torch.Tensor): Input position tensor of shape (n_query, bs, 2) for 2D coordinates or (n_query, bs, + 4) for 4D coordinates (bounding boxes). + num_feats (int): Number of feature dimensions for the output embedding. Must be even. Defaults to 256. + + Returns: + (torch.Tensor): Sinusoidal position embeddings of shape (n_query, bs, num_feats) for 2D input or (n_query, bs, + num_feats * 2) for 4D input. + + Raises: + AssertionError: If num_feats is not even. + ValueError: If pos_tensor.size(-1) is not 2 or 4. + + Examples: + >>> pos_2d = torch.rand(100, 8, 2) # 100 queries, batch size 8, 2D coordinates + >>> embeddings_2d = gen_sineembed_for_position(pos_2d, num_feats=256) + >>> embeddings_2d.shape + torch.Size([100, 8, 256]) + >>> pos_4d = torch.rand(50, 4, 4) # 50 queries, batch size 4, 4D coordinates + >>> embeddings_4d = gen_sineembed_for_position(pos_4d, num_feats=128) + >>> embeddings_4d.shape + torch.Size([50, 4, 256]) + """ + assert num_feats % 2 == 0 + num_feats = num_feats // 2 + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + scale = 2 * math.pi + dim_t = torch.arange(num_feats, dtype=pos_tensor.dtype, device=pos_tensor.device) + dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}") + return pos diff --git a/ultralytics/models/sam/sam3/necks.py b/ultralytics/models/sam/sam3/necks.py new file mode 100644 index 0000000000..db2036ffde --- /dev/null +++ b/ultralytics/models/sam/sam3/necks.py @@ -0,0 +1,129 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +"""Necks are the interface between a vision backbone and the rest of the detection model.""" + +from __future__ import annotations + +from copy import deepcopy + +import torch +import torch.nn as nn + + +class Sam3DualViTDetNeck(nn.Module): + """A neck that implements a simple FPN as in ViTDet, with support for dual necks (for SAM3 and SAM2).""" + + def __init__( + self, + trunk: nn.Module, + position_encoding: nn.Module, + d_model: int, + scale_factors=(4.0, 2.0, 1.0, 0.5), + add_sam2_neck: bool = False, + ): + """ + SimpleFPN neck a la ViTDet + (From detectron2, very lightly adapted) + It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights. + + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + """ + super().__init__() + self.trunk = trunk + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + + self.scale_factors = scale_factors + use_bias = True + dim: int = self.trunk.channel_list[-1] + + for _, scale in enumerate(scale_factors): + current = nn.Sequential() + + if scale == 4.0: + current.add_module( + "dconv_2x2_0", + nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), + ) + current.add_module( + "gelu", + nn.GELU(), + ) + current.add_module( + "dconv_2x2_1", + nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), + ) + out_dim = dim // 4 + elif scale == 2.0: + current.add_module( + "dconv_2x2", + nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), + ) + out_dim = dim // 2 + elif scale == 1.0: + out_dim = dim + elif scale == 0.5: + current.add_module( + "maxpool_2x2", + nn.MaxPool2d(kernel_size=2, stride=2), + ) + out_dim = dim + else: + raise NotImplementedError(f"scale_factor={scale} is not supported yet.") + + current.add_module( + "conv_1x1", + nn.Conv2d( + in_channels=out_dim, + out_channels=d_model, + kernel_size=1, + bias=use_bias, + ), + ) + current.add_module( + "conv_3x3", + nn.Conv2d( + in_channels=d_model, + out_channels=d_model, + kernel_size=3, + padding=1, + bias=use_bias, + ), + ) + self.convs.append(current) + + self.sam2_convs = None + if add_sam2_neck: + # Assumes sam2 neck is just a clone of the original neck + self.sam2_convs = deepcopy(self.convs) + + def forward( + self, tensor_list: list[torch.Tensor] + ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: + """Get the feature maps and positional encodings from the neck.""" + xs = self.trunk(tensor_list) + sam3_out, sam3_pos = [], [] + sam2_out, sam2_pos = None, None + if self.sam2_convs is not None: + sam2_out, sam2_pos = [], [] + x = xs[-1] # simpleFPN + for i in range(len(self.convs)): + sam3_x_out = self.convs[i](x) + sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype) + sam3_out.append(sam3_x_out) + sam3_pos.append(sam3_pos_out) + + if self.sam2_convs is not None: + sam2_x_out = self.sam2_convs[i](x) + sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype) + sam2_out.append(sam2_x_out) + sam2_pos.append(sam2_pos_out) + return sam3_out, sam3_pos, sam2_out, sam2_pos + + def set_imgsz(self, imgsz: list[int] = [1008, 1008]): + """Set the image size for the trunk backbone.""" + self.trunk.set_imgsz(imgsz) diff --git a/ultralytics/models/sam/sam3/sam3_image.py b/ultralytics/models/sam/sam3/sam3_image.py new file mode 100644 index 0000000000..c8bccc92f4 --- /dev/null +++ b/ultralytics/models/sam/sam3/sam3_image.py @@ -0,0 +1,357 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +from __future__ import annotations + +from copy import deepcopy + +import torch + +from ultralytics.nn.modules.utils import inverse_sigmoid +from ultralytics.utils.ops import xywh2xyxy + +from .geometry_encoders import Prompt +from .vl_combiner import SAM3VLBackbone + + +def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True): + """Helper function to update output dictionary with main and auxiliary outputs.""" + out[out_name] = out_value[-1] if auxiliary else out_value + if auxiliary and update_aux: + if "aux_outputs" not in out: + out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)] + assert len(out["aux_outputs"]) == len(out_value) - 1 + for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]): + aux_output[out_name] = aux_value + + +class SAM3SemanticModel(torch.nn.Module): + """SAM3 model for semantic segmentation with vision-language backbone.""" + + def __init__( + self, + backbone: SAM3VLBackbone, + transformer, + input_geometry_encoder, + segmentation_head=None, + num_feature_levels=1, + o2m_mask_predict=True, + dot_prod_scoring=None, + use_instance_query: bool = True, + multimask_output: bool = True, + use_act_checkpoint_seg_head: bool = True, + matcher=None, + use_dot_prod_scoring=True, + supervise_joint_box_scores: bool = False, # only relevant if using presence token/score + detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score + separate_scorer_for_instance: bool = False, + num_interactive_steps_val: int = 0, + ): + """Initialize the SAM3SemanticModel.""" + super().__init__() + self.backbone = backbone + self.geometry_encoder = input_geometry_encoder + self.transformer = transformer + self.hidden_dim = transformer.d_model + self.num_feature_levels = num_feature_levels + self.segmentation_head = segmentation_head + + self.o2m_mask_predict = o2m_mask_predict + + self.dot_prod_scoring = dot_prod_scoring + self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head + self.matcher = matcher + + self.num_interactive_steps_val = num_interactive_steps_val + self.use_dot_prod_scoring = use_dot_prod_scoring + + if self.use_dot_prod_scoring: + assert dot_prod_scoring is not None + self.dot_prod_scoring = dot_prod_scoring + self.instance_dot_prod_scoring = None + if separate_scorer_for_instance: + self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring) + else: + self.class_embed = torch.nn.Linear(self.hidden_dim, 1) + self.instance_class_embed = None + if separate_scorer_for_instance: + self.instance_class_embed = deepcopy(self.class_embed) + + self.supervise_joint_box_scores = supervise_joint_box_scores + self.detach_presence_in_joint_score = detach_presence_in_joint_score + + # verify the number of queries for O2O and O2M + num_o2o_static = self.transformer.decoder.num_queries + num_o2m_static = self.transformer.decoder.num_o2m_queries + assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0) + self.dac = self.transformer.decoder.dac + + self.use_instance_query = use_instance_query + self.multimask_output = multimask_output + + self.text_embeddings = {} + self.names = [] + + def _prepare_backbone_features(self, backbone_out, num_prompts=1): + """Prepare and flatten visual features from the image backbone output for further processing.""" + if num_prompts > 1: # expand features if there's more than one prompt + for i, feat in enumerate(backbone_out["backbone_fpn"]): + backbone_out["backbone_fpn"][i] = feat.expand(num_prompts, -1, -1, -1) + for i, pos in enumerate(backbone_out["vision_pos_enc"]): + pos = pos.expand(num_prompts, -1, -1, -1) + backbone_out["vision_pos_enc"][i] = pos + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _encode_prompt( + self, + img_feats, + img_pos_embeds, + vis_feat_sizes, + geometric_prompt, + visual_prompt_embed=None, + visual_prompt_mask=None, + prev_mask_pred=None, + ): + """Encode the geometric and visual prompts.""" + if prev_mask_pred is not None: + img_feats = [img_feats[-1] + prev_mask_pred] + # Encode geometry + geo_feats, geo_masks = self.geometry_encoder( + geo_prompt=geometric_prompt, + img_feats=img_feats, + img_sizes=vis_feat_sizes, + img_pos_embeds=img_pos_embeds, + ) + if visual_prompt_embed is None: + visual_prompt_embed = torch.zeros((0, *geo_feats.shape[1:]), device=geo_feats.device) + visual_prompt_mask = torch.zeros( + (*geo_masks.shape[:-1], 0), + device=geo_masks.device, + dtype=geo_masks.dtype, + ) + prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0) + prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1) + return prompt, prompt_mask + + def _run_encoder( + self, + img_feats, + img_pos_embeds, + vis_feat_sizes, + prompt, + prompt_mask, + encoder_extra_kwargs: dict | None = None, + ): + """Run the transformer encoder.""" + # Run the encoder + # make a copy of the image feature lists since the encoder may modify these lists in-place + memory = self.transformer.encoder( + src=img_feats.copy(), + src_key_padding_mask=None, + src_pos=img_pos_embeds.copy(), + prompt=prompt, + prompt_key_padding_mask=prompt_mask, + feat_sizes=vis_feat_sizes, + encoder_extra_kwargs=encoder_extra_kwargs, + ) + encoder_out = { + # encoded image features + "encoder_hidden_states": memory["memory"], + "pos_embed": memory["pos_embed"], + "padding_mask": memory["padding_mask"], + "spatial_shapes": memory["spatial_shapes"], + "valid_ratios": memory["valid_ratios"], + "vis_feat_sizes": vis_feat_sizes, + # encoded text features (or other prompts) + "prompt_before_enc": prompt, + "prompt_after_enc": memory.get("memory_text", prompt), + "prompt_mask": prompt_mask, + } + return encoder_out + + def _run_decoder( + self, + pos_embed, + memory, + src_mask, + out, + prompt, + prompt_mask, + encoder_out, + ): + """Run the transformer decoder.""" + bs = memory.shape[1] + query_embed = self.transformer.decoder.query_embed.weight + tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) + + hs, reference_boxes, dec_presence_out, _ = self.transformer.decoder( + tgt=tgt, + memory=memory, + memory_key_padding_mask=src_mask, + pos=pos_embed, + reference_boxes=None, + spatial_shapes=encoder_out["spatial_shapes"], + valid_ratios=encoder_out["valid_ratios"], + tgt_mask=None, + memory_text=prompt, + text_attention_mask=prompt_mask, + apply_dac=False, + ) + hs = hs.transpose(1, 2) # seq-first to batch-first + reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first + if dec_presence_out is not None: + # seq-first to batch-first + dec_presence_out = dec_presence_out.transpose(1, 2) + self._update_scores_and_boxes( + out, + hs, + reference_boxes, + prompt, + prompt_mask, + dec_presence_out=dec_presence_out, + ) + return out, hs + + def _update_scores_and_boxes( + self, + out, + hs, + reference_boxes, + prompt, + prompt_mask, + dec_presence_out=None, + is_instance_prompt=False, + ): + """Update output dict with class scores and box predictions.""" + num_o2o = hs.size(2) + # score prediction + if self.use_dot_prod_scoring: + dot_prod_scoring_head = self.dot_prod_scoring + if is_instance_prompt and self.instance_dot_prod_scoring is not None: + dot_prod_scoring_head = self.instance_dot_prod_scoring + outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask) + else: + class_embed_head = self.class_embed + if is_instance_prompt and self.instance_class_embed is not None: + class_embed_head = self.instance_class_embed + outputs_class = class_embed_head(hs) + + # box prediction + box_head = self.transformer.decoder.bbox_embed + if is_instance_prompt and self.transformer.decoder.instance_bbox_embed is not None: + box_head = self.transformer.decoder.instance_bbox_embed + anchor_box_offsets = box_head(hs) + reference_boxes_inv_sig = inverse_sigmoid(reference_boxes) + outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid() + outputs_boxes_xyxy = xywh2xyxy(outputs_coord) + + if dec_presence_out is not None: + _update_out(out, "presence_logit_dec", dec_presence_out, update_aux=False) + + if self.supervise_joint_box_scores: + assert dec_presence_out is not None + prob_dec_presence_out = dec_presence_out.clone().sigmoid() + if self.detach_presence_in_joint_score: + prob_dec_presence_out = prob_dec_presence_out.detach() + + outputs_class = inverse_sigmoid(outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)).clamp( + min=-10.0, max=10.0 + ) + + _update_out(out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=False) + _update_out(out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=False) + _update_out(out, "pred_boxes_xyxy", outputs_boxes_xyxy[:, :, :num_o2o], update_aux=False) + + def _run_segmentation_heads( + self, + out, + backbone_out, + encoder_hidden_states, + prompt, + prompt_mask, + hs, + ): + """Run segmentation heads and get masks.""" + if self.segmentation_head is not None: + num_o2o = hs.size(2) + obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o] + seg_head_outputs = self.segmentation_head( + backbone_feats=backbone_out["backbone_fpn"], + obj_queries=obj_queries, + encoder_hidden_states=encoder_hidden_states, + prompt=prompt, + prompt_mask=prompt_mask, + ) + for k, v in seg_head_outputs.items(): + if k in self.segmentation_head.instance_keys: + _update_out(out, k, v[:, :num_o2o], auxiliary=False) + else: + out[k] = v + else: + backbone_out.pop("backbone_fpn", None) + + def forward_grounding( + self, backbone_out: dict[str, torch.Tensor], text_ids: torch.Tensor, geometric_prompt: Prompt = None + ): + """Forward pass for grounding (detection + segmentation) given input images and text.""" + backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = self._prepare_backbone_features( + backbone_out, num_prompts=len(text_ids) + ) + backbone_out.update({k: v for k, v in self.text_embeddings.items()}) + with torch.profiler.record_function("SAM3Image._encode_prompt"): + prompt, prompt_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt) + # index text features (note that regardless of early or late fusion, the batch size of + # `txt_feats` is always the number of *prompts* in the encoder) + txt_feats = backbone_out["language_features"][:, text_ids] + txt_masks = backbone_out["language_mask"][text_ids] + # encode text + prompt = torch.cat([txt_feats, prompt], dim=0) + prompt_mask = torch.cat([txt_masks, prompt_mask], dim=1) + + # Run the encoder + with torch.profiler.record_function("SAM3Image._run_encoder"): + encoder_out = self._run_encoder(img_feats, img_pos_embeds, vis_feat_sizes, prompt, prompt_mask) + out = {"backbone_out": backbone_out} + + # Run the decoder + with torch.profiler.record_function("SAM3Image._run_decoder"): + out, hs = self._run_decoder( + memory=encoder_out["encoder_hidden_states"], + pos_embed=encoder_out["pos_embed"], + src_mask=encoder_out["padding_mask"], + out=out, + prompt=prompt, + prompt_mask=prompt_mask, + encoder_out=encoder_out, + ) + + # Run segmentation heads + with torch.profiler.record_function("SAM3Image._run_segmentation_heads"): + self._run_segmentation_heads( + out=out, + backbone_out=backbone_out, + encoder_hidden_states=encoder_out["encoder_hidden_states"], + prompt=prompt, + prompt_mask=prompt_mask, + hs=hs, + ) + return out + + def set_classes(self, text: list[str]): + """Set the text embeddings for the given class names.""" + self.text_embeddings = self.backbone.forward_text(text) + self.names = text + + def set_imgsz(self, imgsz: tuple[int, int]): + """Set the image size for the model.""" + self.backbone.set_imgsz(imgsz) diff --git a/ultralytics/models/sam/sam3/text_encoder_ve.py b/ultralytics/models/sam/sam3/text_encoder_ve.py new file mode 100644 index 0000000000..7c45a2ed27 --- /dev/null +++ b/ultralytics/models/sam/sam3/text_encoder_ve.py @@ -0,0 +1,307 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +from __future__ import annotations + +from collections import OrderedDict +from typing import Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from .model_misc import LayerScale + + +class ResidualAttentionBlock(nn.Module): + """Transformer block with multi-head attention, layer normalization, and MLP feed-forward network.""" + + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float | None = None, + act_layer: Callable[[], nn.Module] = nn.GELU, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, + ): + """Initialize residual attention block with configurable dimensions and normalization.""" + super().__init__() + # Attention + self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) + + # LayerNorm, LayerScale + self.ln_1 = norm_layer(d_model) + self.ln_2 = norm_layer(d_model) + + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + # MLP + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)), + ] + ) + ) + + def attention( + self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None + ) -> torch.Tensor: + """Compute multi-head attention with optional cross-attention support and masking.""" + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + if attn_mask is not None: + # Leave boolean masks as is + if not attn_mask.dtype == torch.bool: + attn_mask = attn_mask.to(q_x.dtype) + + return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0] + + def forward( + self, q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None + ) -> torch.Tensor: + """Apply residual attention with layer normalization and MLP, supporting optional cross-attention.""" + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + """Stack of residual attention blocks forming a transformer encoder with optional gradient checkpointing.""" + + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float | None = None, + act_layer: Callable[[], nn.Module] = nn.GELU, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, + compile_mode: str | None = None, + use_act_checkpoint: bool = False, + ): + """Initialize transformer with configurable depth, width, and optional compilation/checkpointing.""" + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = use_act_checkpoint + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + for _ in range(layers) + ] + ) + + if compile_mode is not None: + self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True) + if self.grad_checkpointing: + torch._dynamo.config.optimize_ddp = False + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor: + """Process input through all transformer blocks with optional gradient checkpointing during training.""" + for _, r in enumerate(self.resblocks): + if self.grad_checkpointing and not torch.jit.is_scripting() and self.training: + x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) + else: + x = r(x, attn_mask=attn_mask) + return x + + +def text_global_pool( + x: torch.Tensor, text: torch.Tensor = None, pool_type: str = "argmax" +) -> tuple[torch.Tensor, torch.Tensor]: + """Extract pooled representation and tokens from text embeddings using specified pooling strategy + (first/last/argmax/none). + """ + if pool_type == "first": + pooled, tokens = x[:, 0], x[:, 1:] + elif pool_type == "last": + pooled, tokens = x[:, -1], x[:, :-1] + elif pool_type == "argmax": + # take features from the eot embedding (eot_token is the highest number in each sequence) + assert text is not None + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + else: + pooled = tokens = x + return pooled, tokens + + +class TextTransformer(nn.Module): + """Text transformer encoder with causal masking and flexible pooling strategies.""" + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + mlp_ratio: float = 4.0, + ls_init_value: float | None = None, + output_dim: int = 512, + no_causal_mask: bool = False, + pool_type: str = "none", # no pooling + proj_bias: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + output_tokens: bool = False, + use_ln_post: bool = True, + compile_mode: str | None = None, + use_act_checkpoint: bool = False, + ): + """Initialize text transformer with embedding layers, transformer blocks, and pooling options.""" + super().__init__() + assert pool_type in ("first", "last", "argmax", "none") + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pool_type = pool_type + + self.token_embedding = nn.Embedding(self.vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + compile_mode=compile_mode, + use_act_checkpoint=use_act_checkpoint, + ) + self.ln_final = norm_layer(width) if use_ln_post else nn.Identity() + if no_causal_mask: + self.attn_mask = None + else: + self.register_buffer("attn_mask", self.build_causal_mask(), persistent=False) + if proj_bias: + self.text_projection = nn.Linear(width, output_dim) + else: + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def build_causal_mask(self) -> torch.Tensor: + """Create a causal attention mask to prevent attention to future tokens.""" + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Forward pass through the text transformer, returning pooled output and optionally token embeddings.""" + seq_len = text.shape[1] + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + attn_mask = self.attn_mask + if attn_mask is not None: + attn_mask = attn_mask[:seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len] + x = self.transformer(x, attn_mask=attn_mask) + + x = self.ln_final(x) + pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + pooled = self.text_projection(pooled) + else: + pooled = pooled @ self.text_projection + if self.output_tokens: + return pooled, tokens + return pooled + + +class VETextEncoder(nn.Module): + """Text encoder for Vision Encoder (VE) models, combining a text transformer and a linear resizer.""" + + def __init__( + self, + d_model: int, + tokenizer: Callable, + width: int = 1024, + heads: int = 16, + layers: int = 24, + context_length: int = 32, + vocab_size: int = 49408, + use_ln_post: bool = True, + compile_mode: str | None = None, + use_act_checkpoint: bool = True, + ): + """Initialize VE text encoder with a text transformer and a linear resizer to match decoder dimensions.""" + super().__init__() + self.context_length = context_length + self.use_ln_post = use_ln_post + self.tokenizer = tokenizer + + self.encoder = TextTransformer( + context_length=self.context_length, + vocab_size=vocab_size, + width=width, + heads=heads, + layers=layers, + # we want the tokens, not just the pooled output + output_tokens=True, + use_ln_post=use_ln_post, + compile_mode=compile_mode, + use_act_checkpoint=use_act_checkpoint, + ) + self.resizer = nn.Linear(self.encoder.width, d_model) + + def forward( + self, text: list[str] | tuple[torch.Tensor, torch.Tensor, dict], input_boxes: list | None = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode text input, either raw strings or pre-encoded tensors, and resize to match decoder dimensions.""" + if isinstance(text[0], str): + # no use case for this + assert input_boxes is None or len(input_boxes) == 0, "not supported" + + # Encode the text + tokenized = self.tokenizer(text, context_length=self.context_length).to( + self.resizer.weight.device + ) # [b, seq_len] + text_attention_mask = (tokenized != 0).bool() + + # manually embed the tokens + inputs_embeds = self.encoder.token_embedding(tokenized) # [b, seq_len, d=1024] + _, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024] + + assert text_memory.shape[1] == inputs_embeds.shape[1] + # Invert attention mask because its the opposite in pytorch transformer + text_attention_mask = text_attention_mask.ne(1) + # Transpose memory because pytorch's attention expects sequence first + text_memory = text_memory.transpose(0, 1) + # Resize the encoder hidden states to be of the same d_model as the decoder + text_memory_resized = self.resizer(text_memory) + else: + # The text is already encoded, use as is. + text_attention_mask, text_memory_resized, tokenized = text + inputs_embeds = tokenized["inputs_embeds"] + assert input_boxes is None or len(input_boxes) == 0, "Can't replace boxes in text if it's already encoded" + + # Note that the input_embeds are returned in pytorch's convention (sequence first) + return ( + text_attention_mask, + text_memory_resized, + inputs_embeds.transpose(0, 1), + ) diff --git a/ultralytics/models/sam/sam3/tokenizer_ve.py b/ultralytics/models/sam/sam3/tokenizer_ve.py new file mode 100644 index 0000000000..05810b4dc4 --- /dev/null +++ b/ultralytics/models/sam/sam3/tokenizer_ve.py @@ -0,0 +1,242 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +""" +Text Tokenizer. + +Copied and lightly adapted from VE repo, which in turn copied +from open_clip and openAI CLIP. +""" + +from __future__ import annotations + +import gzip +import html +import io +import os +import string +from functools import lru_cache + +import ftfy +import regex as re +import torch +from iopath.common.file_io import g_pathmgr + + +@lru_cache +def bytes_to_unicode(): + """Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode + strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When + you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a + significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 + bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("ยก"), ord("ยฌ") + 1)) + list(range(ord("ยฎ"), ord("รฟ") + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length + strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + """Basic text cleaning: fix unicode and unescape HTML entities.""" + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + """Remove redundant whitespace.""" + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def _clean_canonicalize(x): + """Clean text and canonicalize it.""" + # basic, remove whitespace, remove punctuation, lower case + return canonicalize_text(basic_clean(x)) + + +def _clean_lower(x): + """Clean text and return lowercase.""" + # basic, remove whitespace, lower case + return whitespace_clean(basic_clean(x)).lower() + + +def _clean_whitespace(x): + """Clean text and remove redundant whitespace.""" + # basic, remove whitespace + return whitespace_clean(basic_clean(x)) + + +def get_clean_fn(type: str): + """Get text cleaning function by name.""" + if type == "canonicalize": + return _clean_canonicalize + elif type == "lower": + return _clean_lower + elif type == "whitespace": + return _clean_whitespace + else: + assert False, f"Invalid clean function ({type})." + + +def canonicalize_text(text, *, keep_punctuation_exact_string=None): + """Returns canonicalized `text` (lowercase and punctuation removed). From: + https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94. + + Args: + text: string to be canonicalized. + keep_punctuation_exact_string: If provided, then this exact string kept. For example providing '{}' will keep + any occurrences of '{}' (but will still remove '{' and '}' that appear separately). + """ + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans("", "", string.punctuation)) + for part in text.split(keep_punctuation_exact_string) + ) + else: + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + text = re.sub(r"\s+", " ", text) + return text.strip() + + +class SimpleTokenizer: + """A simple tokenizer for text inputs.""" + + def __init__( + self, + bpe_path: str | os.PathLike, + additional_special_tokens: list[str] | None = None, + context_length: int = 77, + clean: str = "lower", + ): + """The tokenizer for text inputs.""" + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with g_pathmgr.open(bpe_path, "rb") as fh: + bpe_bytes = io.BytesIO(fh.read()) + merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") + # merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + special_tokens = ["", ""] + if additional_special_tokens: + special_tokens += additional_special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + self.sot_token_id = self.all_special_ids[0] + self.eot_token_id = self.all_special_ids[1] + self.context_length = context_length + self.clean_fn = get_clean_fn(clean) + + def bpe(self, token): + """Byte Pair Encoding.""" + if token in self.cache: + return self.cache[token] + word = (*tuple(token[:-1]), token[-1] + "") + pairs = get_pairs(word) + if not pairs: + return token + "" + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + """Encode text to a sequence of BPE tokens.""" + bpe_tokens = [] + text = self.clean_fn(text) + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def decode(self, tokens): + """Decodes a sequence of tokens back into a text string.""" + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("", " ") + return text + + def __call__(self, texts: str | list[str], context_length: int | None = None) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) Parameters. ---------- texts : Union[str, + list[str]] An input string or a list of input strings to tokenize context_length : int The context + length to use; all CLIP models use 77 as the context length. + + Returns: + -------: A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, + context_length] + """ + if isinstance(texts, str): + texts = [texts] + context_length = context_length or self.context_length + assert context_length, "Please set a valid context length" + all_tokens = [[self.sot_token_id, *self.encode(text), self.eot_token_id] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = self.eot_token_id + result[i, : len(tokens)] = torch.tensor(tokens) + return result diff --git a/ultralytics/models/sam/sam3/vitdet.py b/ultralytics/models/sam/sam3/vitdet.py new file mode 100644 index 0000000000..ee28fba738 --- /dev/null +++ b/ultralytics/models/sam/sam3/vitdet.py @@ -0,0 +1,546 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +""" +ViTDet backbone adapted from Detectron2. +This module implements Vision Transformer (ViT) backbone for object detection. + +Rope embedding code adopted from: +1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +2. https://github.com/naver-ai/rope-vit +3. https://github.com/lucidrains/rotary-embedding-torch +""" + +from __future__ import annotations + +import math +from functools import partial +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch import Tensor + +from ultralytics.models.sam.modules.blocks import PatchEmbed +from ultralytics.models.sam.modules.utils import ( + apply_rotary_enc, + compute_axial_cis, + concat_rel_pos, + get_abs_pos, + window_partition, + window_unpartition, +) +from ultralytics.utils.checks import check_requirements + +from .model_misc import LayerScale + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings and 2d-rope.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: tuple[int, int] | None = None, + cls_token: bool = False, + use_rope: bool = False, + rope_theta: float = 10000.0, + rope_pt_size: tuple[int, int] | None = None, + rope_interp: bool = False, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional parameter size or rope + size. + attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer". + cls_token: whether a cls_token is present. + use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together) + use_rel_pos: whether to use relative positional embeddings + rope_theta: control frequencies of rope + rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling + rope_interp: whether to interpolate (or extrapolate) rope to match input size. + """ + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.cls_token = cls_token + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + # rel_pos embeddings and rope + self.use_rel_pos = use_rel_pos + self.input_size = input_size + + self.use_rope = use_rope + self.rope_theta = rope_theta + self.rope_pt_size = rope_pt_size + self.rope_interp = rope_interp + + # init rel_pos embeddings and rope + self._setup_rel_pos(rel_pos_zero_init, input_size) + self._setup_rope_freqs(input_size) + + def _setup_rel_pos(self, rel_pos_zero_init: bool = True, input_size: tuple[int, int] | None = None) -> None: + """Setup relative positional embeddings.""" + if not self.use_rel_pos: + self.rel_pos_h = None + self.rel_pos_w = None + return + + assert input_size is not None + assert self.cls_token is False, "not supported" + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim)) + + if not rel_pos_zero_init: + nn.init.trunc_normal_(self.rel_pos_h, std=0.02) + nn.init.trunc_normal_(self.rel_pos_w, std=0.02) + + # Precompute the relative coords + H, W = input_size + q_coords = torch.arange(H)[:, None] + k_coords = torch.arange(W)[None, :] + relative_coords = (q_coords - k_coords) + (H - 1) + self.relative_coords = relative_coords.long() + + def _setup_rope_freqs(self, input_size: tuple[int, int] | None = None) -> None: + """Setup 2d-rope frequencies.""" + if not self.use_rope: + self.freqs_cis = None + return + + assert input_size is not None + # determine rope input size + if self.rope_pt_size is None: + self.rope_pt_size = input_size + + # initialize 2d rope freqs + self.compute_cis = partial( + compute_axial_cis, + dim=self.head_dim, + theta=self.rope_theta, + ) + + # interpolate rope + scale_pos = 1.0 + if self.rope_interp: + scale_pos = self.rope_pt_size[0] / input_size[0] + # get scaled freqs_cis + freqs_cis = self.compute_cis( + end_x=input_size[0], + end_y=input_size[1], + scale_pos=scale_pos, + ) + if self.cls_token: + t = torch.zeros( + self.head_dim // 2, + dtype=torch.float32, + device=freqs_cis.device, + ) + cls_freqs_cis = torch.polar(torch.ones_like(t), t)[None, :] + freqs_cis = torch.cat([cls_freqs_cis, freqs_cis], dim=0) + + self.freqs_cis = freqs_cis + + def _apply_rope(self, q, k) -> tuple[Tensor, Tensor]: + """Apply 2d-rope to q and k.""" + if not self.use_rope: + return q, k + + assert self.freqs_cis is not None + return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis.to(q.device)) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of attention block.""" + s = 1 if self.cls_token else 0 # used to exclude cls_token + if x.ndim == 4: + B, H, W, _ = x.shape + assert s == 0 # no cls_token + L = H * W + ndim = 4 + else: + assert x.ndim == 3 + B, L, _ = x.shape + ndim = 3 + H = W = math.sqrt(L - s) + + # qkv with shape (3, B, nHead, L, C) + qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, -1) + # q, k, v with shape (B, nHead, L, C) + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) + + # handle rope and rel pos embeddings + q, k = self._apply_rope(q, k) + if self.use_rel_pos: + q, k = concat_rel_pos( + q.flatten(0, 1), + k.flatten(0, 1), + (H, W), + x.shape[1:3], + self.rel_pos_h, + self.rel_pos_w, + rescale=True, + relative_coords=self.relative_coords, + ) + + # sdpa expects [B, nheads, H*W, C] so we transpose back + q = q.reshape(B, self.num_heads, H * W, -1) + k = k.reshape(B, self.num_heads, H * W, -1) + + x = F.scaled_dot_product_attention(q, k, v) + + if ndim == 4: + x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + else: + x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1) + + x = self.proj(x) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention.""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop_path: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + act_layer: Callable[..., nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: tuple[int, int] | None = None, + use_rope: bool = False, + rope_pt_size: tuple[int, int] | None = None, + rope_interp: bool = False, + cls_token: bool = False, + dropout: float = 0.0, + init_values: float | None = None, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then not use window attention. + input_size (int or None): Input resolution for calculating the relative positional parameter size. + dropout (float): Dropout rate. + cls_token: whether a cls_token is present. + use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together) + rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling + rope_interp: whether to interpolate (or extrapolate) rope to match target input size, expected to specify + source size as rope_pt_size. + init_values: layer scale init, None for no layer scale. + """ + super().__init__() + + check_requirements("timm") + from timm.layers import DropPath, Mlp + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + use_rope=use_rope, + rope_pt_size=rope_pt_size, + rope_interp=rope_interp, + cls_token=cls_token, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=(dropout, 0.0), + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.dropout = nn.Dropout(dropout) + self.window_size = window_size + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of the transformer block.""" + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.ls1(self.attn(x)) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + self.dropout(self.drop_path(x)) + x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x))))) + + return x + + +class ViT(nn.Module): + """This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. "Exploring Plain Vision Transformer + Backbones for Object Detection", https://arxiv.org/abs/2203.16527. + """ + + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop_path_rate: float = 0.0, + norm_layer: Callable[..., nn.Module] | str = "LayerNorm", + act_layer: Callable[..., nn.Module] = nn.GELU, + use_abs_pos: bool = True, + tile_abs_pos: bool = True, + rel_pos_blocks: tuple[int, ...] | bool = (2, 5, 8, 11), + rel_pos_zero_init: bool = True, + window_size: int = 14, + global_att_blocks: tuple[int, ...] = (2, 5, 8, 11), + use_rope: bool = False, + rope_pt_size: int | None = None, + use_interp_rope: bool = False, + pretrain_img_size: int = 224, + pretrain_use_cls_token: bool = True, + retain_cls_token: bool = True, + dropout: float = 0.0, + return_interm_layers: bool = False, + init_values: float | None = None, # for layerscale + ln_pre: bool = False, + ln_post: bool = False, + bias_patch_embed: bool = True, + compile_mode: str | None = None, + use_act_checkpoint: bool = True, + ): + """ + Args: + img_size (int): Input image size. Only relevant for rel pos or rope. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation. + rel_pos_blocks (list): Blocks which have rel pos embeddings. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention). + use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together). + rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling. + use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size, expected to + specify source size as rope_pt_size. + use_act_checkpoint (bool): If True, use activation checkpointing. + pretrain_img_size (int): input image size for pretraining models. + pretrain_use_cls_token (bool): If True, pretraining models use class token. + retain_cls_token: whether cls_token should be retained. + dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp. + return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks). + init_values: layer scale init, None for no layer scale. + ln_pre (bool): If True, apply layer norm before transformer blocks. + ln_post (bool): If True, apply layer norm after transformer blocks. + bias_patch_embed (bool): bias in conv for patch embed? + compile_mode (str): mode to compile the forward. + """ + super().__init__() + self.pretrain_use_cls_token = pretrain_use_cls_token + + window_block_indexes = [i for i in range(depth) if i not in global_att_blocks] + self.full_attn_ids = list(global_att_blocks) + self.rel_pos_blocks = [False] * depth + if isinstance(rel_pos_blocks, bool) and rel_pos_blocks: + self.rel_pos_blocks = [True] * depth + else: + for i in rel_pos_blocks: + self.rel_pos_blocks[i] = True + + self.retain_cls_token = retain_cls_token + if self.retain_cls_token: + assert pretrain_use_cls_token + assert len(window_block_indexes) == 0, "windowing not supported with cls token" + + assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token" + + scale = embed_dim**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(1, 1, embed_dim)) + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-5) + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + bias=bias_patch_embed, + ) + + # Handle absolute positional embedding + self.tile_abs_pos = tile_abs_pos + self.use_abs_pos = use_abs_pos + if self.tile_abs_pos: + assert self.use_abs_pos + + if self.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size) + num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) + else: + self.pos_embed = None + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + self.patch_size = patch_size + self.window_size = window_size + self.blocks = nn.ModuleList() + cur_stage = 1 + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=self.rel_pos_blocks[i], + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i in window_block_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + use_rope=use_rope, + rope_pt_size=((window_size, window_size) if rope_pt_size is None else (rope_pt_size, rope_pt_size)), + rope_interp=use_interp_rope, + cls_token=self.retain_cls_token, + dropout=dropout, + init_values=init_values, + ) + + if i not in window_block_indexes: + cur_stage += 1 + + self.use_act_checkpoint = use_act_checkpoint + + self.blocks.append(block) + + self.return_interm_layers = return_interm_layers + self.channel_list = [embed_dim] * len(self.full_attn_ids) if return_interm_layers else [embed_dim] + + if self.pos_embed is not None: + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + self.ln_pre = norm_layer(embed_dim) if ln_pre else nn.Identity() + self.ln_post = norm_layer(embed_dim) if ln_post else nn.Identity() + + self.apply(self._init_weights) + + if compile_mode is not None: + self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True) + if self.use_act_checkpoint and self.training: + torch._dynamo.config.optimize_ddp = False + + def _init_weights(self, m: nn.Module) -> None: + """Initialize the weights.""" + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """Vit forward path and get feature maps.""" + x = self.patch_embed(x) + h, w = x.shape[1], x.shape[2] + + s = 0 + if self.retain_cls_token: + # If cls_token is retained, we don't + # maintain spatial shape + x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1) + s = 1 + + if self.pos_embed is not None: + x = x + get_abs_pos( + self.pos_embed, + self.pretrain_use_cls_token, + (h, w), + self.retain_cls_token, + tiling=self.tile_abs_pos, + ) + + x = self.ln_pre(x) + + outputs = [] + for i, blk in enumerate(self.blocks): + if self.use_act_checkpoint and self.training: + x = checkpoint.checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + if (i == self.full_attn_ids[-1]) or (self.return_interm_layers and i in self.full_attn_ids): + if i == self.full_attn_ids[-1]: + x = self.ln_post(x) + + feats = x[:, s:] + if feats.ndim == 4: + feats = feats.permute(0, 3, 1, 2) + else: + assert feats.ndim == 3 + h = w = math.sqrt(feats.shape[1]) + feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2) + + outputs.append(feats) + + return outputs + + def set_imgsz(self, imgsz: list[int] = [1008, 1008]): + """Setup rel pos embeddings and rope freqs for a new input image size.""" + for block in self.blocks: + if block.window_size != 0: + continue + block.attn._setup_rel_pos(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size)) + block.attn._setup_rope_freqs(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size)) diff --git a/ultralytics/models/sam/sam3/vl_combiner.py b/ultralytics/models/sam/sam3/vl_combiner.py new file mode 100644 index 0000000000..c47c0f9c09 --- /dev/null +++ b/ultralytics/models/sam/sam3/vl_combiner.py @@ -0,0 +1,165 @@ +# Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +"""Provides utility to combine a vision backbone with a language backbone.""" + +from __future__ import annotations + +from copy import copy + +import torch +import torch.nn as nn +from torch.nn.attention import SDPBackend, sdpa_kernel + +from .necks import Sam3DualViTDetNeck + + +class SAM3VLBackbone(nn.Module): + """This backbone combines a vision backbone and a language backbone without fusion. As such it is more of a + convenience wrapper to handle the two backbones together. + + It adds support for activation checkpointing and compilation. + """ + + def __init__( + self, + visual: Sam3DualViTDetNeck, + text, + compile_visual: bool = False, + act_ckpt_whole_vision_backbone: bool = False, + act_ckpt_whole_language_backbone: bool = False, + scalp=0, + ): + """Initialize the backbone combiner. + + :param visual: The vision backbone to use + :param text: The text encoder to use + """ + super().__init__() + self.vision_backbone: Sam3DualViTDetNeck = torch.compile(visual) if compile_visual else visual + self.language_backbone = text + self.scalp = scalp + # allow running activation checkpointing on the entire vision and language backbones + self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone + self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone + + def forward( + self, + samples: torch.Tensor, + captions: list[str], + input_boxes: torch.Tensor = None, + additional_text: list[str] | None = None, + ): + """Forward pass of the backbone combiner. + + :param samples: The input images + :param captions: The input captions + :param input_boxes: If the text contains place-holders for boxes, this + parameter contains the tensor containing their spatial features + :param additional_text: This can be used to encode some additional text + (different from the captions) in the same forward of the backbone + :return: Output dictionary with the following keys: + - vision_features: The output of the vision backbone + - language_features: The output of the language backbone + - language_mask: The attention mask of the language backbone + - vision_pos_enc: The positional encoding of the vision backbone + - (optional) additional_text_features: The output of the language + backbone for the additional text + - (optional) additional_text_mask: The attention mask of the + language backbone for the additional text + """ + output = self.forward_image(samples) + output.update(self.forward_text(captions, input_boxes, additional_text)) + return output + + def forward_image(self, samples: torch.Tensor): + """Forward pass of the vision backbone and get both SAM3 and SAM2 features.""" + # Forward through backbone + sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(samples) + if self.scalp > 0: + # Discard the lowest resolution features + sam3_features, sam3_pos = ( + sam3_features[: -self.scalp], + sam3_pos[: -self.scalp], + ) + if sam2_features is not None and sam2_pos is not None: + sam2_features, sam2_pos = ( + sam2_features[: -self.scalp], + sam2_pos[: -self.scalp], + ) + + sam2_output = None + + if sam2_features is not None and sam2_pos is not None: + sam2_src = sam2_features[-1] + sam2_output = { + "vision_features": sam2_src, + "vision_pos_enc": sam2_pos, + "backbone_fpn": sam2_features, + } + + sam3_src = sam3_features[-1] + return { + "vision_features": sam3_src, + "vision_pos_enc": sam3_pos, + "backbone_fpn": sam3_features, + "sam2_backbone_out": sam2_output, + } + + def forward_image_sam2(self, samples: torch.Tensor): + """Forward pass of the vision backbone to get SAM2 features only.""" + xs = self.vision_backbone.trunk(samples) + sam2_features, sam2_pos = [], [] + x = xs[-1] # simpleFPN + + assert self.vision_backbone.sam2_convs is not None, "SAM2 neck is not available." + for i in range(len(self.vision_backbone.sam2_convs)): + sam2_x_out = self.vision_backbone.sam2_convs[i](x) + sam2_pos_out = self.vision_backbone.position_encoding(sam2_x_out).to(sam2_x_out.dtype) + sam2_features.append(sam2_x_out) + sam2_pos.append(sam2_pos_out) + + if self.scalp > 0: + # Discard the lowest resolution features + sam2_features, sam2_pos = ( + sam2_features[: -self.scalp], + sam2_pos[: -self.scalp], + ) + + return { + "vision_features": sam2_features[-1], + "vision_pos_enc": sam2_pos, + "backbone_fpn": sam2_features, + } + + def forward_text(self, captions, input_boxes=None, additional_text=None): + """Forward pass of the text encoder.""" + output = {} + + # Forward through text_encoder + text_to_encode = copy(captions) + if additional_text is not None: + # if there are additional_text, we piggy-back them into this forward. + # They'll be used later for output alignment + text_to_encode += additional_text + + with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]): + text_attention_mask, text_memory, text_embeds = self.language_backbone(text_to_encode, input_boxes) + + if additional_text is not None: + output["additional_text_features"] = text_memory[:, -len(additional_text) :] + output["additional_text_mask"] = text_attention_mask[-len(additional_text) :] + + text_memory = text_memory[:, : len(captions)] + text_attention_mask = text_attention_mask[: len(captions)] + text_embeds = text_embeds[:, : len(captions)] + output["language_features"] = text_memory + output["language_mask"] = text_attention_mask + output["language_embeds"] = text_embeds # Text embeddings before forward to the encoder + + return output + + def set_imgsz(self, imgsz: list[int] = [1008, 1008]): + """Set the image size for the vision backbone.""" + self.vision_backbone.set_imgsz(imgsz) diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py index c8b5f04eb7..0e54e21198 100644 --- a/ultralytics/nn/modules/transformer.py +++ b/ultralytics/nn/modules/transformer.py @@ -359,7 +359,15 @@ class MLP(nn.Module): """ def __init__( - self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act=nn.ReLU, sigmoid: bool = False + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + act=nn.ReLU, + sigmoid: bool = False, + residual: bool = False, + out_norm: nn.Module = None, ): """Initialize the MLP with specified input, hidden, output dimensions and number of layers. @@ -370,6 +378,8 @@ class MLP(nn.Module): num_layers (int): Number of layers. act (nn.Module): Activation function. sigmoid (bool): Whether to apply sigmoid to the output. + residual (bool): Whether to use residual connections. + out_norm (nn.Module, optional): Normalization layer for the output. """ super().__init__() self.num_layers = num_layers @@ -377,6 +387,12 @@ class MLP(nn.Module): self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim, *h], [*h, output_dim])) self.sigmoid = sigmoid self.act = act() + if residual and input_dim != output_dim: + raise ValueError("residual is only supported if input_dim == output_dim") + self.residual = residual + # whether to apply a normalization layer to the output + assert isinstance(out_norm, nn.Module) or out_norm is None + self.out_norm = out_norm or nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass for the entire MLP. @@ -387,8 +403,12 @@ class MLP(nn.Module): Returns: (torch.Tensor): Output tensor after MLP. """ + orig_x = x for i, layer in enumerate(self.layers): x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x) + if getattr(self, "residual", False): + x = x + orig_x + x = getattr(self, "out_norm", nn.Identity())(x) return x.sigmoid() if getattr(self, "sigmoid", False) else x diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index 787b48e889..2f8cec5c0e 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -660,6 +660,4 @@ def clean_str(s): def empty_like(x): """Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype.""" - return ( - torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32) - ) + return torch.empty_like(x, dtype=x.dtype) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=x.dtype)