diff --git a/docs/en/models/yolo-nas.md b/docs/en/models/yolo-nas.md index aa7e200f0b..6e6f09c05a 100644 --- a/docs/en/models/yolo-nas.md +++ b/docs/en/models/yolo-nas.md @@ -6,6 +6,10 @@ keywords: YOLO-NAS, Deci AI, object detection, deep learning, Neural Architectur # YOLO-NAS +!!! note "Important Update" + + Please note that [Deci](https://www.linkedin.com/company/deciai/), the original creators of YOLO-NAS, have been acquired by NVIDIA. As a result, these models are no longer actively maintained by Deci. Ultralytics continues to support the usage of these models, but no further updates from the original team are expected. + ## Overview Developed by Deci AI, YOLO-NAS is a groundbreaking object detection foundational model. It is the product of advanced [Neural Architecture Search](https://www.ultralytics.com/glossary/neural-architecture-search-nas) technology, meticulously designed to address the limitations of previous YOLO models. With significant improvements in quantization support and [accuracy](https://www.ultralytics.com/glossary/accuracy)-latency trade-offs, YOLO-NAS represents a major leap in object detection. diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py index 0d03b62bc1..27fc1186de 100644 --- a/ultralytics/models/nas/model.py +++ b/ultralytics/models/nas/model.py @@ -81,6 +81,7 @@ class NAS(Model): self.model.pt_path = weights # for export() self.model.task = "detect" # for export() self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export() + self.model.eval() def info(self, detailed: bool = False, verbose: bool = True): """ diff --git a/ultralytics/models/nas/predict.py b/ultralytics/models/nas/predict.py index ab0b447161..44a0691914 100644 --- a/ultralytics/models/nas/predict.py +++ b/ultralytics/models/nas/predict.py @@ -2,16 +2,15 @@ import torch -from ultralytics.engine.predictor import BasePredictor -from ultralytics.engine.results import Results +from ultralytics.models.yolo.detect.predict import DetectionPredictor from ultralytics.utils import ops -class NASPredictor(BasePredictor): +class NASPredictor(DetectionPredictor): """ Ultralytics YOLO NAS Predictor for object detection. - This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the + This class extends the `DetectionPredictor` from Ultralytics engine and is responsible for post-processing the raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and scaling the bounding boxes to fit the original image dimensions. @@ -38,23 +37,4 @@ class NASPredictor(BasePredictor): # Convert boxes from xyxy to xywh format and concatenate with class scores boxes = ops.xyxy2xywh(preds_in[0][0]) preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) - - # Apply non-maximum suppression to filter overlapping detections - preds = ops.non_max_suppression( - preds, - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - classes=self.args.classes, - ) - - if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list - orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) - - results = [] - for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]): - # Scale bounding boxes to match original image dimensions - pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) - results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred)) - return results + return super().postprocess(preds, img, orig_imgs) diff --git a/ultralytics/models/nas/val.py b/ultralytics/models/nas/val.py index b45064b907..2ca1065160 100644 --- a/ultralytics/models/nas/val.py +++ b/ultralytics/models/nas/val.py @@ -36,7 +36,4 @@ class NASValidator(DetectionValidator): """Apply Non-maximum suppression to prediction outputs.""" boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding box format from xyxy to xywh preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with scores and permute - return super().postprocess( - preds, - max_time_img=0.5, - ) + return super().postprocess(preds)