diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index ce372c78e250..86e19fa121af 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -272,25 +272,62 @@ def obj_detection_test_images(tmpdir_factory): 'center_net_resnet18_v1b_coco', 'center_net_resnet50_v1b_coco', 'center_net_resnet101_v1b_coco', - # the following models are failing due to onnxruntime errors - #'ssd_300_vgg16_atrous_voc', - #'ssd_512_vgg16_atrous_voc', - #'ssd_512_resnet50_v1_voc', - #'ssd_512_mobilenet1.0_voc', - #'faster_rcnn_resnet50_v1b_voc', - #'yolo3_darknet53_voc', - #'yolo3_mobilenet1.0_voc', - #'ssd_300_vgg16_atrous_coco', - #'ssd_512_vgg16_atrous_coco', - #'ssd_300_resnet34_v1b_coco', - #'ssd_512_resnet50_v1_coco', - #'ssd_512_mobilenet1.0_coco', - #'faster_rcnn_resnet50_v1b_coco', - #'faster_rcnn_resnet101_v1d_coco', - #'yolo3_darknet53_coco', - #'yolo3_mobilenet1.0_coco', + 'ssd_300_vgg16_atrous_voc', + 'ssd_512_vgg16_atrous_voc', + 'ssd_512_resnet50_v1_voc', + 'ssd_512_mobilenet1.0_voc', + 'faster_rcnn_resnet50_v1b_voc', + 'yolo3_darknet53_voc', + 'yolo3_mobilenet1.0_voc', + 'ssd_300_vgg16_atrous_coco', + 'ssd_512_vgg16_atrous_coco', + # 'ssd_300_resnet34_v1b_coco', #cannot import + 'ssd_512_resnet50_v1_coco', + 'ssd_512_mobilenet1.0_coco', + 'faster_rcnn_resnet50_v1b_coco', + 'faster_rcnn_resnet101_v1d_coco', + 'yolo3_darknet53_coco', + 'yolo3_mobilenet1.0_coco', ]) def test_obj_detection_model_inference_onnxruntime(tmp_path, model, obj_detection_test_images): + def assert_obj_detetion_result(mx_ids, mx_scores, mx_boxes, + onnx_ids, onnx_scores, onnx_boxes, + score_thresh=0.6, score_tol=1e-4): + def assert_bbox(mx_boxe, onnx_boxe, box_tol=1e-2): + def assert_scalar(a, b, tol=box_tol): + return np.abs(a-b) <= tol + return assert_scalar(mx_boxe[0], onnx_boxe[0]) and assert_scalar(mx_boxe[1], onnx_boxe[1]) \ + and assert_scalar(mx_boxe[2], onnx_boxe[2]) and assert_scalar(mx_boxe[3], onnx_boxe[3]) + + found_match = False + for i in range(len(onnx_ids)): + onnx_id = onnx_ids[i][0] + onnx_score = onnx_scores[i][0] + onnx_boxe = onnx_boxes[i] + + if onnx_score < score_thresh: + break + for j in range(len(mx_ids)): + mx_id = mx_ids[j].asnumpy()[0] + mx_score = mx_scores[j].asnumpy()[0] + mx_boxe = mx_boxes[j].asnumpy() + # check socre + if onnx_score < mx_score - score_tol: + continue + if onnx_score > mx_score + score_tol: + return False + # check id + if onnx_id != mx_id: + continue + # check bounding box + if assert_bbox(mx_boxe, onnx_boxe): + found_match = True + break + if not found_match: + return False + found_match = False + return True + def normalize_image(imgfile): img = mx.image.imread(imgfile) img, _ = mx.image.center_crop(img, size=(512, 512)) @@ -310,10 +347,17 @@ def normalize_image(imgfile): for img in obj_detection_test_images: img_data = normalize_image(img) mx_class_ids, mx_scores, mx_boxes = M.predict(img_data) - onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()}) - assert_almost_equal(mx_class_ids, onnx_class_ids) - assert_almost_equal(mx_scores, onnx_scores) - assert_almost_equal(mx_boxes, onnx_boxes) + # center_net_resnet models have different output format + if 'center_net_resnet' in model: + onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()}) + assert_almost_equal(mx_class_ids, onnx_class_ids) + assert_almost_equal(mx_scores, onnx_scores) + assert_almost_equal(mx_boxes, onnx_boxes) + else: + onnx_class_ids, onnx_scores, onnx_boxes = session.run([], {input_name: img_data.asnumpy()}) + if not assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0], \ + onnx_class_ids[0], onnx_scores[0], onnx_boxes[0]): + raise AssertionError("Assertion error on model: " + model) finally: shutil.rmtree(tmp_path)