From f362e0c80045804045a1d0aa98f55231948b629d Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 8 Jan 2021 00:32:29 +0000 Subject: [PATCH 1/7] repeat op --- .../contrib/onnx/mx2onnx/_op_translations.py | 71 ++++++++++++++++++- tests/python-pytest/onnx/test_operators.py | 9 +++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 07537a3d99f5..8039f3b2d0ef 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2721,7 +2721,76 @@ def convert_arange(node, **kwargs): create_const_scalar_node(name+"_start", np.array([start], dtype=dtype), kwargs), create_const_scalar_node(name+"_stop", np.array([stop], dtype=dtype), kwargs), create_const_scalar_node(name+"_step", np.array([step], dtype=dtype), kwargs), - make_node("Range", [name+"_start", name+"_stop", name+"_step"], [name]) + make_node("Range", [name+"_start", name+"_stop", name+"_step"], [name], name=name) ] return nodes + + +@mx_op.register('repeat') +def convert_arange(node, **kwargs): + """Map MXNet's repeat operator attributes to onnx's Tile operator. + """ + from onnx.helper import make_node + from onnx import TensorProto + name, input_nodes, attrs = get_inputs(node, kwargs) + + opset_version = kwargs['opset_version'] + if opset_version < 11: + raise AttributeError('ONNX opset 11 or greater is required to export this operator') + + repeats = int(attrs.get('repeats', 1)) + axis = attrs.get('axis', 'None') + + if repeats <= 0: + raise NotImplementedError('repeat operator does not support parameter repeats==0') + + nodes = [] + if axis == 'None': + print('dooooo') + nodes += [ + create_tensor([repeats], name+'_rep', kwargs['initializer']), + create_tensor([1, repeats], name+'_repeats', kwargs['initializer']), + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('ReduceProd', [name+'_shape'], [name+'_size']), + make_node('Reshape', [input_nodes[0], name+'_size'], [name+'_flat']), + make_node('Unsqueeze', [name+'_flat'], [name+'_unsqueeze'], axes=[-1]), + make_node('Tile', [name+'_unsqueeze', name+'_repeats'], [name+'_tile']), + make_node('Mul', [name+'_size', name+'_rep'], [name+'_new_size']), + make_node('Reshape', [name+'_tile', name+'_new_size'], [name], name=name) + ] + else: + axis = int(axis) + repeats -= 1 + nodes += [ + create_tensor([repeats], name+'_repeats', kwargs['initializer']), + create_tensor([1], name+'_1', kwargs['initializer']), + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor([], name+'_void', kwargs['initializer']), + create_tensor([axis], name+'_axis', kwargs['initializer']), + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Shape', [name+'_shape'], [name+'_dim']), + make_node('Reshape', [name+'_dim', name+'_void'], [name+'_dim_s']), + make_node('Range', [name+'_0', name+'_dim_s', name+'_1'], [name+'_range']) + ] + if axis < 0: + nodes += [ + make_node('Add', [name+'_axis', name+'_dim'], [name+'_true_axis']), + make_node('Equal', [name+'_range', name+'_true_axis'], [name+'_one_hot']) + ] + else: + nodes += [ + make_node('Equal', [name+'_range', name+'_axis'], [name+'_one_hot']) + ] + nodes += [ + make_node('Cast', [name+'_one_hot'], [name+'_one_hot_int'], to=int(TensorProto.INT64)), + make_node('Mul', [name+'_repeats', name+'_one_hot_int'], [name+'_mul']), + make_node('Add', [name+'_mul', name+'_1'], [name+'_add']), + make_node('Concat', [name+'_1', name+'_add'], [name+'_repeats_tensor'], axis=0), + make_node('Unsqueeze', [input_nodes[0]], [name+'_unsqueeze'], axes=[axis+1]), + make_node('Tile', [name+'_unsqueeze', name+'_repeats_tensor'], [name+'_tile']), + make_node('Mul', [name+'_shape', name+'_add'], [name+'_new_shape']), + make_node('Reshape', [name+'_tile', name+'_new_shape'], [name], name=name) + ] + + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 34838a081996..01284e5bb3c6 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -351,3 +351,12 @@ def test_onnx_export_softmax(tmp_path, dtype): M4 = def_model('softmax', use_length=True, axis=1) l4 = mx.nd.array([[2,0,3,1],[0,1,0,0]], dtype=int) op_export_test('softmax_4', M4, [x, l4], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) +@pytest.mark.parametrize('axis', [None, 0, 1, 2]) +@pytest.mark.parametrize('repeats', [2, 1, 3]) +def test_onnx_export_repeat(tmp_path, dtype, axis, repeats): + x = mx.nd.arange(0, 27, dtype=dtype).reshape((3, 3, 3)) + M = def_model('repeat', axis=axis, repeats=repeats) + op_export_test('repeat', M, [x], tmp_path) From 4f4e8a4511e0aa5b1fc4b772dbbc5ae7b6014609 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 8 Jan 2021 00:36:51 +0000 Subject: [PATCH 2/7] remove extra print --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 8039f3b2d0ef..3f215a4ce9f8 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2747,7 +2747,6 @@ def convert_arange(node, **kwargs): nodes = [] if axis == 'None': - print('dooooo') nodes += [ create_tensor([repeats], name+'_rep', kwargs['initializer']), create_tensor([1, repeats], name+'_repeats', kwargs['initializer']), From f096204f8654b0006568fe82280210c51c7b2015 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Thu, 7 Jan 2021 17:22:47 -0800 Subject: [PATCH 3/7] restore sanity --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 3f215a4ce9f8..2973fe1a01de 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2728,7 +2728,7 @@ def convert_arange(node, **kwargs): @mx_op.register('repeat') -def convert_arange(node, **kwargs): +def convert_repeat(node, **kwargs): """Map MXNet's repeat operator attributes to onnx's Tile operator. """ from onnx.helper import make_node From 3069781496a36d67b320c65a26b9d660a084d09f Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Fri, 8 Jan 2021 15:06:42 -0800 Subject: [PATCH 4/7] Update test_operators.py --- tests/python-pytest/onnx/test_operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 01284e5bb3c6..105c8e5198f6 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -354,7 +354,7 @@ def test_onnx_export_softmax(tmp_path, dtype): @pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) -@pytest.mark.parametrize('axis', [None, 0, 1, 2]) +@pytest.mark.parametrize('axis', [None, 0, 1, 2, -1, -2, -3]) @pytest.mark.parametrize('repeats', [2, 1, 3]) def test_onnx_export_repeat(tmp_path, dtype, axis, repeats): x = mx.nd.arange(0, 27, dtype=dtype).reshape((3, 3, 3)) From e168e4b2dc4885d706c1fb74e179cb53308e3b7e Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Sun, 10 Jan 2021 00:27:25 +0000 Subject: [PATCH 5/7] fix axis=1 case --- .../contrib/onnx/mx2onnx/_op_translations.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 2973fe1a01de..55aa65622403 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2785,8 +2785,19 @@ def convert_repeat(node, **kwargs): make_node('Cast', [name+'_one_hot'], [name+'_one_hot_int'], to=int(TensorProto.INT64)), make_node('Mul', [name+'_repeats', name+'_one_hot_int'], [name+'_mul']), make_node('Add', [name+'_mul', name+'_1'], [name+'_add']), - make_node('Concat', [name+'_1', name+'_add'], [name+'_repeats_tensor'], axis=0), - make_node('Unsqueeze', [input_nodes[0]], [name+'_unsqueeze'], axes=[axis+1]), + make_node('Concat', [name+'_1', name+'_add'], [name+'_repeats_tensor'], axis=0) + ] + if axis == -1: + nodes += [ + make_node('Concat', [name+'_shape', name+'_1'], [name+'_unsqueeze_shape'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_unsqueeze_shape'], + [name+'_unsqueeze']) + ] + else : + nodes += [ + make_node('Unsqueeze', [input_nodes[0]], [name+'_unsqueeze'], axes=[axis+1]) + ] + nodes += [ make_node('Tile', [name+'_unsqueeze', name+'_repeats_tensor'], [name+'_tile']), make_node('Mul', [name+'_shape', name+'_add'], [name+'_new_shape']), make_node('Reshape', [name+'_tile', name+'_new_shape'], [name], name=name) From 201789f4a803acfacc8c48ed986e8c212ca91d38 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Sun, 10 Jan 2021 18:24:49 -0800 Subject: [PATCH 6/7] Update _op_translations.py --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 74bd9802c2d2..57ef546de29f 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2850,7 +2850,7 @@ def convert_repeat(node, **kwargs): make_node('Reshape', [input_nodes[0], name+'_unsqueeze_shape'], [name+'_unsqueeze']) ] - else : + else: nodes += [ make_node('Unsqueeze', [input_nodes[0]], [name+'_unsqueeze'], axes=[axis+1]) ] From 01102f39683b0dd645fb5d3782d775a63ed11837 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Mon, 11 Jan 2021 11:04:27 -0800 Subject: [PATCH 7/7] Update test_operators.py --- tests/python-pytest/onnx/test_operators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 8dcb6a81ff84..c17a03bc3276 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -362,6 +362,7 @@ def test_onnx_export_repeat(tmp_path, dtype, axis, repeats): op_export_test('repeat', M, [x], tmp_path) +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) @pytest.mark.parametrize('params', [{'height': 7, 'width': 13}, {'height': 10, 'width': 16}, {'height': 3, 'width': 5},