Skip to content

Commit

Permalink
[XPU]. Modified based on comment.
Browse files Browse the repository at this point in the history
  • Loading branch information
wbn03 committed Jul 29, 2022
1 parent 5fc47d6 commit 233ffec
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
6 changes: 3 additions & 3 deletions lite/kernels/xpu/generate_proposals_v2_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ void GenerateProposalsV2Compute::Run() {
float min_size = param.min_size;
float eta = param.eta;
bool pixel_offset = param.pixel_offset;
if (std::fabs(eta - 1.0f) > 1e-7) {
LOG(FATAL) << "XPU Generate Proposals Don't Support Adaptive NMS.";
}

// XPU Generate Proposals Don't Support Adaptive NMS.
CHECK_LT(std::fabs(eta - 1.0f), 1e-7);

auto& scores_dim = scores->dims();
int num = static_cast<int>(scores_dim[0]);
Expand Down
12 changes: 5 additions & 7 deletions lite/tests/kernels/generate_proposals_v2_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,17 @@ class GenerateProposalsV2ComputeTester : public arena::TestCase {
template <typename T, typename IndexT = int>
void CPUGather(const Tensor& src, const Tensor& index, Tensor* output) {
if (index.dims().size() == 2) {
// index.dims()[1] should be 1 when index.dims().size() = 2
CHECK_NE(index.dims()[1], 1);
if (index.dims()[1] != 1) {
LOG(FATAL) << "index.dims()[1] should be 1 when index.dims().size() = 2"
LOG(FATAL) << "i"
"in gather_op, but received value is "
<< index.dims()[1];
}

} else {
if (index.dims().size() != 1) {
LOG(FATAL) << "index.dims().size() should be 1 or 2 in gather_op,"
"but received shape's size is "
<< index.dims().size();
}
// index.dims().size() should be 1 or 2 in gather_op
CHECK_NE(index.dims().size(), 1);
}
int64_t index_size = index.dims()[0];

Expand Down Expand Up @@ -702,7 +701,6 @@ class GenerateProposalsV2ComputeTester : public arena::TestCase {
SetCommonTensor(Anchors_, anchors.dims(), anchors.data<float>());

// Variances

SetCommonTensor(Variances_, variances.dims(), variances.data<float>());

// Scores
Expand Down

0 comments on commit 233ffec

Please sign in to comment.