Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactor code (#19887)
Browse files Browse the repository at this point in the history
Co-authored-by: Wei Chu <[email protected]>
  • Loading branch information
waytrue17 and Wei Chu authored Feb 16, 2021
1 parent 95f3723 commit 3b470d1
Showing 1 changed file with 65 additions and 21 deletions.
86 changes: 65 additions & 21 deletions tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,25 +272,62 @@ def obj_detection_test_images(tmpdir_factory):
'center_net_resnet18_v1b_coco',
'center_net_resnet50_v1b_coco',
'center_net_resnet101_v1b_coco',
# the following models are failing due to onnxruntime errors
#'ssd_300_vgg16_atrous_voc',
#'ssd_512_vgg16_atrous_voc',
#'ssd_512_resnet50_v1_voc',
#'ssd_512_mobilenet1.0_voc',
#'faster_rcnn_resnet50_v1b_voc',
#'yolo3_darknet53_voc',
#'yolo3_mobilenet1.0_voc',
#'ssd_300_vgg16_atrous_coco',
#'ssd_512_vgg16_atrous_coco',
#'ssd_300_resnet34_v1b_coco',
#'ssd_512_resnet50_v1_coco',
#'ssd_512_mobilenet1.0_coco',
#'faster_rcnn_resnet50_v1b_coco',
#'faster_rcnn_resnet101_v1d_coco',
#'yolo3_darknet53_coco',
#'yolo3_mobilenet1.0_coco',
'ssd_300_vgg16_atrous_voc',
'ssd_512_vgg16_atrous_voc',
'ssd_512_resnet50_v1_voc',
'ssd_512_mobilenet1.0_voc',
'faster_rcnn_resnet50_v1b_voc',
'yolo3_darknet53_voc',
'yolo3_mobilenet1.0_voc',
'ssd_300_vgg16_atrous_coco',
'ssd_512_vgg16_atrous_coco',
# 'ssd_300_resnet34_v1b_coco', #cannot import
'ssd_512_resnet50_v1_coco',
'ssd_512_mobilenet1.0_coco',
'faster_rcnn_resnet50_v1b_coco',
'faster_rcnn_resnet101_v1d_coco',
'yolo3_darknet53_coco',
'yolo3_mobilenet1.0_coco',
])
def test_obj_detection_model_inference_onnxruntime(tmp_path, model, obj_detection_test_images):
def assert_obj_detetion_result(mx_ids, mx_scores, mx_boxes,
onnx_ids, onnx_scores, onnx_boxes,
score_thresh=0.6, score_tol=1e-4):
def assert_bbox(mx_boxe, onnx_boxe, box_tol=1e-2):
def assert_scalar(a, b, tol=box_tol):
return np.abs(a-b) <= tol
return assert_scalar(mx_boxe[0], onnx_boxe[0]) and assert_scalar(mx_boxe[1], onnx_boxe[1]) \
and assert_scalar(mx_boxe[2], onnx_boxe[2]) and assert_scalar(mx_boxe[3], onnx_boxe[3])

found_match = False
for i in range(len(onnx_ids)):
onnx_id = onnx_ids[i][0]
onnx_score = onnx_scores[i][0]
onnx_boxe = onnx_boxes[i]

if onnx_score < score_thresh:
break
for j in range(len(mx_ids)):
mx_id = mx_ids[j].asnumpy()[0]
mx_score = mx_scores[j].asnumpy()[0]
mx_boxe = mx_boxes[j].asnumpy()
# check socre
if onnx_score < mx_score - score_tol:
continue
if onnx_score > mx_score + score_tol:
return False
# check id
if onnx_id != mx_id:
continue
# check bounding box
if assert_bbox(mx_boxe, onnx_boxe):
found_match = True
break
if not found_match:
return False
found_match = False
return True

def normalize_image(imgfile):
img = mx.image.imread(imgfile)
img, _ = mx.image.center_crop(img, size=(512, 512))
Expand All @@ -310,10 +347,17 @@ def normalize_image(imgfile):
for img in obj_detection_test_images:
img_data = normalize_image(img)
mx_class_ids, mx_scores, mx_boxes = M.predict(img_data)
onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
assert_almost_equal(mx_class_ids, onnx_class_ids)
assert_almost_equal(mx_scores, onnx_scores)
assert_almost_equal(mx_boxes, onnx_boxes)
# center_net_resnet models have different output format
if 'center_net_resnet' in model:
onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
assert_almost_equal(mx_class_ids, onnx_class_ids)
assert_almost_equal(mx_scores, onnx_scores)
assert_almost_equal(mx_boxes, onnx_boxes)
else:
onnx_class_ids, onnx_scores, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
if not assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0], \
onnx_class_ids[0], onnx_scores[0], onnx_boxes[0]):
raise AssertionError("Assertion error on model: " + model)

finally:
shutil.rmtree(tmp_path)
Expand Down

0 comments on commit 3b470d1

Please sign in to comment.