Salome HOME
Merge branch 'occ/24009'
[tools/medcoupling.git] / src / ParaMEDMEM / CommInterface.hxx
1 // Copyright (C) 2007-2021  CEA/DEN, EDF R&D
2 //
3 // This library is free software; you can redistribute it and/or
4 // modify it under the terms of the GNU Lesser General Public
5 // License as published by the Free Software Foundation; either
6 // version 2.1 of the License, or (at your option) any later version.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
16 //
17 // See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
18 //
19
20 #pragma once
21
22 #include "ParaIdType.hxx"
23 #include "MEDCouplingMemArray.hxx"
24
25 #include <mpi.h>
26
27 #include <memory>
28 #include <numeric>
29
30 namespace MEDCoupling
31 {
32   template<class T>
33   struct ParaTraits
34   {
35     using EltType = T;
36   };
37   
38   template<>
39   struct ParaTraits<double>
40   {
41     static MPI_Datatype MPIDataType;
42   };
43
44   template<>
45   struct ParaTraits<Int32>
46   {
47     static MPI_Datatype MPIDataType;
48   };
49
50   template<>
51   struct ParaTraits<Int64>
52   {
53     static MPI_Datatype MPIDataType;
54   };
55
56   /*! \anchor CommInterface-det
57      \class CommInterface
58
59     The class \a CommInterface is the gateway to the MPI library.
60     It is a wrapper around all MPI calls, thus trying to abstract the rest of the code from using the direct MPI API
61     (but this is not strictly respected overall in practice ...). It is used in all
62     the \ref parallel "DEC related classes".
63
64     It is typically instantiated after the MPI_Init() call in a program and is afterwards passed as a
65     parameter to the constructors of various \ref parallel "parallel objects" so that they access the
66     MPI library via this common interface.
67
68     As an example, the following code excerpt initializes a processor group made of the zero processor.
69
70     \verbatim
71     #include "CommInterface.hxx"
72     #include "ProcessorGroup.hxx"
73
74     int main(int argc, char** argv)
75     {
76     //initialization
77     MPI_Init(&argc, &argv);
78     MEDCoupling::CommInterface comm_interface;
79
80     //setting up a processor group with proc 0
81     set<int> procs;
82     procs.insert(0);
83     MEDCoupling::ProcessorGroup group(procs, comm_interface);
84
85     //cleanup
86     MPI_Finalize();
87     }
88     \endverbatim
89   */
90   class CommInterface
91   {
92   public:
93     CommInterface() { }
94     virtual ~CommInterface() { }
95     int worldSize() const {
96       int size;
97       MPI_Comm_size(MPI_COMM_WORLD, &size);
98       return size;}
99     int commSize(MPI_Comm comm, int* size) const { return MPI_Comm_size(comm,size); }
100     int commRank(MPI_Comm comm, int* rank) const { return MPI_Comm_rank(comm,rank); }
101     int commGroup(MPI_Comm comm, MPI_Group* group) const { return MPI_Comm_group(comm, group); }
102     int groupIncl(MPI_Group group, int size, int* ranks, MPI_Group* group_output) const { return MPI_Group_incl(group, size, ranks, group_output); }
103     int commCreate(MPI_Comm comm, MPI_Group group, MPI_Comm* comm_output) const { return MPI_Comm_create(comm,group,comm_output); }
104     int groupFree(MPI_Group* group) const { return MPI_Group_free(group); }
105     int commFree(MPI_Comm* comm) const { return MPI_Comm_free(comm); }
106
107     int send(void* buffer, int count, MPI_Datatype datatype, int target, int tag, MPI_Comm comm) const { return MPI_Send(buffer,count, datatype, target, tag, comm); }
108     int recv(void* buffer, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Status* status) const { return MPI_Recv(buffer,count, datatype, source, tag, comm, status); }
109     int sendRecv(void* sendbuf, int sendcount, MPI_Datatype sendtype, 
110                  int dest, int sendtag, void* recvbuf, int recvcount, 
111                  MPI_Datatype recvtype, int source, int recvtag, MPI_Comm comm,
112                  MPI_Status* status) { return MPI_Sendrecv(sendbuf, sendcount, sendtype, dest, sendtag, recvbuf, recvcount, recvtype, source, recvtag, comm,status); }
113
114     int Isend(void* buffer, int count, MPI_Datatype datatype, int target,
115               int tag, MPI_Comm comm, MPI_Request *request) const { return MPI_Isend(buffer,count, datatype, target, tag, comm, request); }
116     int Irecv(void* buffer, int count, MPI_Datatype datatype, int source,
117               int tag, MPI_Comm comm, MPI_Request* request) const { return MPI_Irecv(buffer,count, datatype, source, tag, comm, request); }
118
119     int wait(MPI_Request *request, MPI_Status *status) const { return MPI_Wait(request, status); }
120     int test(MPI_Request *request, int *flag, MPI_Status *status) const { return MPI_Test(request, flag, status); }
121     int requestFree(MPI_Request *request) const { return MPI_Request_free(request); }
122     int waitany(int count, MPI_Request *array_of_requests, int *index, MPI_Status *status) const { return MPI_Waitany(count, array_of_requests, index, status); }
123     int testany(int count, MPI_Request *array_of_requests, int *index, int *flag, MPI_Status *status) const { return MPI_Testany(count, array_of_requests, index, flag, status); }
124     int waitall(int count, MPI_Request *array_of_requests, MPI_Status *array_of_status) const { return MPI_Waitall(count, array_of_requests, array_of_status); }
125     int testall(int count, MPI_Request *array_of_requests, int *flag, MPI_Status *array_of_status) const { return MPI_Testall(count, array_of_requests, flag, array_of_status); }
126     int waitsome(int incount, MPI_Request *array_of_requests,int *outcount, int *array_of_indices, MPI_Status *array_of_status) const { return MPI_Waitsome(incount, array_of_requests, outcount, array_of_indices, array_of_status); }
127     int testsome(int incount, MPI_Request *array_of_requests, int *outcount,
128                  int *array_of_indices, MPI_Status *array_of_status) const { return MPI_Testsome(incount, array_of_requests, outcount, array_of_indices, array_of_status); }
129     int probe(int source, int tag, MPI_Comm comm, MPI_Status *status) const { return MPI_Probe(source, tag, comm, status) ; }
130     int Iprobe(int source, int tag, MPI_Comm comm, int *flag, MPI_Status *status) const { return MPI_Iprobe(source, tag, comm, flag, status) ; }
131     int cancel(MPI_Request *request) const { return MPI_Cancel(request); }
132     int testCancelled(MPI_Status *status, int *flag) const { return MPI_Test_cancelled(status, flag); }
133     int barrier(MPI_Comm comm) const { return MPI_Barrier(comm); }
134     int errorString(int errorcode, char *string, int *resultlen) const { return MPI_Error_string(errorcode, string, resultlen); }
135     int getCount(MPI_Status *status, MPI_Datatype datatype, int *count) const { return MPI_Get_count(status, datatype, count); }
136
137     int broadcast(void* buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm) const { return MPI_Bcast(buffer, count,  datatype, root, comm); }
138     int gather(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm) const { return MPI_Gather(const_cast<void*>(sendbuf),sendcount,sendtype,recvbuf,recvcount,recvtype,root,comm); }
139     int gatherV(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, const int recvcounts[], const int displs[], MPI_Datatype recvtype, int root, MPI_Comm comm) const { return MPI_Gatherv(const_cast<void*>(sendbuf),sendcount,sendtype,recvbuf,const_cast<int *>(recvcounts),const_cast<int *>(displs),recvtype,root,comm); }
140     int allGather(void* sendbuf, int sendcount, MPI_Datatype sendtype,
141                   void* recvbuf, int recvcount, MPI_Datatype recvtype,
142                   MPI_Comm comm) const { return MPI_Allgather(sendbuf,sendcount, sendtype, recvbuf, recvcount, recvtype, comm); }
143     int allGatherV(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int recvcounts[],
144                    const int displs[], MPI_Datatype recvtype, MPI_Comm comm) const { return MPI_Allgatherv(const_cast<void*>(sendbuf),sendcount,sendtype,recvbuf,const_cast<int *>(recvcounts),const_cast<int *>(displs),recvtype,comm); }
145     int allToAll(void* sendbuf, int sendcount, MPI_Datatype sendtype,
146                  void* recvbuf, int recvcount, MPI_Datatype recvtype,
147                  MPI_Comm comm) const { return MPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm); }
148     int allToAllV(const void* sendbuf, int* sendcounts, int* senddispls,
149                   MPI_Datatype sendtype, void* recvbuf, int* recvcounts,
150                   int* recvdispls, MPI_Datatype recvtype, 
151                   MPI_Comm comm) const { return MPI_Alltoallv(const_cast<void*>(sendbuf), sendcounts, senddispls, sendtype, recvbuf, recvcounts, recvdispls, recvtype, comm); }
152
153     int reduce(void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype,
154                MPI_Op op, int root, MPI_Comm comm) const { return MPI_Reduce(sendbuf, recvbuf, count, datatype, op, root, comm); }
155     int allReduce(void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) const { return MPI_Allreduce(sendbuf, recvbuf, count, datatype, op, comm); }
156   public:
157     void gatherArrays(MPI_Comm comm, int root, const DataArrayIdType *array, std::vector< MCAuto<DataArrayIdType> >& arraysOut) const;
158     void allGatherArrays(MPI_Comm comm, const DataArrayIdType *array, std::vector< MCAuto<DataArrayIdType> >& arraysOut) const;
159     int allGatherArrays(MPI_Comm comm, const DataArrayIdType *array, std::unique_ptr<mcIdType[]>& result, std::unique_ptr<mcIdType[]>& resultIndex) const;
160     void allToAllArrays(MPI_Comm comm, const std::vector< MCAuto<DataArrayIdType> >& arrays, std::vector< MCAuto<DataArrayIdType> >& arraysOut) const;
161     void allToAllArrays(MPI_Comm comm, const std::vector< MCAuto<DataArrayDouble> >& arrays, std::vector< MCAuto<DataArrayDouble> >& arraysOut) const;
162     void allToAllArrays(MPI_Comm comm, const std::vector< MCAuto<DataArrayDouble> >& arrays, MCAuto<DataArrayDouble>& arraysOut) const;
163     
164     template<class T>
165     int gatherArraysT(MPI_Comm comm, int root, const typename Traits<T>::ArrayType *array, std::unique_ptr<T[]>& result, std::unique_ptr<mcIdType[]>& resultIndex, int& rank) const
166     {
167       int size;
168       this->commSize(comm,&size);
169       rank = -1;
170       this->commRank(comm,&rank);
171       std::unique_ptr<mcIdType[]> nbOfElems;
172       if(rank==root)
173         nbOfElems.reset(new mcIdType[size]);
174       mcIdType nbOfCellsRequested(array->getNumberOfTuples());
175       this->gather(&nbOfCellsRequested,1,MPI_ID_TYPE,nbOfElems.get(),1,MPI_ID_TYPE,root,comm);
176       std::unique_ptr<int[]> nbOfElemsInt,offsetsIn;
177       if(rank==root)
178       {
179         mcIdType nbOfCellIdsSum(std::accumulate(nbOfElems.get(),nbOfElems.get()+size,0));
180         result.reset(new T[nbOfCellIdsSum]);
181         nbOfElemsInt = CommInterface::ToIntArray<mcIdType>(nbOfElems,size);
182         offsetsIn = CommInterface::ComputeOffset(nbOfElemsInt,size);
183       }
184       this->gatherV(array->begin(),nbOfCellsRequested,ParaTraits<T>::MPIDataType,result.get(),nbOfElemsInt.get(),offsetsIn.get(),ParaTraits<T>::MPIDataType,root,comm);
185       if(rank==root)
186       {
187         resultIndex = ComputeOffsetFull<mcIdType>(nbOfElems,size);
188       }
189       return size;
190     }
191
192     template<class T>
193     void gatherArraysT2(MPI_Comm comm, int root, const typename Traits<T>::ArrayType *array, std::vector< MCAuto<typename Traits<T>::ArrayType> >& arraysOut) const
194     {
195       using DataArrayT = typename Traits<T>::ArrayType;
196       std::unique_ptr<T[]> result;
197       std::unique_ptr<mcIdType[]> resultIndex;
198       int rank(-1);
199       int size(this->gatherArraysT<T>(comm,root,array,result,resultIndex,rank));
200       arraysOut.resize(size);
201       for(int i = 0 ; i < size ; ++i)
202       {
203         arraysOut[i] = DataArrayT::New();
204         if(rank == root)
205         {
206           mcIdType nbOfEltPack(resultIndex[i+1]-resultIndex[i]);
207           arraysOut[i]->alloc(nbOfEltPack,1);
208           std::copy(result.get()+resultIndex[i],result.get()+resultIndex[i+1],arraysOut[i]->getPointer());
209         }
210       }
211     }
212
213     template<class T>
214     int allGatherArraysT(MPI_Comm comm, const typename Traits<T>::ArrayType *array, std::unique_ptr<T[]>& result, std::unique_ptr<mcIdType[]>& resultIndex) const
215     {
216       int size;
217       this->commSize(comm,&size);
218       std::unique_ptr<mcIdType[]> nbOfElems(new mcIdType[size]);
219       mcIdType nbOfCellsRequested(array->getNumberOfTuples());
220       this->allGather(&nbOfCellsRequested,1,MPI_ID_TYPE,nbOfElems.get(),1,MPI_ID_TYPE,comm);
221       mcIdType nbOfCellIdsSum(std::accumulate(nbOfElems.get(),nbOfElems.get()+size,0));
222       result.reset(new T[nbOfCellIdsSum]);
223       std::unique_ptr<int[]> nbOfElemsInt( CommInterface::ToIntArray<mcIdType>(nbOfElems,size) );
224       std::unique_ptr<int[]> offsetsIn( CommInterface::ComputeOffset(nbOfElemsInt,size) );
225       this->allGatherV(array->begin(),nbOfCellsRequested,ParaTraits<T>::MPIDataType,result.get(),nbOfElemsInt.get(),offsetsIn.get(),ParaTraits<T>::MPIDataType,comm);
226       resultIndex = ComputeOffsetFull<mcIdType>(nbOfElems,size);
227       return size;
228     }
229
230     template<class T>
231     void allGatherArraysT2(MPI_Comm comm, const typename Traits<T>::ArrayType *array, std::vector< MCAuto<typename Traits<T>::ArrayType> >& arraysOut) const
232     {
233       using DataArrayT = typename Traits<T>::ArrayType;
234       std::unique_ptr<T[]> result;
235       std::unique_ptr<mcIdType[]> resultIndex;
236       int size(this->allGatherArraysT<T>(comm,array,result,resultIndex));
237       arraysOut.resize(size);
238       for(int i = 0 ; i < size ; ++i)
239       {
240         arraysOut[i] = DataArrayT::New();
241         mcIdType nbOfEltPack(resultIndex[i+1]-resultIndex[i]);
242         arraysOut[i]->alloc(nbOfEltPack,1);
243         std::copy(result.get()+resultIndex[i],result.get()+resultIndex[i+1],arraysOut[i]->getPointer());
244       }
245     }
246
247     template<class T>
248     int allToAllArraysT2(MPI_Comm comm, const std::vector< MCAuto<typename Traits<T>::ArrayType> >& arrays, MCAuto<typename Traits<T>::ArrayType>& arrayOut, std::unique_ptr<mcIdType[]>& nbOfElems2, mcIdType& nbOfComponents) const
249     {
250       using DataArrayT = typename Traits<T>::ArrayType;
251       int size;
252       this->commSize(comm,&size);
253       if( arrays.size() != ToSizeT(size) )
254         throw INTERP_KERNEL::Exception("AllToAllArrays : internal error ! Invalid size of input array.");
255         
256       std::vector< const DataArrayT *> arraysBis(FromVecAutoToVecOfConst<DataArrayT>(arrays));
257       std::unique_ptr<mcIdType[]> nbOfElems3(new mcIdType[size]);
258       nbOfElems2.reset(new mcIdType[size]);
259       nbOfComponents = std::numeric_limits<mcIdType>::max();
260       for(int curRk = 0 ; curRk < size ; ++curRk)
261       {
262         mcIdType curNbOfCompo( ToIdType( arrays[curRk]->getNumberOfComponents() ) );
263         if(nbOfComponents != std::numeric_limits<mcIdType>::max())
264         {
265           if( nbOfComponents != curNbOfCompo )
266             throw INTERP_KERNEL::Exception("AllToAllArrays : internal error ! Nb of components is not homogeneous !");
267         }
268         else
269         {
270           nbOfComponents = curNbOfCompo;
271         }
272         nbOfElems3[curRk] = arrays[curRk]->getNbOfElems();
273       }
274       this->allToAll(nbOfElems3.get(),1,MPI_ID_TYPE,nbOfElems2.get(),1,MPI_ID_TYPE,comm);
275       mcIdType nbOfCellIdsSum(std::accumulate(nbOfElems2.get(),nbOfElems2.get()+size,0));
276       arrayOut = DataArrayT::New();
277       arrayOut->alloc(nbOfCellIdsSum/nbOfComponents,nbOfComponents);
278       std::unique_ptr<int[]> nbOfElemsInt( CommInterface::ToIntArray<mcIdType>(nbOfElems3,size) ),nbOfElemsOutInt( CommInterface::ToIntArray<mcIdType>(nbOfElems2,size) );
279       std::unique_ptr<int[]> offsetsIn( CommInterface::ComputeOffset(nbOfElemsInt,size) ), offsetsOut( CommInterface::ComputeOffset(nbOfElemsOutInt,size) );
280       {
281         MCAuto<DataArrayT> arraysAcc(DataArrayT::Aggregate(arraysBis));
282         this->allToAllV(arraysAcc->begin(),nbOfElemsInt.get(),offsetsIn.get(),ParaTraits<T>::MPIDataType,
283                         arrayOut->getPointer(),nbOfElemsOutInt.get(),offsetsOut.get(),ParaTraits<T>::MPIDataType,comm);
284       }
285       return size;
286     }
287
288     template<class T>
289     void allToAllArraysT(MPI_Comm comm, const std::vector< MCAuto<typename Traits<T>::ArrayType> >& arrays, std::vector< MCAuto<typename Traits<T>::ArrayType> >& arraysOut) const
290     {
291       using DataArrayT = typename Traits<T>::ArrayType;
292       MCAuto<DataArrayT> cellIdsFromProcs;
293       std::unique_ptr<mcIdType[]> nbOfElems2;
294       mcIdType nbOfComponents(0);
295       int size(this->allToAllArraysT2<T>(comm,arrays,cellIdsFromProcs,nbOfElems2,nbOfComponents));
296       std::unique_ptr<mcIdType[]> offsetsOutIdType( CommInterface::ComputeOffset(nbOfElems2,size) );
297       // build output arraysOut by spliting cellIdsFromProcs into parts
298       arraysOut.resize(size);
299       for(int curRk = 0 ; curRk < size ; ++curRk)
300       {
301         arraysOut[curRk] = DataArrayT::NewFromArray(cellIdsFromProcs->begin()+offsetsOutIdType[curRk],cellIdsFromProcs->begin()+offsetsOutIdType[curRk]+nbOfElems2[curRk]);
302         arraysOut[curRk]->rearrange(nbOfComponents);
303       }
304     }
305   public:
306
307     /*!
308     * \a counts is expected to be an array of array length. This method returns an array of split array.
309     */
310     static std::unique_ptr<mcIdType[]> SplitArrayOfLength(const std::unique_ptr<mcIdType[]>& counts, std::size_t countsSz, int rk, int size)
311     {
312       std::unique_ptr<mcIdType[]> ret(new mcIdType[countsSz]);
313       for(std::size_t i=0;i<countsSz;++i)
314       {
315         mcIdType a,b;
316         DataArray::GetSlice(0,counts[i],1,rk,size,a,b);
317         ret[i] = b-a;
318       }
319       return ret;
320     }
321
322     /*!
323     * Helper of alltoallv and allgatherv
324     */
325     template<class T>
326     static std::unique_ptr<int []> ToIntArray(const std::unique_ptr<T []>& arr, std::size_t size)
327     {
328       std::unique_ptr<int []> ret(new int[size]);
329       std::copy(arr.get(),arr.get()+size,ret.get());
330       return ret;
331     }
332     
333     /*!
334     * Helper of alltoallv and allgatherv
335     */
336     template<class T>
337     static std::unique_ptr<T []> ComputeOffset(const std::unique_ptr<T []>& counts, std::size_t sizeOfCounts)
338     {
339       std::unique_ptr<T []> ret(new T[sizeOfCounts]);
340       ret[0] = static_cast<T>(0);
341       for(std::size_t i = 1 ; i < sizeOfCounts ; ++i)
342       {
343         ret[i] = ret[i-1] + counts[i-1];
344       }
345       return ret;
346     }
347
348     /*!
349     * Helper of alltoallv and allgatherv
350     */
351     template<class T>
352     static std::unique_ptr<T []> ComputeOffsetFull(const std::unique_ptr<T []>& counts, std::size_t sizeOfCounts)
353     {
354       std::unique_ptr<T []> ret(new T[sizeOfCounts+1]);
355       ret[0] = static_cast<T>(0);
356       for(std::size_t i = 1 ; i < sizeOfCounts+1 ; ++i)
357       {
358         ret[i] = ret[i-1] + counts[i-1];
359       }
360       return ret;
361     }
362   };
363 }