diff --git a/demo/predict.py b/demo/predict.py index 872d018f..df7743ed 100644 --- a/demo/predict.py +++ b/demo/predict.py @@ -37,6 +37,7 @@ def get_parser_infer(parents=None): parser.add_argument( "--single_cls", type=ast.literal_eval, default=False, help="train multi-class data as single-class" ) + parser.add_argument("--exec_nms", type=ast.literal_eval, default=True, help="whether to execute NMS or not") parser.add_argument("--nms_time_limit", type=float, default=60.0, help="time limit for NMS") parser.add_argument("--conf_thres", type=float, default=0.25, help="object confidence threshold") parser.add_argument("--iou_thres", type=float, default=0.65, help="IOU threshold for NMS") @@ -94,6 +95,7 @@ def detect( conf_thres: float = 0.25, iou_thres: float = 0.65, conf_free: bool = False, + exec_nms: bool = True, nms_time_limit: float = 60.0, img_size: int = 640, stride: int = 32, @@ -129,14 +131,15 @@ def detect( # Run NMS t = time.time() out = out.asnumpy() - out = non_max_suppression( - out, - conf_thres=conf_thres, - iou_thres=iou_thres, - conf_free=conf_free, - multi_label=True, - time_limit=nms_time_limit, - ) + if exec_nms: + out = non_max_suppression( + out, + conf_thres=conf_thres, + iou_thres=iou_thres, + conf_free=conf_free, + multi_label=True, + time_limit=nms_time_limit, + ) nms_times = time.time() - t result_dict = {"category_id": [], "bbox": [], "score": []} @@ -305,6 +308,7 @@ def infer(args): conf_thres=args.conf_thres, iou_thres=args.iou_thres, conf_free=args.conf_free, + exec_nms=args.exec_nms, nms_time_limit=args.nms_time_limit, img_size=args.img_size, stride=max(max(args.network.stride), 32),