fix: 🐞 Add requirement check for TensorRT in test_export_engine_matrix (#23496)

Signed-off-by: Onuralp SEZER <onuralp@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Onuralp SEZER 2026-01-30 18:31:49 +03:00 committed by GitHub
parent a8b639bc35
commit f8e24c56cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 15 additions and 4 deletions

View file

@ -63,6 +63,10 @@ keywords: Ultralytics, YOLO, utility functions, version checks, requirements, im
<br><br><hr><br>
## ::: ultralytics.utils.checks.check_tensorrt
<br><br><hr><br>
## ::: ultralytics.utils.checks.check_torchvision
<br><br><hr><br>

View file

@ -12,7 +12,7 @@ from ultralytics import YOLO
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
from ultralytics.utils import ASSETS, IS_JETSON, WEIGHTS_DIR
from ultralytics.utils.autodevice import GPUInfo
from ultralytics.utils.checks import check_amp
from ultralytics.utils.checks import check_amp, check_tensorrt
from ultralytics.utils.torch_utils import TORCH_1_13
# Try to find idle devices if CUDA is available
@ -91,6 +91,7 @@ def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify, nms):
)
def test_export_engine_matrix(task, dynamic, int8, half, batch):
"""Test YOLO model export to TensorRT format for various configurations and run inference."""
check_tensorrt()
import tensorrt as trt
is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10

View file

@ -110,6 +110,7 @@ from ultralytics.utils.checks import (
check_executorch_requirements,
check_imgsz,
check_requirements,
check_tensorrt,
check_version,
is_intel,
is_sudo_available,
@ -1005,9 +1006,7 @@ class Exporter:
try:
import tensorrt as trt
except ImportError:
if LINUX:
cuda_version = torch.version.cuda.split(".")[0]
check_requirements(f"tensorrt-cu{cuda_version}>7.0.0,!=10.1.0")
check_tensorrt()
import tensorrt as trt
check_version(trt.__version__, ">=7.0.0", hard=True)
check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")

View file

@ -507,6 +507,13 @@ def check_executorch_requirements():
check_requirements("numpy<=2.3.5")
def check_tensorrt():
"""Check and install TensorRT requirements including platform-specific dependencies."""
if LINUX:
cuda_version = torch.version.cuda.split(".")[0]
check_requirements(f"tensorrt-cu{cuda_version}>7.0.0,!=10.1.0")
def check_torchvision():
"""Check the installed versions of PyTorch and Torchvision to ensure they're compatible.