Skip to content

Commit

Permalink
Auto merge of rust-lang#136457 - calder:master, r=<try>
Browse files Browse the repository at this point in the history
Expose algebraic floating point intrinsics

# Problem

A stable Rust implementation of a simple dot product is 8x slower than C++ on modern x86-64 CPUs. The root cause is an inability to let the compiler reorder floating point operations for better vectorization.

See https://github.com/calder/dot-bench for benchmarks. Measurements below were performed on a i7-10875H.

### C++: 10us ✅

With Clang 18.1.3 and `-O2 -march=haswell`:
<table>
<tr>
    <th>C++</th>
    <th>Assembly</th>
</tr>
<tr>
<td>
<pre lang="cc">
float dot(float *a, float *b, size_t len) {
    #pragma clang fp reassociate(on)
    float sum = 0.0;
    for (size_t i = 0; i < len; ++i) {
        sum += a[i] * b[i];
    }
    return sum;
}
</pre>
</td>
<td>
<img src="https://github.com/user-attachments/assets/739573c0-380a-4d84-9fd9-141343ce7e68" />
</td>
</tr>
</table>

### Nightly Rust: 10us ✅

With rustc 1.86.0-nightly (8239a37) and `-C opt-level=3 -C target-feature=+avx2,+fma`:
<table>
<tr>
    <th>Rust</th>
    <th>Assembly</th>
</tr>
<tr>
<td>
<pre lang="rust">
fn dot(a: &[f32], b: &[f32]) -> f32 {
    let mut sum = 0.0;
    for i in 0..a.len() {
        sum = fadd_algebraic(sum, fmul_algebraic(a[i], b[i]));
    }
    sum
}
</pre>
</td>
<td>
<img src="https://github.com/user-attachments/assets/9dcf953a-2cd7-42f3-bc34-7117de4c5fb9" />
</td>
</tr>
</table>

### Stable Rust: 84us ❌

With rustc 1.84.1 (e71f9a9) and `-C opt-level=3 -C target-feature=+avx2,+fma`:
<table>
<tr>
    <th>Rust</th>
    <th>Assembly</th>
</tr>
<tr>
<td>
<pre lang="rust">
fn dot(a: &[f32], b: &[f32]) -> f32 {
    let mut sum = 0.0;
    for i in 0..a.len() {
        sum += a[i] * b[i];
    }
    sum
}
</pre>
</td>
<td>
<img src="https://github.com/user-attachments/assets/936a1f7e-33e4-4ff8-a732-c3cdfe068dca" />
</td>
</tr>
</table>

# Proposed Change

Add `core::intrinsics::f*_algebraic` wrappers to `f16`, `f32`, `f64`, and `f128` gated on a new `float_algebraic` feature.

# Alternatives Considered

rust-lang#21690 has a lot of good discussion of various options for supporting fast math in Rust, but is still open a decade later because any choice that opts in more than individual operations is ultimately contrary to Rust's design principles.

In the mean time, processors have evolved and we're leaving major performance on the table by not supporting vectorization. We shouldn't make users choose between an unstable compiler and an 8x performance hit.

# References

* rust-lang#21690
* rust-lang/libs-team#532
* rust-lang#136469
* https://github.com/calder/dot-bench
* https://www.felixcloutier.com/x86/vfmadd132ps:vfmadd213ps:vfmadd231ps

try-job: x86_64-gnu-nopt
  • Loading branch information
bors committed Feb 19, 2025
2 parents ed49386 + 806af98 commit 621c0ed
Show file tree
Hide file tree
Showing 15 changed files with 508 additions and 20 deletions.
10 changes: 5 additions & 5 deletions library/core/src/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2995,7 +2995,7 @@ pub unsafe fn float_to_int_unchecked<Float: Copy, Int: Copy>(_value: Float) -> I

/// Float addition that allows optimizations based on algebraic rules.
///
/// This intrinsic does not have a stable counterpart.
/// Stabilized as [`f16::algebraic_add`], [`f32::algebraic_add`], [`f64::algebraic_add`] and [`f128::algebraic_add`].
#[rustc_nounwind]
#[rustc_intrinsic]
#[rustc_intrinsic_must_be_overridden]
Expand All @@ -3005,7 +3005,7 @@ pub fn fadd_algebraic<T: Copy>(_a: T, _b: T) -> T {

/// Float subtraction that allows optimizations based on algebraic rules.
///
/// This intrinsic does not have a stable counterpart.
/// Stabilized as [`f16::algebraic_sub`], [`f32::algebraic_sub`], [`f64::algebraic_sub`] and [`f128::algebraic_sub`].
#[rustc_nounwind]
#[rustc_intrinsic]
#[rustc_intrinsic_must_be_overridden]
Expand All @@ -3015,7 +3015,7 @@ pub fn fsub_algebraic<T: Copy>(_a: T, _b: T) -> T {

/// Float multiplication that allows optimizations based on algebraic rules.
///
/// This intrinsic does not have a stable counterpart.
/// Stabilized as [`f16::algebraic_mul`], [`f32::algebraic_mul`], [`f64::algebraic_mul`] and [`f128::algebraic_mul`].
#[rustc_nounwind]
#[rustc_intrinsic]
#[rustc_intrinsic_must_be_overridden]
Expand All @@ -3025,7 +3025,7 @@ pub fn fmul_algebraic<T: Copy>(_a: T, _b: T) -> T {

/// Float division that allows optimizations based on algebraic rules.
///
/// This intrinsic does not have a stable counterpart.
/// Stabilized as [`f16::algebraic_div`], [`f32::algebraic_div`], [`f64::algebraic_div`] and [`f128::algebraic_div`].
#[rustc_nounwind]
#[rustc_intrinsic]
#[rustc_intrinsic_must_be_overridden]
Expand All @@ -3035,7 +3035,7 @@ pub fn fdiv_algebraic<T: Copy>(_a: T, _b: T) -> T {

/// Float remainder that allows optimizations based on algebraic rules.
///
/// This intrinsic does not have a stable counterpart.
/// Stabilized as [`f16::algebraic_rem`], [`f32::algebraic_rem`], [`f64::algebraic_rem`] and [`f128::algebraic_rem`].
#[rustc_nounwind]
#[rustc_intrinsic]
#[rustc_intrinsic_must_be_overridden]
Expand Down
50 changes: 50 additions & 0 deletions library/core/src/num/f128.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1365,4 +1365,54 @@ impl f128 {
// SAFETY: this is actually a safe intrinsic
unsafe { intrinsics::copysignf128(self, sign) }
}

/// Float addition that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_add(self, rhs: f128) -> f128 {
intrinsics::fadd_algebraic(self, rhs)
}

/// Float subtraction that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_sub(self, rhs: f128) -> f128 {
intrinsics::fsub_algebraic(self, rhs)
}

/// Float multiplication that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_mul(self, rhs: f128) -> f128 {
intrinsics::fmul_algebraic(self, rhs)
}

/// Float division that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_div(self, rhs: f128) -> f128 {
intrinsics::fdiv_algebraic(self, rhs)
}

/// Float remainder that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_rem(self, rhs: f128) -> f128 {
intrinsics::frem_algebraic(self, rhs)
}
}
50 changes: 50 additions & 0 deletions library/core/src/num/f16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1341,4 +1341,54 @@ impl f16 {
// SAFETY: this is actually a safe intrinsic
unsafe { intrinsics::copysignf16(self, sign) }
}

/// Float addition that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_add(self, rhs: f16) -> f16 {
intrinsics::fadd_algebraic(self, rhs)
}

/// Float subtraction that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_sub(self, rhs: f16) -> f16 {
intrinsics::fsub_algebraic(self, rhs)
}

/// Float multiplication that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_mul(self, rhs: f16) -> f16 {
intrinsics::fmul_algebraic(self, rhs)
}

/// Float division that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_div(self, rhs: f16) -> f16 {
intrinsics::fdiv_algebraic(self, rhs)
}

/// Float remainder that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_rem(self, rhs: f16) -> f16 {
intrinsics::frem_algebraic(self, rhs)
}
}
50 changes: 50 additions & 0 deletions library/core/src/num/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1506,4 +1506,54 @@ impl f32 {
// SAFETY: this is actually a safe intrinsic
unsafe { intrinsics::copysignf32(self, sign) }
}

/// Float addition that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_add(self, rhs: f32) -> f32 {
intrinsics::fadd_algebraic(self, rhs)
}

/// Float subtraction that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_sub(self, rhs: f32) -> f32 {
intrinsics::fsub_algebraic(self, rhs)
}

/// Float multiplication that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_mul(self, rhs: f32) -> f32 {
intrinsics::fmul_algebraic(self, rhs)
}

/// Float division that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_div(self, rhs: f32) -> f32 {
intrinsics::fdiv_algebraic(self, rhs)
}

/// Float remainder that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_rem(self, rhs: f32) -> f32 {
intrinsics::frem_algebraic(self, rhs)
}
}
50 changes: 50 additions & 0 deletions library/core/src/num/f64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1506,4 +1506,54 @@ impl f64 {
// SAFETY: this is actually a safe intrinsic
unsafe { intrinsics::copysignf64(self, sign) }
}

/// Float addition that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_add(self, rhs: f64) -> f64 {
intrinsics::fadd_algebraic(self, rhs)
}

/// Float subtraction that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_sub(self, rhs: f64) -> f64 {
intrinsics::fsub_algebraic(self, rhs)
}

/// Float multiplication that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_mul(self, rhs: f64) -> f64 {
intrinsics::fmul_algebraic(self, rhs)
}

/// Float division that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_div(self, rhs: f64) -> f64 {
intrinsics::fdiv_algebraic(self, rhs)
}

/// Float remainder that allows optimizations based on algebraic rules.
///
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
#[must_use = "method returns a new number and does not mutate the original value"]
#[unstable(feature = "float_algebraic", issue = "136469")]
#[inline]
pub fn algebraic_rem(self, rhs: f64) -> f64 {
intrinsics::frem_algebraic(self, rhs)
}
}
45 changes: 45 additions & 0 deletions library/core/src/primitive_docs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,51 @@ mod prim_f16 {}
/// | `wasm32`, `wasm64` | If all input NaNs are quiet with all-zero payload: None.<br> Otherwise: all possible payloads. |
///
/// For targets not in this table, all payloads are possible.
///
/// # Algebraic operators
///
/// Algebraic operators of the form `a.algebraic_*(b)` allow the compiler to optimize
/// floating point operations using all the usual algebraic properties of real numbers --
/// despite the fact that those properties do *not* hold on floating point numbers.
/// This can give a great performance boost since it may unlock vectorization.
///
/// The exact set of optimizations is unspecified but typically allows combining operations,
/// rearranging series of operations based on mathematical properties, converting between division
/// and reciprocal multiplication, and disregarding the sign of zero. This means that the results of
/// elementary operations may have undefined precision, and "non-mathematical" values
/// such as NaN, +/-Inf, or -0.0 may behave in unexpected ways, but these operations
/// will never cause undefined behavior.
///
/// Because of the unpredictable nature of compiler optimizations, the same inputs may produce
/// different results even within a single program run. **Unsafe code must not rely on any property
/// of the return value for soundness.** However, implementations will generally do their best to
/// pick a reasonable tradeoff between performance and accuracy of the result.
///
/// For example:
///
/// ```
/// # #![feature(float_algebraic)]
/// # #![allow(unused_assignments)]
/// # let mut x: f32 = 0.0;
/// # let a: f32 = 1.0;
/// # let b: f32 = 2.0;
/// # let c: f32 = 3.0;
/// # let d: f32 = 4.0;
/// x = a.algebraic_add(b).algebraic_add(c).algebraic_add(d);
/// ```
///
/// May be rewritten as either:
///
/// ```
/// # #![allow(unused_assignments)]
/// # let mut x: f32 = 0.0;
/// # let a: f32 = 1.0;
/// # let b: f32 = 2.0;
/// # let c: f32 = 3.0;
/// # let d: f32 = 4.0;
/// x = a + b + c + d; // As written
/// x = (a + c) + (b + d); // Reordered to shorten critical path and enable vectorization
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
mod prim_f32 {}
Expand Down
1 change: 1 addition & 0 deletions library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@
#![feature(exact_size_is_empty)]
#![feature(exclusive_wrapper)]
#![feature(extend_one)]
#![feature(float_algebraic)]
#![feature(float_gamma)]
#![feature(float_minimum_maximum)]
#![feature(fmt_internals)]
Expand Down
12 changes: 12 additions & 0 deletions library/std/tests/floats/f128.rs
Original file line number Diff line number Diff line change
Expand Up @@ -983,3 +983,15 @@ fn test_total_cmp() {
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&f128::INFINITY));
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&s_nan()));
}

#[test]
fn test_algebraic() {
let a: f128 = 123.0;
let b: f128 = 456.0;

assert_approx_eq!(a.algebraic_add(b), a + b);
assert_approx_eq!(a.algebraic_sub(b), a - b);
assert_approx_eq!(a.algebraic_mul(b), a * b);
assert_approx_eq!(a.algebraic_div(b), a / b);
assert_approx_eq!(a.algebraic_rem(b), a % b);
}
12 changes: 12 additions & 0 deletions library/std/tests/floats/f16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,3 +955,15 @@ fn test_total_cmp() {
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&f16::INFINITY));
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&s_nan()));
}

#[test]
fn test_algebraic() {
let a: f16 = 123.0;
let b: f16 = 456.0;

assert_approx_eq!(a.algebraic_add(b), a + b);
assert_approx_eq!(a.algebraic_sub(b), a - b);
assert_approx_eq!(a.algebraic_mul(b), a * b);
assert_approx_eq!(a.algebraic_div(b), a / b);
assert_approx_eq!(a.algebraic_rem(b), a % b);
}
12 changes: 12 additions & 0 deletions library/std/tests/floats/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,3 +915,15 @@ fn test_total_cmp() {
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&f32::INFINITY));
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&s_nan()));
}

#[test]
fn test_algebraic() {
let a: f32 = 123.0;
let b: f32 = 456.0;

assert_approx_eq!(a.algebraic_add(b), a + b);
assert_approx_eq!(a.algebraic_sub(b), a - b);
assert_approx_eq!(a.algebraic_mul(b), a * b);
assert_approx_eq!(a.algebraic_div(b), a / b);
assert_approx_eq!(a.algebraic_rem(b), a % b);
}
12 changes: 12 additions & 0 deletions library/std/tests/floats/f64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,15 @@ fn test_total_cmp() {
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&f64::INFINITY));
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&s_nan()));
}

#[test]
fn test_algebraic() {
let a: f64 = 123.0;
let b: f64 = 456.0;

assert_approx_eq!(a.algebraic_add(b), a + b);
assert_approx_eq!(a.algebraic_sub(b), a - b);
assert_approx_eq!(a.algebraic_mul(b), a * b);
assert_approx_eq!(a.algebraic_div(b), a / b);
assert_approx_eq!(a.algebraic_rem(b), a % b);
}
2 changes: 1 addition & 1 deletion library/std/tests/floats/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#![feature(f16, f128, float_gamma, float_minimum_maximum)]
#![feature(f16, f128, float_algebraic, float_gamma, float_minimum_maximum)]

use std::fmt;
use std::ops::{Add, Div, Mul, Rem, Sub};
Expand Down
Loading

0 comments on commit 621c0ed

Please sign in to comment.