mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
Fix AutoBatch with multispectral images (#23546)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Jing Qiu <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
6166eea721
commit
aacb357fbe
2 changed files with 3 additions and 1 deletions
|
|
@ -121,6 +121,7 @@ def test_train():
|
|||
device = tuple(DEVICES) if len(DEVICES) > 1 else DEVICES[0]
|
||||
# NVIDIA Jetson only has one GPU and therefore skipping checks
|
||||
if not IS_JETSON:
|
||||
results = YOLO(MODEL).train(data="coco8-grayscale.yaml", imgsz=64, epochs=1, device=DEVICES[0], batch=-1)
|
||||
results = YOLO(MODEL).train(data="coco8.yaml", imgsz=64, epochs=1, device=device, batch=15, compile=True)
|
||||
results = YOLO(MODEL).train(data="coco128.yaml", imgsz=64, epochs=1, device=device, batch=15, val=False)
|
||||
visible = eval(os.environ["CUDA_VISIBLE_DEVICES"])
|
||||
|
|
|
|||
|
|
@ -84,8 +84,9 @@ def autobatch(
|
|||
|
||||
# Profile batch sizes
|
||||
batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]
|
||||
ch = model.yaml.get("channels", 3)
|
||||
try:
|
||||
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
||||
img = [torch.empty(b, ch, imgsz, imgsz) for b in batch_sizes]
|
||||
results = profile_ops(img, model, n=1, device=device, max_num_obj=max_num_obj)
|
||||
|
||||
# Fit a solution
|
||||
|
|
|
|||
Loading…
Reference in a new issue