SAM3: Skip geometry token when using text prompts (#24244)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Jing Qiu <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Mohammed Yasin 2026-04-16 08:29:35 +06:00 committed by GitHub
parent 518d31c1ee
commit ca24fdd77d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 10 additions and 6 deletions

View file

@ -2273,8 +2273,9 @@ class SAM3SemanticPredictor(SAM3Predictor):
"""Run inference on the extracted features with optional bounding boxes and labels."""
# NOTE: priority: bboxes > text > pre-set classes
nc = 1 if bboxes is not None else len(text) if text is not None else len(self.model.names)
geometric_prompt = self._get_dummy_prompt(nc)
geometric_prompt = None
if bboxes is not None:
geometric_prompt = self._get_dummy_prompt(nc)
for i in range(len(bboxes)):
geometric_prompt.append_boxes(bboxes[[i]], labels[[i]])
if text is None:

View file

@ -290,15 +290,18 @@ class SAM3SemanticModel(torch.nn.Module):
self, backbone_out, batch=len(text_ids)
)
backbone_out.update({k: v for k, v in self.text_embeddings.items()})
with torch.profiler.record_function("SAM3Image._encode_prompt"):
prompt, prompt_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt)
# index text features (note that regardless of early or late fusion, the batch size of
# `txt_feats` is always the number of *prompts* in the encoder)
txt_feats = backbone_out["language_features"][:, text_ids]
txt_masks = backbone_out["language_mask"][text_ids]
# encode text
prompt = torch.cat([txt_feats, prompt], dim=0)
prompt_mask = torch.cat([txt_masks, prompt_mask], dim=1)
if geometric_prompt is not None:
with torch.profiler.record_function("SAM3Image._encode_prompt"):
geo_prompt, geo_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt)
prompt = torch.cat([txt_feats, geo_prompt], dim=0)
prompt_mask = torch.cat([txt_masks, geo_mask], dim=1)
else:
prompt = txt_feats
prompt_mask = txt_masks
# Run the encoder
with torch.profiler.record_function("SAM3Image._run_encoder"):