Skip to content

Commit

Permalink
[ONNX] Fix issues for Clip and RoiAlign (apache#7237)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored and masahi committed Jan 14, 2021
1 parent 5810588 commit 3c73e7c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
8 changes: 6 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,7 +2046,7 @@ def _impl_v1(cls, inputs, attr, params):
x = inputs[0]
rois = inputs[1]
batch_indices = inputs[2]
mode = attr.get("mode", "avg")
mode = attr.get("mode", b"avg")
if mode != b"avg":
raise ValueError("RoiAlign in Relay only uses avg mode")
output_height = attr.get("output_height", 1)
Expand All @@ -2056,7 +2056,7 @@ def _impl_v1(cls, inputs, attr, params):
spatial_scale = attr.get("spatial_scale", 1.0)

batch_indices = _op.expand_dims(batch_indices, axis=1, num_newaxis=1)
batch_indices = _op.cast(batch_indices, infer_type(rois).type_annotation.dtype)
batch_indices = _op.cast(batch_indices, infer_type(rois).checked_type.dtype)
rois = _op.concatenate([batch_indices, rois], 1)

return _vision.roi_align(
Expand All @@ -2074,6 +2074,10 @@ def convert_attributes(inputs, attr, params):

@classmethod
def _impl_v1(cls, inputs, attr, params):
if "min" not in attr:
attr["min"] = -np.inf
if "max" not in attr:
attr["max"] = np.inf
return Clip.convert_attributes(inputs, attr, params)

@classmethod
Expand Down
24 changes: 22 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def test_slice():
)


def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs, opset=None):
indata = np.random.uniform(-1, 1, size=inshape).astype(dtype)
outdata = outfunc(indata, **npargs)

Expand All @@ -856,7 +856,7 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
model = helper.make_model(graph, producer_name=opname + "_test")

for target, ctx in tvm.testing.enabled_targets():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype)
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype, opset=opset)
tvm.testing.assert_allclose(outdata, tvm_out)


Expand All @@ -881,6 +881,26 @@ def test_clip():
{"min": -1.0, "max": 1.0},
)

_test_onnx_op_elementwise(
(2, 4, 5, 6),
np.clip,
{"a_min": -np.inf, "a_max": 1.0},
"float32",
"Clip",
{"max": 1.0},
opset=1,
)

_test_onnx_op_elementwise(
(2, 4, 5, 6),
np.clip,
{"a_min": -1.0, "a_max": np.inf},
"float32",
"Clip",
{"min": -1.0},
opset=1,
)


@tvm.testing.uses_gpu
def test_clip_min_max_as_inputs():
Expand Down

0 comments on commit 3c73e7c

Please sign in to comment.