Skip to content

Commit

Permalink
Merge pull request JuliaParallel#227 from samo-lin/enhancement/cartes…
Browse files Browse the repository at this point in the history
…ian-topo

Add functions for Cartesian process topology
  • Loading branch information
barche authored Jan 19, 2019
2 parents a915ee1 + 79f3b00 commit db9833b
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 1 deletion.
4 changes: 4 additions & 0 deletions deps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ FortranCInterface_HEADER(jlmpi_f2c.h MACRO_NAMESPACE "JLMPI_" SYMBOLS
MPI_BCAST
MPI_BSEND
MPI_CANCEL
MPI_CART_CREATE
MPI_CART_COORDS
MPI_CART_SHIFT
MPI_COMM_DUP
MPI_COMM_FREE
MPI_COMM_GET_PARENT
MPI_COMM_RANK
MPI_COMM_SIZE
MPI_COMM_SPLIT
MPI_COMM_SPLIT_TYPE
MPI_DIMS_CREATE
MPI_EXSCAN
MPI_FETCH_AND_OP
MPI_FINALIZE
Expand Down
2 changes: 2 additions & 0 deletions deps/gen_constants.f90
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ program gen_constants

call output("MPI_INFO_NULL ", MPI_INFO_NULL)

call output("MPI_PROC_NULL ", MPI_PROC_NULL)

call output("MPI_STATUS_SIZE ", MPI_STATUS_SIZE)
call output("MPI_ERROR ", MPI_ERROR)
call output("MPI_SOURCE ", MPI_SOURCE)
Expand Down
4 changes: 4 additions & 0 deletions deps/gen_functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@ int main(int argc, char *argv[]) {
printf(" :MPI_BCAST => \"%s\",\n", STRING(MPI_BCAST));
printf(" :MPI_BSEND => \"%s\",\n", STRING(MPI_BSEND));
printf(" :MPI_CANCEL => \"%s\",\n", STRING(MPI_CANCEL));
printf(" :MPI_CART_CREATE => \"%s\",\n", STRING(MPI_CART_CREATE));
printf(" :MPI_CART_COORDS => \"%s\",\n", STRING(MPI_CART_COORDS));
printf(" :MPI_CART_SHIFT => \"%s\",\n", STRING(MPI_CART_SHIFT));
printf(" :MPI_COMM_DUP => \"%s\",\n", STRING(MPI_COMM_DUP));
printf(" :MPI_COMM_FREE => \"%s\",\n", STRING(MPI_COMM_FREE));
printf(" :MPI_COMM_GET_PARENT => \"%s\",\n", STRING(MPI_COMM_GET_PARENT));
printf(" :MPI_COMM_RANK => \"%s\",\n", STRING(MPI_COMM_RANK));
printf(" :MPI_COMM_SIZE => \"%s\",\n", STRING(MPI_COMM_SIZE));
printf(" :MPI_COMM_SPLIT => \"%s\",\n", STRING(MPI_COMM_SPLIT));
printf(" :MPI_COMM_SPLIT_TYPE => \"%s\",\n", STRING(MPI_COMM_SPLIT_TYPE));
printf(" :MPI_DIMS_CREATE => \"%s\",\n", STRING(MPI_DIMS_CREATE));
printf(" :MPI_EXSCAN => \"%s\",\n", STRING(MPI_EXSCAN));
printf(" :MPI_FETCH_AND_OP => \"%s\",\n", STRING(MPI_FETCH_AND_OP));
printf(" :MPI_FINALIZE => \"%s\",\n", STRING(MPI_FINALIZE));
Expand Down
49 changes: 49 additions & 0 deletions src/mpi-base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,55 @@ function Comm_split_type(comm::Comm,split_type::Integer,key::Integer;info::Info=
MPI.Comm(newcomm[])
end

function Dims_create!(nnodes::Integer, ndims::Integer, dims::MPIBuffertype{T}) where T <: Integer
ccall(MPI_DIMS_CREATE, Nothing, (Ref{Cint}, Ref{Cint}, Ptr{T}, Ref{Cint}),
nnodes, ndims, dims, 0)
end

function Dims_create!(nnodes::Integer, dims::AbstractArray{T,N}) where T <: Integer where N
cdims = Cint.(dims[:])
ndims = length(cdims)
Dims_create!(nnodes, ndims, cdims)
dims[:] .= cdims
end

function Cart_create(comm_old::Comm, ndims::Integer, dims::MPIBuffertype{T}, periods::MPIBuffertype{T}, reorder::Integer) where T <: Integer
comm_cart = Ref{Cint}()
ccall(MPI_CART_CREATE, Nothing,
(Ref{Cint}, Ref{Cint}, Ptr{T}, Ptr{T}, Ref{Cint}, Ref{Cint}, Ref{Cint}),
comm_old.val, ndims, dims, periods, reorder, comm_cart, 0)
MPI.Comm(comm_cart[])
end

function Cart_create(comm_old::Comm, dims::AbstractArray{T,N}, periods::Array{T,N}, reorder::Integer) where T <: Integer where N
cdims = Cint.(dims[:])
cperiods = Cint.(periods[:])
ndims = length(cdims)
Cart_create(comm_old, ndims, cdims, cperiods, reorder)
end

function Cart_coords!(comm::Comm, rank::Integer, maxdims::Integer, coords::MPIBuffertype{T}) where T <: Integer
ccall(MPI_CART_COORDS, Nothing,
(Ref{Cint}, Ref{Cint}, Ref{Cint}, Ptr{T}, Ref{Cint}),
comm.val, rank, maxdims, coords, 0)
end

function Cart_coords!(comm::Comm, maxdims::Integer)
ccoords = Vector{Cint}(undef, maxdims)
rank = Comm_rank(comm)
Cart_coords!(comm, rank, maxdims, ccoords)
Int.(ccoords)
end

function Cart_shift(comm::Comm, direction::Integer, disp::Integer)
rank_source = Ref{Cint}()
rank_dest = Ref{Cint}()
ccall(MPI_CART_SHIFT, Nothing,
(Ref{Cint}, Ref{Cint}, Ref{Cint}, Ref{Cint}, Ref{Cint}, Ref{Cint}),
comm.val, direction, disp, rank_source, rank_dest, 0)
Int(rank_source[]), Int(rank_dest[])
end

function Wtick()
ccall(MPI_WTICK, Cdouble, ())
end
Expand Down
7 changes: 6 additions & 1 deletion src/win_mpiconstants.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# These constants were manually copied from the file mpi.h in the Microsoft
# MPI SDK v7
# MPI SDK v7; the value for MPI_PROC_NULL was obtained from v10.0

const MPI_BYTE = Int32(0x4c00010d)
const MPI_WCHAR = Int32(0x4c00020e)
Expand Down Expand Up @@ -47,6 +47,7 @@ const MPI_ERROR = Int32(5)
const MPI_SOURCE = Int32(3)
const MPI_TAG = Int32(4)
const MPI_ANY_SOURCE = Int32(-2)
const MPI_PROC_NULL = Int32(-1)
const MPI_ANY_TAG = Int32(-1)
const MPI_TAG_UB = Int32(1681915906)
const MPI_UNDEFINED = Int32(-32766)
Expand All @@ -59,12 +60,16 @@ const MPI_ACCUMULATE = (:MPI_ACCUMULATE, libmpi)
const MPI_ALLREDUCE = (:MPI_ALLREDUCE, libmpi)
const MPI_INIT = (:MPI_INIT, libmpi)
const MPI_CANCEL = (:MPI_CANCEL, libmpi)
const MPI_CART_CREATE = (:MPI_CART_CREATE, libmpi)
const MPI_CART_COORDS = (:MPI_CART_COORDS, libmpi)
const MPI_CART_SHIFT = (:MPI_CART_SHIFT, libmpi)
const MPI_COMM_FREE = (:MPI_COMM_FREE, libmpi)
const MPI_COMM_GET_PARENT = (:MPI_COMM_GET_PARENT, libmpi)
const MPI_COMM_RANK = (:MPI_COMM_RANK, libmpi)
const MPI_COMM_SIZE = (:MPI_COMM_SIZE, libmpi)
const MPI_COMM_SPLIT = (:MPI_COMM_SPLIT, libmpi)
const MPI_COMM_SPLIT_TYPE = (:MPI_COMM_SPLIT_TYPE, libmpi)
const MPI_DIMS_CREATE = (:MPI_DIMS_CREATE, libmpi)
const MPI_BARRIER = (:MPI_BARRIER, libmpi)
const MPI_FINALIZE = (:MPI_FINALIZE, libmpi)
const MPI_BCAST = (:MPI_BCAST, libmpi)
Expand Down
24 changes: 24 additions & 0 deletions test/test_cart_coords.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using Test
using MPI

MPI.Init()
comm = MPI.COMM_WORLD
nnodes = MPI.Comm_size(comm)
ndims = 3
reorder = 1
periods = [0,1,0]
dims = [0,0,0]
MPI.Dims_create!(nnodes, dims)
comm_cart = MPI.Cart_create(comm, dims, periods, reorder)

rank = MPI.Comm_rank(comm)
ccoords = Cint[-1,-1,-1]
MPI.Cart_coords!(comm_cart, rank, ndims, ccoords)
@test all(ccoords .>= 0)
@test all(ccoords .< dims)

coords = MPI.Cart_coords!(comm_cart, ndims)
@test all(coords .>= 0)
@test all(coords .< dims)

MPI.Finalize()
21 changes: 21 additions & 0 deletions test/test_cart_create.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using Test
using MPI

MPI.Init()
comm = MPI.COMM_WORLD
nnodes = MPI.Comm_size(comm)
ndims = 3
reorder = 1
periods = [0 1 0]
dims = [0 0 0]
MPI.Dims_create!(nnodes, dims)

cperiods = Cint.(periods[:])
cdims = Cint.(dims[:])
comm_cart = MPI.Cart_create(comm, ndims, cdims, cperiods, reorder)
@test MPI.Comm_size(comm_cart) == nnodes

comm_cart2 = MPI.Cart_create(comm, dims, periods, reorder)
@test MPI.Comm_size(comm_cart2) == nnodes

MPI.Finalize()
21 changes: 21 additions & 0 deletions test/test_cart_shift.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using Test
using MPI

MPI.Init()
comm = MPI.COMM_WORLD
nnodes = MPI.Comm_size(comm)
ndims = 3
reorder = 1
periods = [0,1,0]
dims = [0,0,0]
MPI.Dims_create!(nnodes, dims)
comm_cart = MPI.Cart_create(comm, dims, periods, reorder)
coords = MPI.Cart_coords!(comm_cart, ndims)
disp = 1

for i in 0:2
neighbors = MPI.Cart_shift(comm_cart, i, disp)
@test all( ((neighbors .>= 0) .& (neighbors .< nnodes)) .| (neighbors .== MPI.MPI_PROC_NULL) )
end

MPI.Finalize()
22 changes: 22 additions & 0 deletions test/test_dims_create.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Test
using MPI

MPI.Init()
comm = MPI.COMM_WORLD
nnodes = MPI.Comm_size(comm)
ndims = 3

cdims = Cint[0,0,0]
MPI.Dims_create!(nnodes, ndims, cdims)
@test prod(cdims) == nnodes

cdims = Cint[1,0,1]
MPI.Dims_create!(nnodes, ndims, cdims)
@test cdims == Cint[1,nnodes,1]

for dims in ([0,0,0], [0 0 0])
MPI.Dims_create!(nnodes, dims)
@test prod(dims) == nnodes
end

MPI.Finalize()

0 comments on commit db9833b

Please sign in to comment.