From 5c9c6e6d399b722f8589fe83bcaa7e396002767d Mon Sep 17 00:00:00 2001 From: srinidhigoud Date: Fri, 16 Jul 2021 23:26:41 -0700 Subject: [PATCH] [Frontend][Tensorflow2] Stridedslice and concat_v2 fix (#8483) * fix for strided_slice when begin > end in case of shrinkaxis_mask * fix for name_hint missing error for concat_v2 op * removing a local fix * adding more testing capability to concat_v2 --- python/tvm/relay/frontend/tensorflow_ops.py | 10 ++++++++-- tests/python/frontend/tensorflow/test_forward.py | 2 ++ .../frontend/tensorflow2/test_functional_models.py | 3 ++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 797ff51ace7a..ba0fcca0197d 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1483,7 +1483,13 @@ def _impl(inputs, attr, params, mod): def _concatV2(): def _impl(inputs, attr, params, mod): pop_node = inputs.pop(len(inputs) - 1) - axis = int(_get_num_param(params, pop_node)) + try: + axis = int(_get_num_param(params, pop_node)) + except (IndexError, KeyError, AttributeError): + try: + axis = int(_infer_value(pop_node, params, mod).numpy()) + except Exception: + axis = int(pop_node) return AttrCvt(op_name="concatenate", ignores=["T", "N", "Tidx"], extras={"axis": axis})( [inputs], attr ) @@ -2244,7 +2250,7 @@ def _transform_mask(stride_dim, ellipsis_mask): if begin[index] < 0 else begin[index] ) - m_end[final_index] = begin[index] + 1 + m_end[final_index] = m_begin[final_index] + 1 m_stride[final_index] = 1 fshape_indices.append(-2) else: diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c942411471cd..2341b11bd726 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2568,7 +2568,9 @@ def test_forward_stridedslice(): _test_stridedslice([], [0], [0], [1], "float32", new_axis_mask=1) _test_stridedslice([2], [1], [1], [1], "float32", shrink_axis_mask=1) + _test_stridedslice([4], [-1], [0], [1], "float32", shrink_axis_mask=1) _test_stridedslice([2, 1], [0], [1], [1], "float32", shrink_axis_mask=1) + _test_stridedslice([2, 3, 4], [-2], [0], [1], "float32", shrink_axis_mask=8) _test_stridedslice([2, 3, 4], [0], [1], [1], "float32", shrink_axis_mask=8) _test_stridedslice([3, 4, 3], [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32") _test_stridedslice([3, 4, 3], [1, 0], [4, 3], [2, 1], "float32", ellipsis_mask=8) diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index b3504ff38328..a39ecb411f15 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -354,7 +354,8 @@ def get_input(self): @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) def func(self, x): a, b, c = tf.split(x, 3, axis=1) - return tf.raw_ops.ConcatV2(values=[a, b, c], axis=1) + axis = tf.add(tf.constant(1, dtype="int32"), tf.constant(0, dtype="int32")) + return tf.raw_ops.ConcatV2(values=[a, b, c], axis=axis) run_all(ConcatV2)