Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
reshape corner cases for mask rcnn (#19875)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zha0q1 authored Feb 16, 2021
1 parent c123c32 commit 95f3723
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
24 changes: 24 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,6 +1682,30 @@ def convert_reshape(node, **kwargs):
]
return nodes

if targ_shape == [-4, 1, -1, 0, 0, 0] and reverse != 'True':
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([-1], name+'_m1', kwargs['initializer'])
nodes = [
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2',
name+'_dim3'], axis=0),
make_node('Concat', [name+'_1', name+'_m1', name+'_dim1', name+'_dim2', name+'_dim3'],
[name+'_shape_new'], axis=0),
make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name)
]
return nodes

if targ_shape == [-4, 1, 1000, 0, 0] and reverse != 'True':
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([1000], name+'_1000', kwargs['initializer'])
nodes = [
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2'], axis=0),
make_node('Concat', [name+'_1', name+'_1000', name+'_dim1', name+'_dim2'],
[name+'_shape_new'], axis=0),
make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name)
]
return nodes

not_supported_shape = [-2, -3, -4]
for val in targ_shape:
Expand Down
8 changes: 8 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ def test_onnx_export_reshape_special_cases(tmp_path, dtype):
M7 = def_model('reshape', shape=(0, 0, -4, 2, 2, 0, 0))
op_export_test('reshape_spec_7', M7, [x5], tmp_path)

x6 = mx.nd.ones((8, 7, 6, 5), dtype=dtype)
M8 = def_model('reshape', shape=(-4, 1, -1, 0, 0, 0))
op_export_test('reshape_spec_8', M8, [x6], tmp_path)

x7 = mx.nd.ones((1000, 2, 3), dtype=dtype)
M9 = def_model('reshape', shape=(-4, 1, 1000, 0, 0))
op_export_test('reshape_spec_9', M9, [x7], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64'])
def test_onnx_export_embedding(tmp_path, dtype):
Expand Down

0 comments on commit 95f3723

Please sign in to comment.