Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new WarpReduce overloadings #3884

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
22 changes: 8 additions & 14 deletions cub/cub/warp/specializations/warp_reduce_smem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@
# pragma system_header
#endif // no system header

#include <cub/thread/thread_load.cuh>
#include <cub/thread/thread_operators.cuh>
#include <cub/thread/thread_store.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>

#include <cuda/ptx>
Expand Down Expand Up @@ -152,21 +151,16 @@ struct WarpReduceSmem
ReduceStep(T input, int valid_items, ReductionOp reduction_op, constant_t<STEP> /*step*/)
{
constexpr int OFFSET = 1 << STEP;

// Share input through buffer
ThreadStore<STORE_VOLATILE>(&temp_storage.reduce[lane_id], input);

temp_storage.reduce[lane_id] = input;
__syncwarp(member_mask);

// Update input if peer_addend is in range
if ((ALL_LANES_VALID && IS_POW_OF_TWO) || ((lane_id + OFFSET) < valid_items))
{
T peer_addend = ThreadLoad<LOAD_VOLATILE>(&temp_storage.reduce[lane_id + OFFSET]);
T peer_addend = temp_storage.reduce[lane_id + OFFSET];
input = reduction_op(input, peer_addend);
}

__syncwarp(member_mask);

return ReduceStep<ALL_LANES_VALID>(input, valid_items, reduction_op, constant_v<STEP + 1>);
}

Expand Down Expand Up @@ -250,14 +244,14 @@ struct WarpReduceSmem
const int OFFSET = 1 << STEP;

// Share input into buffer
ThreadStore<STORE_VOLATILE>(&temp_storage.reduce[lane_id], input);
temp_storage.reduce[lane_id] = input;

__syncwarp(member_mask);

// Update input if peer_addend is in range
if (OFFSET + lane_id < next_flag)
{
T peer_addend = ThreadLoad<LOAD_VOLATILE>(&temp_storage.reduce[lane_id + OFFSET]);
T peer_addend = temp_storage.reduce[lane_id + OFFSET];
input = reduction_op(input, peer_addend);
}

Expand Down Expand Up @@ -297,7 +291,7 @@ struct WarpReduceSmem
};

// Alias flags onto shared data storage
volatile SmemFlag* flag_storage = temp_storage.flags;
SmemFlag* flag_storage = temp_storage.flags;

SmemFlag flag_status = (flag) ? SET : UNSET;

Expand All @@ -306,12 +300,12 @@ struct WarpReduceSmem
const int OFFSET = 1 << STEP;

// Share input through buffer
ThreadStore<STORE_VOLATILE>(&temp_storage.reduce[lane_id], input);
temp_storage.reduce[lane_id] = input;

__syncwarp(member_mask);

// Get peer from buffer
T peer_addend = ThreadLoad<LOAD_VOLATILE>(&temp_storage.reduce[lane_id + OFFSET]);
T peer_addend = temp_storage.reduce[lane_id + OFFSET];

__syncwarp(member_mask);

Expand Down
Loading
Loading