From d20ad0b2825af7aa72d9f3a0f415ed1009c7f241 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 9 Apr 2021 11:16:26 -0600 Subject: [PATCH 1/2] Support optional outputs for ONNX nodes --- python/tvm/relay/frontend/onnx.py | 18 ++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 6 ------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 09525a64ac05..f4d4d250581e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3202,6 +3202,24 @@ def from_onnx(self, graph, opset, get_output_expr=False): outputs_num = 1 else: outputs_num = len(op) + if outputs_num > 1: + valid_outputs = [False] * outputs_num + for i in range(len(node_output)): + if node_output[i] != "": + valid_outputs[i] = True + if not all(valid_outputs): + tup = op.astuple() + if isinstance(tup, _expr.Tuple): + outputs = [tup.fields[i] for i, valid in enumerate(valid_outputs) if valid] + else: + outputs = [op[i] for i, valid in enumerate(valid_outputs) if valid] + + if len(outputs) == 1: + op = outputs[0] + else: + op = _expr.TupleWrapper(outputs, len(outputs)) + outputs_num = len(outputs) + node_output = [output for output in node_output if output != ""] assert ( len(node_output) == outputs_num ), "Number of output mismatch {} vs {} in {}.".format( diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a491ed130418..8a63bac33923 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4138,9 +4138,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): unsupported_onnx_tests = [ "test_basic_convinteger/", "test_cast_DOUBLE_to_FLOAT16/", - "test_cast_FLOAT16_to_DOUBLE/", - "test_cast_FLOAT16_to_FLOAT/", - "test_cast_FLOAT_to_FLOAT16/", "test_cast_FLOAT_to_STRING/", "test_cast_STRING_to_FLOAT/", "test_compress_0/", @@ -4171,9 +4168,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_hardmax_one_hot/", "test_isinf_negative/", "test_isinf_positive/", - "test_lstm_defaults/", - "test_lstm_with_initial_bias/", - "test_lstm_with_peepholes/", "test_matmulinteger/", "test_maxpool_2d_dilations/", "test_maxpool_2d_same_lower/", From 4d153b8eaca4f8a489e68db6685dfda97f63aed9 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 9 Apr 2021 13:49:38 -0600 Subject: [PATCH 2/2] add comments --- python/tvm/relay/frontend/onnx.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f4d4d250581e..85fe01905b6e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3203,21 +3203,29 @@ def from_onnx(self, graph, opset, get_output_expr=False): else: outputs_num = len(op) if outputs_num > 1: + # ONNX supports optional outputs for some nodes. + # This block searches for missing outputs in the ONNX graph + # and removes any unneeded ops valid_outputs = [False] * outputs_num - for i in range(len(node_output)): - if node_output[i] != "": + for i, output in enumerate(node_output): + if output != "": valid_outputs[i] = True + # If we have outputs ONNX isn't expecting, we need to drop them if not all(valid_outputs): tup = op.astuple() + # TupleWrapper can also wrap ops with TupleType outputs if isinstance(tup, _expr.Tuple): + # For tuples, we extract the fields instead of using GetTupleItem outputs = [tup.fields[i] for i, valid in enumerate(valid_outputs) if valid] else: + # For call nodes, we need to GetTupleItem outputs = [op[i] for i, valid in enumerate(valid_outputs) if valid] - + # Create the new op with valid outputs if len(outputs) == 1: op = outputs[0] else: op = _expr.TupleWrapper(outputs, len(outputs)) + # Drop invalid outputs for the onnx node outputs_num = len(outputs) node_output = [output for output in node_output if output != ""] assert (