4 #===============================================================================================================================
5 # Name : Tests of the library mpi4py from MPI4PY tutorial
6 # Author : Michaël Ndjinga
7 # Copyright : CEA Saclay 2021
8 # Description : https://mpi4py.readthedocs.io/en/stable/tutorial.html
9 #================================================================================================================================
11 from mpi4py import MPI
14 # Tests from MPI4PY tutorial https://mpi4py.readthedocs.io/en/stable/tutorial.html
17 size = comm.Get_size()
18 rank = comm.Get_rank()
20 print("My rank is ", rank, " among ", size, "processors ")
22 ###Point-to-Point Communication
24 #Python objects (pickle under the hood):
27 data = {'a': 7, 'b': 3.14}
28 comm.send(data, dest=1, tag=11)
30 data = comm.recv(source=0, tag=11)
32 #Python objects with non-blocking communication:
35 data = {'a': 7, 'b': 3.14}
36 req = comm.isend(data, dest=1, tag=11)
39 req = comm.irecv(source=0, tag=11)
42 # passing MPI datatypes explicitly
44 data = np.arange(1000, dtype='i')
45 comm.Send([data, MPI.INT], dest=1, tag=77)
47 data = np.empty(1000, dtype='i')
48 comm.Recv([data, MPI.INT], source=0, tag=77)
50 # automatic MPI datatype discovery
52 data = np.arange(100, dtype=np.float64)
53 comm.Send(data, dest=1, tag=13)
55 data = np.empty(100, dtype=np.float64)
56 comm.Recv(data, source=0, tag=13)
58 ###Collective Communication
60 #Broadcasting a Python dictionary:
63 data = {'key1' : [7, 2.72, 2+3j],
64 'key2' : ( 'abc', 'xyz')}
67 data = comm.bcast(data, root=0)
69 #Scattering Python objects:
72 data = [(i+1)**2 for i in range(size)]
75 data = comm.scatter(data, root=0)
76 assert data == (rank+1)**2
78 #Gathering Python objects:
81 data = comm.gather(data, root=0)
84 assert data[i] == (i+1)**2
88 # Broadcasting a NumPy array:
91 data = np.arange(100, dtype='i')
93 data = np.empty(100, dtype='i')
94 comm.Bcast(data, root=0)
98 #Scattering NumPy arrays:
102 sendbuf = np.empty([size, 100], dtype='i')
103 sendbuf.T[:,:] = range(size)
104 recvbuf = np.empty(100, dtype='i')
105 comm.Scatter(sendbuf, recvbuf, root=0)
106 assert np.allclose(recvbuf, rank)
108 #Gathering NumPy arrays:
110 sendbuf = np.zeros(100, dtype='i') + rank
113 recvbuf = np.empty([size, 100], dtype='i')
114 comm.Gather(sendbuf, recvbuf, root=0)
116 for i in range(size):
117 assert np.allclose(recvbuf[i,:], i)
119 #Parallel matrix-vector product: