Skip to content

Commit

Permalink
Vectorize IEnumerable<T>.Sum where possible (#84519)
Browse files Browse the repository at this point in the history
* Vectorize IEnumerable<T>.Sum where possible

* Remove unnecessary CreateChecked and fix long test

* Add more assertions

* Improve comments, don't use Unsafe.Add, use T.MinValue
  • Loading branch information
brantburnett authored May 22, 2023
1 parent c5cfd7b commit 07af177
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 0 deletions.
135 changes: 135 additions & 0 deletions src/libraries/System.Linq/src/System/Linq/Sum.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace System.Linq
{
Expand Down Expand Up @@ -40,6 +43,26 @@ private static TResult Sum<T, TResult>(ReadOnlySpan<T> span)
where T : struct, INumber<T>
where TResult : struct, INumber<TResult>
{
if (typeof(T) == typeof(TResult)
&& Vector<T>.IsSupported
&& Vector.IsHardwareAccelerated
&& Vector<T>.Count > 2
&& span.Length >= Vector<T>.Count * 4)
{
// For cases where the vector may only contain two elements vectorization doesn't add any benefit
// due to the expense of overflow checking. This means that architectures where Vector<T> is 128 bit,
// such as ARM or Intel without AVX, will only vectorize spans of ints and not longs.

if (typeof(T) == typeof(long))
{
return (TResult) (object) SumSignedIntegersVectorized(MemoryMarshal.Cast<T, long>(span));
}
if (typeof(T) == typeof(int))
{
return (TResult) (object) SumSignedIntegersVectorized(MemoryMarshal.Cast<T, int>(span));
}
}

TResult sum = TResult.Zero;
foreach (T value in span)
{
Expand All @@ -49,6 +72,118 @@ private static TResult Sum<T, TResult>(ReadOnlySpan<T> span)
return sum;
}

private static T SumSignedIntegersVectorized<T>(ReadOnlySpan<T> span)
where T : struct, IBinaryInteger<T>, ISignedNumber<T>, IMinMaxValue<T>
{
Debug.Assert(span.Length >= Vector<T>.Count * 4);
Debug.Assert(Vector<T>.Count > 2);
Debug.Assert(Vector.IsHardwareAccelerated);

ref T ptr = ref MemoryMarshal.GetReference(span);
nuint length = (nuint)span.Length;

// Overflow testing for vectors is based on setting the sign bit of the overflowTracking
// vector for an element if the following are all true:
// - The two elements being summed have the same sign bit. If one element is positive
// and the other is negative then an overflow is not possible.
// - The sign bit of the sum is not the same as the sign bit of the previous accumulator.
// This indicates that the new sum wrapped around to the opposite sign.
//
// This is done by:
// overflowTracking |= (result ^ input1) & (result ^ input2);
//
// The general premise here is that we're doing signof(result) ^ signof(input1). This will produce
// a sign-bit of 1 if they differ and 0 if they are the same. We do the same with
// signof(result) ^ signof(input2), then combine both results together with a logical &.
//
// Thus, if we had a sign swap compared to both inputs, then signof(input1) == signof(input2) and
// we must have overflowed.
//
// By bitwise or-ing the overflowTracking vector for each step we can save cycles by testing
// the sign bits less often. If any iteration has the sign bit set in any element it indicates
// there was an overflow.
//
// Note: The overflow checking in this algorithm is only correct for signed integers.
// If support is ever added for unsigned integers then the overflow check should be:
// overflowTracking |= (input1 & input2) | Vector.AndNot(input1 | input2, result);

Vector<T> accumulator = Vector<T>.Zero;

// Build a test vector with only the sign bit set in each element.
Vector<T> overflowTestVector = new(T.MinValue);

// Unroll the loop to sum 4 vectors per iteration. This reduces range check
// and overflow check frequency, allows us to eliminate move operations swapping
// accumulators, and may have pipelining benefits.
nuint index = 0;
nuint limit = length - (nuint)Vector<T>.Count * 4;
do
{
// Switch accumulators with each step to avoid an additional move operation
Vector<T> data = Vector.LoadUnsafe(ref ptr, index);
Vector<T> accumulator2 = accumulator + data;
Vector<T> overflowTracking = (accumulator2 ^ accumulator) & (accumulator2 ^ data);

data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector<T>.Count);
accumulator = accumulator2 + data;
overflowTracking |= (accumulator ^ accumulator2) & (accumulator ^ data);

data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector<T>.Count * 2);
accumulator2 = accumulator + data;
overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data);

data = Vector.LoadUnsafe(ref ptr, index + (nuint)Vector<T>.Count * 3);
accumulator = accumulator2 + data;
overflowTracking |= (accumulator ^ accumulator2) & (accumulator ^ data);

if ((overflowTracking & overflowTestVector) != Vector<T>.Zero)
{
ThrowHelper.ThrowOverflowException();
}

index += (nuint)Vector<T>.Count * 4;
} while (index < limit);

// Process remaining vectors, if any, without unrolling
limit = length - (nuint)Vector<T>.Count;
if (index < limit)
{
Vector<T> overflowTracking = Vector<T>.Zero;

do
{
Vector<T> data = Vector.LoadUnsafe(ref ptr, index);
Vector<T> accumulator2 = accumulator + data;
overflowTracking |= (accumulator2 ^ accumulator) & (accumulator2 ^ data);
accumulator = accumulator2;

index += (nuint)Vector<T>.Count;
} while (index < limit);

if ((overflowTracking & overflowTestVector) != Vector<T>.Zero)
{
ThrowHelper.ThrowOverflowException();
}
}

// Add the elements in the vector horizontally.
// Vector.Sum doesn't perform overflow checking, instead add elements individually.
T result = T.Zero;
for (int i = 0; i < Vector<T>.Count; i++)
{
checked { result += accumulator[i]; }
}

// Add any remaining elements
while (index < length)
{
checked { result += Unsafe.Add(ref ptr, index); }

index++;
}

return result;
}

public static int? Sum(this IEnumerable<int?> source) => Sum<int, int>(source);

Expand Down
3 changes: 3 additions & 0 deletions src/libraries/System.Linq/src/System/Linq/ThrowHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ internal static class ThrowHelper
[DoesNotReturn]
internal static void ThrowNotSupportedException() => throw new NotSupportedException();

[DoesNotReturn]
internal static void ThrowOverflowException() => throw new OverflowException();

private static string GetArgumentString(ExceptionArgument argument)
{
switch (argument)
Expand Down
78 changes: 78 additions & 0 deletions src/libraries/System.Linq/tests/SumTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Numerics;
using Xunit;

namespace System.Linq.Tests
Expand Down Expand Up @@ -361,6 +362,21 @@ public void SumOfNullableOfDecimal_SourceIsNotEmpty_ProperSumReturned()

#region SourceSumsToOverflow - OverflowExceptionThrown or Infinity returned

// For testing vectorized overflow, confirms that overflow is detected in multiple vertical lanes
// and with the overflow occurring at different vector offsets into the list of data. This includes
// the 5th and 6th vectors in the data to ensure overflow checks after the unrolled loop that processes
// four vectors at a time.
public static IEnumerable<object[]> SumOverflowsVerticalVectorLanes()
{
for (int element = 0; element < 2; element++)
{
for (int verticalOffset = 1; verticalOffset < 6; verticalOffset++)
{
yield return new object[] {element, verticalOffset};
}
}
}

[Fact]
public void SumOfInt_SourceSumsToOverflow_OverflowExceptionThrown()
{
Expand All @@ -369,6 +385,37 @@ public void SumOfInt_SourceSumsToOverflow_OverflowExceptionThrown()
Assert.Throws<OverflowException>(() => sourceInt.Sum(x => x));
}

[Fact]
public void SumOfInt_SourceSumsToOverflowVectorHorizontally_OverflowExceptionThrown()
{
int[] sourceInt = new int[Vector<int>.Count * 4];
Array.Fill(sourceInt, 0);

for (int i = 0; i < Vector<int>.Count; i++)
{
sourceInt[i] = int.MaxValue - 3;
}
for (int i = Vector<int>.Count; i < sourceInt.Length; i++)
{
sourceInt[i] = 1;
}

Assert.Throws<OverflowException>(() => sourceInt.Sum());
}

[Theory]
[MemberData(nameof(SumOverflowsVerticalVectorLanes))]
public void SumOfInt_SourceSumsToOverflowVectorVertically_OverflowExceptionThrown(int element, int verticalOffset)
{
int[] sourceInt = new int[Vector<int>.Count * 6];
Array.Fill(sourceInt, 0);

sourceInt[element] = int.MaxValue;
sourceInt[element + Vector<int>.Count * verticalOffset] = 1;

Assert.Throws<OverflowException>(() => sourceInt.Sum());
}

[Fact]
public void SumOfNullableOfInt_SourceSumsToOverflow_OverflowExceptionThrown()
{
Expand All @@ -385,6 +432,37 @@ public void SumOfLong_SourceSumsToOverflow_OverflowExceptionThrown()
Assert.Throws<OverflowException>(() => sourceLong.Sum(x => x));
}

[Fact]
public void SumOfLong_SourceSumsToOverflowVectorHorizontally_OverflowExceptionThrown()
{
long[] sourceLong = new long[Vector<long>.Count * 4];
Array.Fill(sourceLong, 0);

for (int i = 0; i < Vector<long>.Count; i++)
{
sourceLong[i] = long.MaxValue - 3;
}
for (int i = Vector<long>.Count; i < sourceLong.Length; i++)
{
sourceLong[i] = 1;
}

Assert.Throws<OverflowException>(() => sourceLong.Sum());
}

[Theory]
[MemberData(nameof(SumOverflowsVerticalVectorLanes))]
public void SumOfLong_SourceSumsToOverflowVectorVertically_OverflowExceptionThrown(int element, int verticalOffset)
{
long[] sourceLong = new long[Vector<long>.Count * 6];
Array.Fill(sourceLong, 0);

sourceLong[element] = long.MaxValue;
sourceLong[element + Vector<long>.Count * verticalOffset] = 1;

Assert.Throws<OverflowException>(() => sourceLong.Sum());
}

[Fact]
public void SumOfNullableOfLong_SourceSumsToOverflow_OverflowExceptionThrown()
{
Expand Down

0 comments on commit 07af177

Please sign in to comment.