Skip to content

Commit

Permalink
[XPU] fix conv2d lstm maxptr size bug (PaddlePaddle#8839)
Browse files Browse the repository at this point in the history
  • Loading branch information
shanliang1992 authored and newway committed Jun 1, 2022
1 parent 7e20bbb commit 1c6a716
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
3 changes: 3 additions & 0 deletions lite/kernels/xpu/__xpu__conv2d_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ bool QuantFilter<int8_t>(const float* filter_on_host,
template <typename T, PrecisionType PType>
void XPUConv2dCompute<T, PType>::PrepareForRun() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
int max_ptr_size = xdnn::get_max_ptr_size(ctx.GetRawContext());
param.output_max->Resize({max_ptr_size});
auto filter_ptr = param.filter->template data<float>();
auto filter_dims = param.filter->dims();

Expand Down
16 changes: 10 additions & 6 deletions lite/kernels/xpu/__xpu__dynamic_lstm_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace xpu {

void XPUDynamicLstmCompute::PrepareForRun() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
int max_ptr_size = xdnn::get_max_ptr_size(ctx.GetRawContext());

// transpose from weight_0[xdim, 4 * hdim] to transpose_weight_0[4 * hdim,
// xdim]
Expand Down Expand Up @@ -139,22 +141,24 @@ void XPUDynamicLstmCompute::PrepareForRun() {
auto weight_0_len = param.weight_0->numel();
float max_weight_0 =
paddle::lite::xpu::math::FindMaxAbs(weight_0_ptr, weight_0_len);
std::vector<float> max_weight_0_v(4, max_weight_0);
weight_0_max_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
std::vector<float> max_weight_0_v(max_ptr_size, max_weight_0);
weight_0_max_ =
TargetWrapperXPU::MallocScratchPad(max_ptr_size * sizeof(float));
float* weight_0_max_addr = reinterpret_cast<float*>(weight_0_max_->addr_);
XPU_CALL(xpu_memcpy(weight_0_max_addr,
max_weight_0_v.data(),
4 * sizeof(float),
max_ptr_size * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));

float max_weight_1 = paddle::lite::xpu::math::FindMaxAbs(
param.weight_1->template data<float>(), param.weight_1->numel());
std::vector<float> max_weight_1_v(4, max_weight_1);
weight_1_max_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
std::vector<float> max_weight_1_v(max_ptr_size, max_weight_1);
weight_1_max_ =
TargetWrapperXPU::MallocScratchPad(max_ptr_size * sizeof(float));
float* weight_1_max_addr = reinterpret_cast<float*>(weight_1_max_->addr_);
XPU_CALL(xpu_memcpy(weight_1_max_addr,
max_weight_1_v.data(),
4 * sizeof(float),
max_ptr_size * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
}

Expand Down

0 comments on commit 1c6a716

Please sign in to comment.