Skip to content

Commit

Permalink
TF argmax - handling int64 datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Oct 12, 2020
1 parent b277f18 commit 59e3fcd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ def _impl(inputs, attr, params, mod):
raise TypeError(
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name)
)
return func(inputs[0], axis=axis_input_value, keepdims=False)
out = func(inputs[0], axis=axis_input_value, keepdims=False)
dtype = attr["output_type"].name
if dtype != "int32":
out = _op.cast(out, dtype=dtype)
return out

return _impl

Expand Down
12 changes: 6 additions & 6 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,16 +1601,16 @@ def _test_argx(func, data, **kwargs):

with tf.Graph().as_default():
inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0")
func(inp, name="argx0", output_type=tf.int32, **kwargs)

func(inp, name="argx0", **kwargs)
compare_tf_with_tvm(data, "c0:0", "argx0:0")


def test_forward_argminmax():
for axis in [None, 0, 1, 2]:
data = np.random.uniform(size=(8, 4, 9)).astype("float32")
_test_argx(tf.argmax, data=data, axis=axis)
_test_argx(tf.argmin, data=data, axis=axis)
for output_type in [tf.int64, tf.int32]:
for axis in [None, 0, 1, 2]:
data = np.random.uniform(size=(8, 4, 9)).astype("float32")
_test_argx(tf.argmax, data=data, axis=axis, output_type=output_type)
_test_argx(tf.argmin, data=data, axis=axis, output_type=output_type)


#######################################################################
Expand Down

0 comments on commit 59e3fcd

Please sign in to comment.