diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 341860146743..7652f77be290 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -146,11 +146,11 @@ struct GatherAttrs : public tvm::AttrsNode { struct GatherNDAttrs : public tvm::AttrsNode { Integer batch_dims; - Integer gather_dim; + Integer num_indices_per_tuple; TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); - TVM_ATTR_FIELD(gather_dim) + TVM_ATTR_FIELD(num_indices_per_tuple) .set_default(Integer(-1)) .describe( "The size of an indexing tuple, which is a fixed value. Only needed when the number of " diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4e77f38184d0..40958116517f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1436,8 +1436,8 @@ def _impl_common(cls, data, indices, batch_dims=0): indices_dims = len(infer_shape(indices)) indices_shape = infer_shape(indices) indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) - gather_dim = indices_shape[-1] - return _op.gather_nd(data, indices, batch_dims, gather_dim) + num_indices_per_tuple = indices_shape[-1] + return _op.gather_nd(data, indices, batch_dims, num_indices_per_tuple) @classmethod def _impl_v1(cls, inputs, attr, params): diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 0a6654206006..c158075f53a8 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1130,14 +1130,14 @@ def unique_shape_func(attrs, inputs, _): @script -def _gather_nd_shape(data_shape, indices_shape, batch_dims, gather_dim): +def _gather_nd_shape(data_shape, indices_shape, batch_dims, num_indices_per_tuple): ndim = data_shape.shape[0] # using mdim = indices_shape[0] wouldn't work because a rank cannot # depend on a runtime shape dimension of indices tensor, even if the # dimension is always a known, fixed value. As a workaround, we assume that # the fixed gather dimension (the size of an indexing tuple) is recorded # in `gather_nd` op attribute. - mdim = gather_dim + mdim = num_indices_per_tuple kdim = indices_shape.shape[0] - 1 out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64") for i in range(1, kdim + 1): @@ -1153,6 +1153,6 @@ def gather_nd_shape_func(attrs, inputs, _): Shape func for ghater_nd operator. """ batch_dims = get_const_int(attrs.batch_dims) - gather_dim = get_const_int(attrs.gather_dim) - assert gather_dim > 0, "gather_dim needs to be specified for dynamic gather_nd" - return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(gather_dim))] + num_indices_per_tuple = get_const_int(attrs.num_indices_per_tuple) + assert num_indices_per_tuple > 0, "num_indices_per_tuple needs to be specified for dynamic gather_nd" + return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(num_indices_per_tuple))] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index ed0c66fe5c3f..7c7968cf5631 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1075,7 +1075,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices, batch_dims=0, gather_dim=-1): +def gather_nd(data, indices, batch_dims=0, num_indices_per_tuple=-1): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1090,7 +1090,7 @@ def gather_nd(data, indices, batch_dims=0, gather_dim=-1): batch_dims : int The number of batch dimensions. - gather_dim : int + num_indices_per_tuple : int The size of an indexing tuple, which is a fixed value and the same as indices.shape[0] Only needed when other dimensions of indices are dynamic. @@ -1115,7 +1115,7 @@ def gather_nd(data, indices, batch_dims=0, gather_dim=-1): indices = [[1, 0]] relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]] """ - return _make.gather_nd(data, indices, batch_dims, gather_dim) + return _make.gather_nd(data, indices, batch_dims, num_indices_per_tuple) def sequence_mask(data, valid_length, mask_value=0, axis=0): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 98347b8e2cb9..7b06f2b7112e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3600,11 +3600,11 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } -Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int gather_dim = -1) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int num_indices_per_tuple = -1) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); attrs->batch_dims = batch_dims; - attrs->gather_dim = gather_dim; + attrs->num_indices_per_tuple = num_indices_per_tuple; return Call(op, {data, indices}, Attrs(attrs)); }