Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] Fix xpu_fc_pass, cast and fill_any op #9366

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,34 @@ class XPUFcFuser : public FuseBase {
auto* W = VarNode("W")->assert_is_op_input(mul_type_, "Y")->AsInput();
auto* mul = OpNode("mul", mul_type_)->AsIntermediate();
auto* mul_out = VarNode("mul_out")->assert_is_op_output(mul_type_, "Out");
auto input_attr_teller = [](const Node* node) -> bool {
auto op_desc = *const_cast<Node*>(node)->stmt()->op_info();
bool trans_x = op_desc.GetAttr<bool>("transpose_X");
bool trans_y = op_desc.GetAttr<bool>("transpose_Y");
// assert alpha = 1.0f
auto alpha = op_desc.GetAttr<float>("alpha");
bool has_alpha = (fabsf(alpha - 1.f) > 1e-8f);
auto res = (trans_x == false && trans_y == false && !has_alpha);
return res;
};
auto input_attr_teller_v2 = [](const Node* node) -> bool {
auto op_desc = *const_cast<Node*>(node)->stmt()->op_info();
bool trans_x = op_desc.GetAttr<bool>("trans_x");
bool trans_y = op_desc.GetAttr<bool>("trans_y");
bool has_alpha = false;
if (op_desc.HasAttr("alpha")) {
auto alpha = op_desc.GetAttr<float>("alpha");
has_alpha = (fabsf(alpha - 1.f) > 1e-8f);
}
bool res = (trans_x == false && trans_y == false && !has_alpha);
return res;
};
if (mul_type_ == "matmul") {
mul = OpNode("mul", mul_type_)->assert_node_satisfied(input_attr_teller);
} else if (mul_type_ == "matmul_v2") {
mul =
OpNode("mul", mul_type_)->assert_node_satisfied(input_attr_teller_v2);
}
PMNode* bias = nullptr;
PMNode* add = nullptr;
PMNode* add_out = nullptr;
Expand Down
12 changes: 12 additions & 0 deletions lite/kernels/xpu/cast_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ void CastCompute<InType>::Run() {
auto* out_data = out->template mutable_data<int64_t>(TARGET(kXPU));
r = xdnn::cast_v2<InType, int64_t>(
ctx.GetRawContext(), in_data, out_data, numel);
} else if (out_dtype == 0) {
auto* out_data = out->template mutable_data<bool>(TARGET(kXPU));
XPUScratchPadGuard out_int_guard =
TargetWrapperXPU::MallocScratchPad(out->numel() * sizeof(int));
r = xdnn::cast_v2<InType, int>(ctx.GetRawContext(),
in_data,
reinterpret_cast<int*>(out_int_guard->addr_),
numel);
XPU_CALL(xpu_memcpy(out_data,
reinterpret_cast<bool*>(out_int_guard->addr_),
out->numel() * sizeof(bool),
XPUMemcpyKind::XPU_DEVICE_TO_DEVICE));
} else {
LOG(FATAL) << "cast from in_type("
<< lite_api::PrecisionToStr(
Expand Down
11 changes: 11 additions & 0 deletions lite/kernels/xpu/fill_any_like_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ void FillAnyLikeCompute::Run() {
case PRECISION(kInt64):
dtype = static_cast<int32_t>(lite::core::FluidType::INT64);
break;
case PRECISION(kBool):
dtype = static_cast<int32_t>(lite::core::FluidType::BOOL);
break;
default:
LOG(FATAL) << "not supported x dtype: "
<< lite_api::PrecisionToStr(param.X->precision());
Expand All @@ -51,6 +54,14 @@ void FillAnyLikeCompute::Run() {

int r = 0;
switch (dtype) {
case 0: {
auto data = param.Out->mutable_data<bool>(TARGET(kXPU));
r = xdnn::constant<bool>(ctx.GetRawContext(),
data,
write_size,
static_cast<bool>(param.value));
break;
}
case 1: {
auto data = param.Out->mutable_data<int16_t>(TARGET(kXPU));
r = xdnn::constant<int16_t>(ctx.GetRawContext(),
Expand Down