Fix FP16 inference crash from TinyViT cached bias dtype mismatch (#23780)

Signed-off-by: Edwin-Kevin <45322858+Edwin-Kevin@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Onuralp SEZER <onuralp@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
Edwin-Kevin 2026-03-23 22:39:15 +08:00 committed by GitHub
parent ba983b5bf6
commit 7ad68fae11
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 2 deletions

View file

@ -216,6 +216,7 @@ def test_predict_sam():
imgsz=1024,
model=WEIGHTS_DIR / "mobile_sam.pt",
device=DEVICES[0],
half=True,
)
)
predictor.set_image(ASSETS / "zidane.jpg")

View file

@ -454,9 +454,11 @@ class Predictor(BasePredictor):
device = select_device(self.args.device, verbose=verbose)
if model is None:
model = self.get_model()
model.eval()
# Move model to device first, then cast dtype, then set eval so any eval-time caches are created on-device.
model = model.to(device)
self.model = model.half() if self.args.half else model.float()
model = model.half() if self.args.half else model.float()
model.eval()
self.model = model
self.device = device
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)