mirror of
https://github.com/ultralytics/ultralytics
synced 2026-04-21 14:07:18 +00:00
fix: 🐞 Add requirement check for TensorRT in test_export_engine_matrix (#23496)
Signed-off-by: Onuralp SEZER <onuralp@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
a8b639bc35
commit
f8e24c56cc
4 changed files with 15 additions and 4 deletions
|
|
@ -63,6 +63,10 @@ keywords: Ultralytics, YOLO, utility functions, version checks, requirements, im
|
|||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.utils.checks.check_tensorrt
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.utils.checks.check_torchvision
|
||||
|
||||
<br><br><hr><br>
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from ultralytics import YOLO
|
|||
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
|
||||
from ultralytics.utils import ASSETS, IS_JETSON, WEIGHTS_DIR
|
||||
from ultralytics.utils.autodevice import GPUInfo
|
||||
from ultralytics.utils.checks import check_amp
|
||||
from ultralytics.utils.checks import check_amp, check_tensorrt
|
||||
from ultralytics.utils.torch_utils import TORCH_1_13
|
||||
|
||||
# Try to find idle devices if CUDA is available
|
||||
|
|
@ -91,6 +91,7 @@ def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify, nms):
|
|||
)
|
||||
def test_export_engine_matrix(task, dynamic, int8, half, batch):
|
||||
"""Test YOLO model export to TensorRT format for various configurations and run inference."""
|
||||
check_tensorrt()
|
||||
import tensorrt as trt
|
||||
|
||||
is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ from ultralytics.utils.checks import (
|
|||
check_executorch_requirements,
|
||||
check_imgsz,
|
||||
check_requirements,
|
||||
check_tensorrt,
|
||||
check_version,
|
||||
is_intel,
|
||||
is_sudo_available,
|
||||
|
|
@ -1005,9 +1006,7 @@ class Exporter:
|
|||
try:
|
||||
import tensorrt as trt
|
||||
except ImportError:
|
||||
if LINUX:
|
||||
cuda_version = torch.version.cuda.split(".")[0]
|
||||
check_requirements(f"tensorrt-cu{cuda_version}>7.0.0,!=10.1.0")
|
||||
check_tensorrt()
|
||||
import tensorrt as trt
|
||||
check_version(trt.__version__, ">=7.0.0", hard=True)
|
||||
check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
|
||||
|
|
|
|||
|
|
@ -507,6 +507,13 @@ def check_executorch_requirements():
|
|||
check_requirements("numpy<=2.3.5")
|
||||
|
||||
|
||||
def check_tensorrt():
|
||||
"""Check and install TensorRT requirements including platform-specific dependencies."""
|
||||
if LINUX:
|
||||
cuda_version = torch.version.cuda.split(".")[0]
|
||||
check_requirements(f"tensorrt-cu{cuda_version}>7.0.0,!=10.1.0")
|
||||
|
||||
|
||||
def check_torchvision():
|
||||
"""Check the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue