Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-inlined fn calls for some arm vfma intrinsics #1214

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/core_arch/src/arm_shared/neon/generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8697,7 +8697,7 @@ vfmaq_f32_(b, c, a)
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
pub unsafe fn vfma_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t {
vfma_f32(a, b, vdup_n_f32(c))
vfma_f32(a, b, vdup_n_f32_v8(c))
}

/// Floating-point fused Multiply-Add to accumulator(vector)
Expand All @@ -8707,7 +8707,7 @@ pub unsafe fn vfma_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
pub unsafe fn vfmaq_n_f32(a: float32x4_t, b: float32x4_t, c: f32) -> float32x4_t {
vfmaq_f32(a, b, vdupq_n_f32(c))
vfmaq_f32(a, b, vdupq_n_f32_v8(c))
}

/// Floating-point fused multiply-subtract from accumulator
Expand Down Expand Up @@ -8739,7 +8739,7 @@ pub unsafe fn vfmsq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
pub unsafe fn vfms_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t {
vfms_f32(a, b, vdup_n_f32(c))
vfms_f32(a, b, vdup_n_f32_v8(c))
}

/// Floating-point fused Multiply-subtract to accumulator(vector)
Expand All @@ -8749,7 +8749,7 @@ pub unsafe fn vfms_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
pub unsafe fn vfmsq_n_f32(a: float32x4_t, b: float32x4_t, c: f32) -> float32x4_t {
vfmsq_f32(a, b, vdupq_n_f32(c))
vfmsq_f32(a, b, vdupq_n_f32_v8(c))
}

/// Subtract
Expand Down
20 changes: 20 additions & 0 deletions crates/core_arch/src/arm_shared/neon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3704,6 +3704,16 @@ pub unsafe fn vdupq_n_f32(value: f32) -> float32x4_t {
float32x4_t(value, value, value, value)
}

/// Duplicate vector element to vector or scalar
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr("vdup.32"))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(dup))]
unsafe fn vdupq_n_f32_v8(value: f32) -> float32x4_t {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment explaining why this is needed? Also this should probably be under #[cfg(target_arch = "arm")].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Amanieu Will do, but first I will fix arm fused multiply-add to only require vfp4 instead of v8 (PR forthcoming) and then the fn name will change to vdupq_n_f32_vfp4.

float32x4_t(value, value, value, value)
}

/// Duplicate vector element to vector or scalar
#[inline]
#[target_feature(enable = "neon")]
Expand Down Expand Up @@ -3814,6 +3824,16 @@ pub unsafe fn vdup_n_f32(value: f32) -> float32x2_t {
float32x2_t(value, value)
}

/// Duplicate vector element to vector or scalar
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr("vdup.32"))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(dup))]
unsafe fn vdup_n_f32_v8(value: f32) -> float32x2_t {
float32x2_t(value, value)
}

/// Duplicate vector element to vector or scalar
#[inline]
#[target_feature(enable = "neon")]
Expand Down
4 changes: 2 additions & 2 deletions crates/stdarch-gen/neon.spec
Original file line number Diff line number Diff line change
Expand Up @@ -2576,7 +2576,7 @@ generate float*_t
/// Floating-point fused Multiply-Add to accumulator(vector)
name = vfma
n-suffix
multi_fn = vfma-self-noext, a, b, {vdup-nself-noext, c}
multi_fn = vfma-self-noext, a, b, {vdup-nselfv8-noext, c}
a = 2.0, 3.0, 4.0, 5.0
b = 6.0, 4.0, 7.0, 8.0
c = 8.0
Expand Down Expand Up @@ -2653,7 +2653,7 @@ generate float*_t
/// Floating-point fused Multiply-subtract to accumulator(vector)
name = vfms
n-suffix
multi_fn = vfms-self-noext, a, b, {vdup-nself-noext, c}
multi_fn = vfms-self-noext, a, b, {vdup-nselfv8-noext, c}
a = 50.0, 35.0, 60.0, 69.0
b = 6.0, 4.0, 7.0, 8.0
c = 8.0
Expand Down
12 changes: 11 additions & 1 deletion crates/stdarch-gen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,7 @@ fn gen_aarch64(
out_t,
fixed,
None,
true,
));
}
calls
Expand Down Expand Up @@ -1947,6 +1948,7 @@ fn gen_arm(
out_t,
fixed,
None,
false,
));
}
calls
Expand Down Expand Up @@ -2364,6 +2366,7 @@ fn get_call(
out_t: &str,
fixed: &Vec<String>,
n: Option<i32>,
aarch64: bool,
) -> String {
let params: Vec<_> = in_str.split(',').map(|v| v.trim().to_string()).collect();
assert!(params.len() > 0);
Expand Down Expand Up @@ -2531,7 +2534,8 @@ fn get_call(
in_t,
out_t,
fixed,
Some(i as i32)
Some(i as i32),
aarch64
)
);
call.push_str(&sub_match);
Expand Down Expand Up @@ -2580,6 +2584,7 @@ fn get_call(
out_t,
fixed,
n.clone(),
aarch64,
);
if !param_str.is_empty() {
param_str.push_str(", ");
Expand Down Expand Up @@ -2650,6 +2655,11 @@ fn get_call(
fn_name.push_str(type_to_suffix(in_t[1]));
} else if fn_format[1] == "nself" {
fn_name.push_str(type_to_n_suffix(in_t[1]));
} else if fn_format[1] == "nselfv8" {
fn_name.push_str(type_to_n_suffix(in_t[1]));
if !aarch64 {
fn_name.push_str("_v8");
}
} else if fn_format[1] == "out" {
fn_name.push_str(type_to_suffix(out_t));
} else if fn_format[1] == "in0" {
Expand Down
2 changes: 1 addition & 1 deletion crates/stdarch-test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ pub fn assert(shim_addr: usize, fnname: &str, expected: &str) {
// failed inlining something.
s[0].starts_with("call ") && s[1].starts_with("pop") // FIXME: original logic but does not match comment
})
} else if cfg!(target_arch = "aarch64") {
} else if cfg!(target_arch = "aarch64") || cfg!(target_arch = "arm") {
instrs.iter().any(|s| s.starts_with("bl "))
} else {
// FIXME: Add detection for other archs
Expand Down