Fix AutoBatch with multispectral images (#23546)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Jing Qiu <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Mohammed Yasin 2026-02-04 13:30:18 +08:00 committed by GitHub
parent 6166eea721
commit aacb357fbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 1 deletions

View file

@ -121,6 +121,7 @@ def test_train():
device = tuple(DEVICES) if len(DEVICES) > 1 else DEVICES[0]
# NVIDIA Jetson only has one GPU and therefore skipping checks
if not IS_JETSON:
results = YOLO(MODEL).train(data="coco8-grayscale.yaml", imgsz=64, epochs=1, device=DEVICES[0], batch=-1)
results = YOLO(MODEL).train(data="coco8.yaml", imgsz=64, epochs=1, device=device, batch=15, compile=True)
results = YOLO(MODEL).train(data="coco128.yaml", imgsz=64, epochs=1, device=device, batch=15, val=False)
visible = eval(os.environ["CUDA_VISIBLE_DEVICES"])

View file

@ -84,8 +84,9 @@ def autobatch(
# Profile batch sizes
batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]
ch = model.yaml.get("channels", 3)
try:
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
img = [torch.empty(b, ch, imgsz, imgsz) for b in batch_sizes]
results = profile_ops(img, model, n=1, device=device, max_num_obj=max_num_obj)
# Fit a solution