Skip to content

Commit

Permalink
Updated based on comments, ccl now an Allreduce algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Wilkins committed Feb 13, 2025
1 parent 5e592e7 commit ee76028
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 23 deletions.
23 changes: 12 additions & 11 deletions src/include/mpir_cclcomm.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@

#include <nccl.h>

typedef struct MPIR_CCLcomm {
MPIR_OBJECT_HEADER;
MPIR_Comm *comm;
ncclUniqueId id;
ncclComm_t ncclcomm;
cudaStream_t stream;
} MPIR_CCLcomm;
#define ENABLE_CCLCOMM 1 //Temporary, needs to get put in configure

int MPIR_CCL_red_op_is_supported(MPI_Op op);
#ifdef ENABLE_CCLCOMM
typedef struct MPIR_CCLcomm {
MPIR_OBJECT_HEADER;
MPIR_Comm *comm;
ncclUniqueId id;
ncclComm_t ncclcomm;
cudaStream_t stream;
} MPIR_CCLcomm;

int MPIR_CCL_datatype_is_supported(MPI_Datatype datatype);
int MPIR_CCL_red_op_is_supported(MPI_Op op);

int MPIR_CCL_Allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype,
MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
int MPIR_CCL_datatype_is_supported(MPI_Datatype datatype);
#endif /* ENABLE_CCLCOMM */

#endif /* MPIR_CCLCOMM_H_INCLUDED */
30 changes: 18 additions & 12 deletions src/mpi/ccl/cclcomm.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
#define CUDA_ERR_CHECK(ret) if (unlikely((ret) != cudaSuccess)) goto fn_fail

/*
* CCLcomm functions, currently tied to NCCL
* NCCL-specific functions
*/

int MPIR_CCLcomm_init(MPIR_Comm * comm, int rank)
int MPIR_NCCL_comm_init(MPIR_Comm * comm, int rank)
{
int mpi_errno = MPI_SUCCESS;
int mpi_errno = MPI_SUCCESS;
cudaError_t ret;
int comm_size = comm->local_size;

Expand Down Expand Up @@ -45,7 +45,7 @@ int MPIR_CCLcomm_init(MPIR_Comm * comm, int rank)
goto fn_exit;
}

int MPIR_CCLcomm_free(MPIR_Comm * comm)
int MPIR_NCCL_comm_free(MPIR_Comm * comm)
{
int mpi_errno = MPI_SUCCESS;
cudaError_t ret;
Expand All @@ -68,10 +68,6 @@ int MPIR_CCLcomm_free(MPIR_Comm * comm)
goto fn_exit;
}

/*
* NCCL-specific functions
*/

int MPIR_NCCL_red_op_is_supported(MPI_Op op)
{
switch (op) {
Expand Down Expand Up @@ -239,20 +235,30 @@ int MPIR_NCCL_Allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_
}

/*
* CCL wrapper functions
* CCLcomm wrapper functions, currently tied to NCCL
*/

int MPIR_CCLcomm_init(MPIR_Comm * comm, int rank)
{
return MPIR_NCCL_comm_init(comm, rank);
}

int MPIR_CCLcomm_free(MPIR_Comm * comm)
{
return MPIR_NCCL_comm_free(comm);
}

int MPIR_CCL_red_op_is_supported(MPI_Op op)
{
MPIR_NCCL_red_op_is_supported(op);
return MPIR_NCCL_red_op_is_supported(op);
}

int MPIR_CCL_datatype_is_supported(MPI_Datatype datatype)
{
MPIR_NCCL_datatype_is_supported(datatype);
return MPIR_NCCL_datatype_is_supported(datatype);
}

int MPIR_CCL_Allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype,
int MPIR_Allreduce_intra_ccl(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype,
MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
{
return MPIR_NCCL_Allreduce(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag);
Expand Down
1 change: 1 addition & 0 deletions src/mpi/coll/coll_algorithms.txt
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ allreduce-intra:
restrictions: commutative
extra_params: k, single_phase_recv
cvar_params: RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV
ccl
allreduce-inter:
reduce_exchange_bcast
iallreduce-intra:
Expand Down
1 change: 1 addition & 0 deletions src/mpi/coll/cvars.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,7 @@ cvars:
recexch - Force generic transport recursive exchange algorithm
ring - Force ring algorithm
k_reduce_scatter_allgather - Force reduce scatter allgather algorithm
ccl - Force CCL algorithm

- name : MPIR_CVAR_ALLREDUCE_RECURSIVE_MULTIPLYING_KVAL
category : COLLECTIVE
Expand Down
1 change: 1 addition & 0 deletions src/mpi/coll/include/csel_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ typedef enum {
MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Allreduce_intra_recexch,
MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Allreduce_intra_ring,
MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Allreduce_intra_k_reduce_scatter_allgather,
MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Allreduce_intra_ccl,
MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Allreduce_inter_reduce_exchange_bcast,
MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Allreduce_allcomm_nb,
MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Alltoall_intra_brucks,
Expand Down

0 comments on commit ee76028

Please sign in to comment.