diff --git a/src/libraries/System.Linq/src/System/Linq/Sum.cs b/src/libraries/System.Linq/src/System/Linq/Sum.cs index 315df7186c5d78..d8cbdc20a29c85 100644 --- a/src/libraries/System.Linq/src/System/Linq/Sum.cs +++ b/src/libraries/System.Linq/src/System/Linq/Sum.cs @@ -2,7 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Diagnostics; using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; namespace System.Linq { @@ -40,6 +43,26 @@ private static TResult Sum(ReadOnlySpan span) where T : struct, INumber where TResult : struct, INumber { + if (typeof(T) == typeof(TResult) + && Vector.IsSupported + && Vector.IsHardwareAccelerated + && Vector.Count > 2 + && span.Length >= Vector.Count * 4) + { + // For cases where the vector may only contain two elements vectorization doesn't add any benefit + // due to the expense of overflow checking. This means that architectures where Vector is 128 bit, + // such as ARM or Intel without AVX, will only vectorize spans of ints and not longs. + + if (typeof(T) == typeof(long)) + { + return (TResult) (object) SumSignedIntegersVectorized(MemoryMarshal.Cast(span)); + } + if (typeof(T) == typeof(int)) + { + return (TResult) (object) SumSignedIntegersVectorized(MemoryMarshal.Cast(span)); + } + } + TResult sum = TResult.Zero; foreach (T value in span) { @@ -49,6 +72,118 @@ private static TResult Sum(ReadOnlySpan span) return sum; } + private static T SumSignedIntegersVectorized(ReadOnlySpan span) + where T : struct, IBinaryInteger, ISignedNumber, IMinMaxValue + { + Debug.Assert(span.Length >= Vector.Count * 4); + Debug.Assert(Vector.Count > 2); + Debug.Assert(Vector.IsHardwareAccelerated); + + ref T ptr = ref MemoryMarshal.GetReference(span); + nuint length = (nuint)span.Length; + + // Overflow testing for vectors is based on setting the sign bit of the overflowTracking + // vector for an element if the following are all true: + // - The two elements being summed have the same sign bit. If one element is positive + // and the other is negative then an overflow is not possible. + // - The sign bit of the sum is not the same as the sign bit of the previous accumulator. + // This indicates that the new sum wrapped around to the opposite sign. + // + // This is done by: + // overflowTracking |= (result ^ input1) & (result ^ input2); + // + // The general premise here is that we're doing signof(result) ^ signof(input1). This will produce + // a sign-bit of 1 if they differ and 0 if they are the same. We do the same with + // signof(result) ^ signof(input2), then combine both results together with a logical &. + // + // Thus, if we had a sign swap compared to both inputs, then signof(input1) == signof(input2) and + // we must have overflowed. + // + // By bitwise or-ing the overflowTracking vector for each step we can save cycles by testing + // the sign bits less often. If any iteration has the sign bit set in any element it indicates + // there was an overflow. + // + // Note: The overflow checking in this algorithm is only correct for signed integers. + // If support is ever added for unsigned integers then the overflow check should be: + // overflowTracking |= (input1 & input2) | Vector.AndNot(input1 | input2, result); + + Vector accumulator = Vector.Zero; + + // Build a test vector with only the sign bit set in each element. + Vector overflowTestVector = new(T.MinValue); + + // Unroll the loop to sum 4 vectors per iteration. This reduces range check + // and overflow check frequency, allows us to eliminate move operations swapping + // accumulators, and may have pipelining benefits. + nuint index = 0; + nuint limit = length - (nuint)Vector.Count * 4; + do + { + // Switch accumulators with each step to avoid an additional move operation + Vector data = Vector.LoadUnsafe(ref ptr, index); + Vector accumulator2 = accumulator + data; + Vector overflowTracking = (accumulator2 ^ accumulator) & (accumulator2 ^ data); + + data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector.Count); + accumulator = accumulator2 + data; + overflowTracking |= (accumulator ^ accumulator2) & (accumulator ^ data); + + data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector.Count * 2); + accumulator2 = accumulator + data; + overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data); + + data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector.Count * 3); + accumulator = accumulator2 + data; + overflowTracking |= (accumulator ^ accumulator2) & (accumulator ^ data); + + if ((overflowTracking & overflowTestVector) != Vector.Zero) + { + ThrowHelper.ThrowOverflowException(); + } + + index += (nuint)Vector.Count * 4; + } while (index < limit); + + // Process remaining vectors, if any, without unrolling + limit = length - (nuint)Vector.Count; + if (index < limit) + { + Vector overflowTracking = Vector.Zero; + + do + { + Vector data = Vector.LoadUnsafe(ref ptr, index); + Vector accumulator2 = accumulator + data; + overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data); + accumulator = accumulator2; + + index += (nuint)Vector.Count; + } while (index < limit); + + if ((overflowTracking & overflowTestVector) != Vector.Zero) + { + ThrowHelper.ThrowOverflowException(); + } + } + + // Add the elements in the vector horizontally. + // Vector.Sum doesn't perform overflow checking, instead add elements individually. + T result = T.Zero; + for (int i = 0; i < Vector.Count; i++) + { + checked { result += accumulator[i]; } + } + + // Add any remaining elements + while (index < length) + { + checked { result += Unsafe.Add(ref ptr, index); } + + index++; + } + + return result; + } public static int? Sum(this IEnumerable source) => Sum(source); diff --git a/src/libraries/System.Linq/src/System/Linq/ThrowHelper.cs b/src/libraries/System.Linq/src/System/Linq/ThrowHelper.cs index 0b03039c8c8aea..0f0b40c3416799 100644 --- a/src/libraries/System.Linq/src/System/Linq/ThrowHelper.cs +++ b/src/libraries/System.Linq/src/System/Linq/ThrowHelper.cs @@ -29,6 +29,9 @@ internal static class ThrowHelper [DoesNotReturn] internal static void ThrowNotSupportedException() => throw new NotSupportedException(); + [DoesNotReturn] + internal static void ThrowOverflowException() => throw new OverflowException(); + private static string GetArgumentString(ExceptionArgument argument) { switch (argument) diff --git a/src/libraries/System.Linq/tests/SumTests.cs b/src/libraries/System.Linq/tests/SumTests.cs index 4aeb5edf35ee79..39e68a6ad80ad8 100644 --- a/src/libraries/System.Linq/tests/SumTests.cs +++ b/src/libraries/System.Linq/tests/SumTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Numerics; using Xunit; namespace System.Linq.Tests @@ -361,6 +362,21 @@ public void SumOfNullableOfDecimal_SourceIsNotEmpty_ProperSumReturned() #region SourceSumsToOverflow - OverflowExceptionThrown or Infinity returned + // For testing vectorized overflow, confirms that overflow is detected in multiple vertical lanes + // and with the overflow occurring at different vector offsets into the list of data. This includes + // the 5th and 6th vectors in the data to ensure overflow checks after the unrolled loop that processes + // four vectors at a time. + public static IEnumerable SumOverflowsVerticalVectorLanes() + { + for (int element = 0; element < 2; element++) + { + for (int verticalOffset = 1; verticalOffset < 6; verticalOffset++) + { + yield return new object[] {element, verticalOffset}; + } + } + } + [Fact] public void SumOfInt_SourceSumsToOverflow_OverflowExceptionThrown() { @@ -369,6 +385,37 @@ public void SumOfInt_SourceSumsToOverflow_OverflowExceptionThrown() Assert.Throws(() => sourceInt.Sum(x => x)); } + [Fact] + public void SumOfInt_SourceSumsToOverflowVectorHorizontally_OverflowExceptionThrown() + { + int[] sourceInt = new int[Vector.Count * 4]; + Array.Fill(sourceInt, 0); + + for (int i = 0; i < Vector.Count; i++) + { + sourceInt[i] = int.MaxValue - 3; + } + for (int i = Vector.Count; i < sourceInt.Length; i++) + { + sourceInt[i] = 1; + } + + Assert.Throws(() => sourceInt.Sum()); + } + + [Theory] + [MemberData(nameof(SumOverflowsVerticalVectorLanes))] + public void SumOfInt_SourceSumsToOverflowVectorVertically_OverflowExceptionThrown(int element, int verticalOffset) + { + int[] sourceInt = new int[Vector.Count * 6]; + Array.Fill(sourceInt, 0); + + sourceInt[element] = int.MaxValue; + sourceInt[element + Vector.Count * verticalOffset] = 1; + + Assert.Throws(() => sourceInt.Sum()); + } + [Fact] public void SumOfNullableOfInt_SourceSumsToOverflow_OverflowExceptionThrown() { @@ -385,6 +432,37 @@ public void SumOfLong_SourceSumsToOverflow_OverflowExceptionThrown() Assert.Throws(() => sourceLong.Sum(x => x)); } + [Fact] + public void SumOfLong_SourceSumsToOverflowVectorHorizontally_OverflowExceptionThrown() + { + long[] sourceLong = new long[Vector.Count * 4]; + Array.Fill(sourceLong, 0); + + for (int i = 0; i < Vector.Count; i++) + { + sourceLong[i] = long.MaxValue - 3; + } + for (int i = Vector.Count; i < sourceLong.Length; i++) + { + sourceLong[i] = 1; + } + + Assert.Throws(() => sourceLong.Sum()); + } + + [Theory] + [MemberData(nameof(SumOverflowsVerticalVectorLanes))] + public void SumOfLong_SourceSumsToOverflowVectorVertically_OverflowExceptionThrown(int element, int verticalOffset) + { + long[] sourceLong = new long[Vector.Count * 6]; + Array.Fill(sourceLong, 0); + + sourceLong[element] = long.MaxValue; + sourceLong[element + Vector.Count * verticalOffset] = 1; + + Assert.Throws(() => sourceLong.Sum()); + } + [Fact] public void SumOfNullableOfLong_SourceSumsToOverflow_OverflowExceptionThrown() {