Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Host] move beam_search #5759

Merged
merged 3 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lite/backends/arm/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
norm.cc
pad2d.cc
negative.cc
beam_search.cc
reduce_max.cc
reduce_min.cc
reduce_max_min.cc
Expand Down
1 change: 0 additions & 1 deletion lite/backends/arm/math/funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "lite/backends/arm/math/anchor_generator.h"
#include "lite/backends/arm/math/argmax.h"
#include "lite/backends/arm/math/axpy.h"
#include "lite/backends/arm/math/beam_search.h"
#include "lite/backends/arm/math/box_coder.h"
#include "lite/backends/arm/math/clip.h"
#include "lite/backends/arm/math/col_im_transform.h"
Expand Down
1 change: 1 addition & 0 deletions lite/backends/host/math/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
lite_cc_library(math_host SRCS
beam_search.cc
sequence_padding.cc
slice.cc
pad3d.cc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/backends/arm/math/beam_search.h"
#include <arm_neon.h>
#include "lite/backends/host/math/beam_search.h"
#include <cmath>
#include <string>
#include <vector>
#include "lite/utils/cp_logging.h"

namespace paddle {
namespace lite {
namespace arm {
namespace host {
namespace math {
/*
* The basic items help to sort.
Expand Down Expand Up @@ -207,9 +205,7 @@ void beam_search(const Tensor *pre_ids,
int level,
int beam_size,
int end_id,
bool is_accumulated,
Context<TARGET(kARM)> *ctx) {
// auto abs_lod = framework::ToAbsOffset(scores->lod());
bool is_accumulated) {
auto abs_lod = scores->lod();
auto &high_level = abs_lod[level];
auto items = SelectTopBeamSizeItems(pre_ids,
Expand Down Expand Up @@ -266,6 +262,6 @@ void beam_search(const Tensor *pre_ids,
}

} // namespace math
} // namespace arm
} // namespace host
} // namespace lite
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
// limitations under the License.

#pragma once

#include <cmath>
#include "lite/core/context.h"

namespace paddle {
namespace lite {
namespace arm {
namespace host {
namespace math {

void beam_search(const Tensor* pre_ids,
Expand All @@ -32,10 +30,9 @@ void beam_search(const Tensor* pre_ids,
int level,
int beam_size,
int end_id,
bool is_accumulated,
Context<TARGET(kARM)>* ctx);
bool is_accumulated);

} // namespace math
} // namespace arm
} // namespace host
} // namespace lite
} // namespace paddle
2 changes: 0 additions & 2 deletions lite/backends/x86/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ math_library(sequence2batch)
math_library(sequence_pooling DEPS math_function jit_kernel_helper)
math_library(sequence_scale)
math_library(softmax DEPS math_function jit_kernel_helper)
math_library(beam_search DEPS math_function)
#
## math_library(matrix_bit_code)
#
Expand All @@ -90,7 +89,6 @@ endif()
# cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col)
# cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
# cc_test(sequence_pooling_test SRCS sequence_pooling_test.cc DEPS sequence_pooling)
# cc_test(beam_search_test SRCS beam_search_test.cc DEPS beam_search)
# cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
# cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
math_library(box_coder DEPS math_function)
Expand Down
Loading