-
Notifications
You must be signed in to change notification settings - Fork 207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add implementation of inclusive scan via upsweep-downsweep #723
Conversation
#pragma unroll | ||
for (int stride = 1; stride < Power2ScanSize; stride <<= 1) { | ||
int index = (threadIdx.x + 1) * stride * 2 - 1; | ||
if (index < Power2ScanSize) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"This code still works for collections that
+// do not exactly contain a power of 2 number of elements, simply round up to the
+// nearest power of 2 and then call."
This is not true here, you should pass in an extra size parameter for the data instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess to clarify, the algorithm works as long as you have Power2ScanSize space in smem, but yes we could add a size parameter to condition on as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, nevermind, it won't work with a non-power of 2 size, this is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duh, I take that back, it will work. A size parameter sounds like a good addition, since not every scan will involve a power-of-2 size, especially the tail end of a set of data (either that, or the smem will have to be reset with an identity value for the reduction; e.g., for +, would have to be filled with 0).
Can you note the performance improvements ? |
// 15 | ||
// 3 10 21 | ||
template <typename T, class BinaryOp, int Power2ScanSize> | ||
__device__ void inclusivePrefixScan(T *smem, BinaryOp binop) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this function should take an input and pass an output like the others, instead of assuming the values are already in shared memory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the other case each thread is responsible for a single element, where as in this case each thread has two associated elements. So we could match it via:
__device__ void inclusivePrefixScan(T *smem, T a, T b, T *out, BinaryOp op) { ... }
but I think it is a little less clean than compared with the others.
A slightly more efficient Scan that uses upsweep/downsweep like mechanisms.
Tested outside of cutorch codebase on buffers of size 2, 3, 21, 32, 33, 64 and verified that it properly calculated the prefix sum when templatized via an addition binary op.