mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
Fix FP16 inference crash from TinyViT cached bias dtype mismatch (#23780)
Signed-off-by: Edwin-Kevin <45322858+Edwin-Kevin@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Onuralp SEZER <onuralp@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
parent
ba983b5bf6
commit
7ad68fae11
2 changed files with 5 additions and 2 deletions
|
|
@ -216,6 +216,7 @@ def test_predict_sam():
|
|||
imgsz=1024,
|
||||
model=WEIGHTS_DIR / "mobile_sam.pt",
|
||||
device=DEVICES[0],
|
||||
half=True,
|
||||
)
|
||||
)
|
||||
predictor.set_image(ASSETS / "zidane.jpg")
|
||||
|
|
|
|||
|
|
@ -454,9 +454,11 @@ class Predictor(BasePredictor):
|
|||
device = select_device(self.args.device, verbose=verbose)
|
||||
if model is None:
|
||||
model = self.get_model()
|
||||
model.eval()
|
||||
# Move model to device first, then cast dtype, then set eval so any eval-time caches are created on-device.
|
||||
model = model.to(device)
|
||||
self.model = model.half() if self.args.half else model.float()
|
||||
model = model.half() if self.args.half else model.float()
|
||||
model.eval()
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
|
||||
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
|
||||
|
|
|
|||
Loading…
Reference in a new issue