Skip to content

Commit

Permalink
fix in_shape is 1-dims error
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyang-star committed Jan 24, 2022
1 parent 38ba946 commit d356c3a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
11 changes: 9 additions & 2 deletions lite/core/optimizer/mir/fusion/keepdims_convert_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,16 @@ std::vector<int> KeepdimsConvertFuser::GetTensorDims(const Node::Stmt* inst) {
}

const auto& tensor = var->Get<Tensor>();
VLOG(4) << "tensor dims: " << tensor.dims();
std::vector<int> dims;
for (auto iter : tensor.dims().Vectorize()) {
dims.push_back(iter);
// Out dims may be empty. For example, argmax's in dims{3}, keepdims=false, axis=0.
// Set out dims manually.
if (tensor.dims().empty()) {
dims.push_back(1);
} else {
for (auto iter : tensor.dims().Vectorize()) {
dims.push_back(iter);
}
}
return dims;
}
Expand Down
2 changes: 2 additions & 0 deletions lite/kernels/opencl/argmax_image_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class ArgmaxComputeImage2D : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>();
argmax_param_ = param_.get_mutable<param_t>();
auto& x_dims = argmax_param_->X->dims();
bool keepdims = argmax_param_->keepdims;
CHECK(keepdims) << "OpenCL argmax kernel only support keepdims=true. keepdims=false case will be converted by keepdims_convert_pass.";

// padding to 4-dims
in_nchw_ = x_dims.Vectorize();
Expand Down

0 comments on commit d356c3a

Please sign in to comment.