mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
SAM3: Skip geometry token when using text prompts (#24244)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Jing Qiu <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
518d31c1ee
commit
ca24fdd77d
2 changed files with 10 additions and 6 deletions
|
|
@ -2273,8 +2273,9 @@ class SAM3SemanticPredictor(SAM3Predictor):
|
|||
"""Run inference on the extracted features with optional bounding boxes and labels."""
|
||||
# NOTE: priority: bboxes > text > pre-set classes
|
||||
nc = 1 if bboxes is not None else len(text) if text is not None else len(self.model.names)
|
||||
geometric_prompt = self._get_dummy_prompt(nc)
|
||||
geometric_prompt = None
|
||||
if bboxes is not None:
|
||||
geometric_prompt = self._get_dummy_prompt(nc)
|
||||
for i in range(len(bboxes)):
|
||||
geometric_prompt.append_boxes(bboxes[[i]], labels[[i]])
|
||||
if text is None:
|
||||
|
|
|
|||
|
|
@ -290,15 +290,18 @@ class SAM3SemanticModel(torch.nn.Module):
|
|||
self, backbone_out, batch=len(text_ids)
|
||||
)
|
||||
backbone_out.update({k: v for k, v in self.text_embeddings.items()})
|
||||
with torch.profiler.record_function("SAM3Image._encode_prompt"):
|
||||
prompt, prompt_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt)
|
||||
# index text features (note that regardless of early or late fusion, the batch size of
|
||||
# `txt_feats` is always the number of *prompts* in the encoder)
|
||||
txt_feats = backbone_out["language_features"][:, text_ids]
|
||||
txt_masks = backbone_out["language_mask"][text_ids]
|
||||
# encode text
|
||||
prompt = torch.cat([txt_feats, prompt], dim=0)
|
||||
prompt_mask = torch.cat([txt_masks, prompt_mask], dim=1)
|
||||
if geometric_prompt is not None:
|
||||
with torch.profiler.record_function("SAM3Image._encode_prompt"):
|
||||
geo_prompt, geo_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt)
|
||||
prompt = torch.cat([txt_feats, geo_prompt], dim=0)
|
||||
prompt_mask = torch.cat([txt_masks, geo_mask], dim=1)
|
||||
else:
|
||||
prompt = txt_feats
|
||||
prompt_mask = txt_masks
|
||||
|
||||
# Run the encoder
|
||||
with torch.profiler.record_function("SAM3Image._run_encoder"):
|
||||
|
|
|
|||
Loading…
Reference in a new issue