Skip to content

Commit

Permalink
TF argmax - handling int64 datatype (apache#6674)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
2 people authored and trevor-m committed Oct 19, 2020
1 parent 8eff7f2 commit 602fd9b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ def _impl(inputs, attr, params, mod):
raise TypeError(
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name)
)
return func(inputs[0], axis=axis_input_value, keepdims=False)
out = func(inputs[0], axis=axis_input_value, keepdims=False)
dtype = attr["output_type"].name
if dtype != "int32":
out = _op.cast(out, dtype=dtype)
return out

return _impl

Expand Down
12 changes: 6 additions & 6 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,16 +1616,16 @@ def _test_argx(func, data, **kwargs):

with tf.Graph().as_default():
inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0")
func(inp, name="argx0", output_type=tf.int32, **kwargs)

func(inp, name="argx0", **kwargs)
compare_tf_with_tvm(data, "c0:0", "argx0:0")


def test_forward_argminmax():
for axis in [None, 0, 1, 2]:
data = np.random.uniform(size=(8, 4, 9)).astype("float32")
_test_argx(tf.argmax, data=data, axis=axis)
_test_argx(tf.argmin, data=data, axis=axis)
for output_type in [tf.int64, tf.int32]:
for axis in [None, 0, 1, 2]:
data = np.random.uniform(size=(8, 4, 9)).astype("float32")
_test_argx(tf.argmax, data=data, axis=axis, output_type=output_type)
_test_argx(tf.argmin, data=data, axis=axis, output_type=output_type)


#######################################################################
Expand Down

0 comments on commit 602fd9b

Please sign in to comment.