Salome HOME
Copyright update 2021
[tools/medcoupling.git] / src / ParaMEDMEM / ParaUMesh.hxx
1 // Copyright (C) 2020-2021  CEA/DEN, EDF R&D
2 //
3 // This library is free software; you can redistribute it and/or
4 // modify it under the terms of the GNU Lesser General Public
5 // License as published by the Free Software Foundation; either
6 // version 2.1 of the License, or (at your option) any later version.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
16 //
17 // See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
18 //
19 // Author : Anthony Geay (EDF R&D)
20
21 #pragma once
22
23 #include "MEDCouplingUMesh.hxx"
24 #include "ProcessorGroup.hxx"
25 #include "MEDCouplingMemArray.hxx"
26
27 #include <string>
28 #include <vector>
29
30 namespace MEDCoupling
31 {
32   /*!
33    * Parallel representation of an unstructured mesh.
34    *
35    * This class is very specific to the requirement of parallel code computations.
36    */
37   class ParaUMesh : public RefCountObject
38   {
39   public:
40     static ParaUMesh *New(MEDCouplingUMesh *mesh, DataArrayIdType *globalCellIds, DataArrayIdType *globalNodeIds);
41     MCAuto<DataArrayIdType> getCellIdsLyingOnNodes(const DataArrayIdType *globalNodeIds, bool fullyIn) const;
42     ParaUMesh *redistributeCells(const DataArrayIdType *globalCellIds) const;
43     DataArrayDouble *redistributeCellField(const DataArrayIdType *globalCellIds, const DataArrayDouble *fieldValueToRed) const;
44     DataArrayIdType *redistributeCellField(const DataArrayIdType *globalCellIds, const DataArrayIdType *fieldValueToRed) const;
45     DataArrayDouble *redistributeNodeField(const DataArrayIdType *globalCellIds, const DataArrayDouble *fieldValueToRed) const;
46     DataArrayIdType *redistributeNodeField(const DataArrayIdType *globalCellIds, const DataArrayIdType *fieldValueToRed) const;
47     MEDCouplingUMesh *getMesh() { return _mesh; }
48     DataArrayIdType *getGlobalCellIds() { return _cell_global; }
49     DataArrayIdType *getGlobalNodeIds() { return _node_global; }
50   protected:
51     virtual ~ParaUMesh() { }
52     ParaUMesh(MEDCouplingUMesh *mesh, DataArrayIdType *globalCellIds, DataArrayIdType *globalNodeIds);
53     std::string getClassName() const override { return "ParaUMesh"; }
54     std::size_t getHeapMemorySizeWithoutChildren() const override;
55     std::vector<const BigMemoryObject *> getDirectChildrenWithNull() const override;
56   private:
57     MCAuto<MEDCouplingUMesh> _mesh;
58     MCAuto<DataArrayIdType> _cell_global;
59     MCAuto<DataArrayIdType> _node_global;
60   private:
61     MCAuto<DataArrayIdType> getCellIdsLyingOnNodesFalse(const DataArrayIdType *globalNodeIds) const;
62     MCAuto<DataArrayIdType> getCellIdsLyingOnNodesTrue(const DataArrayIdType *globalNodeIds) const;
63     template<class T>
64     typename Traits<T>::ArrayType *redistributeCellFieldT(const DataArrayIdType *globalCellIds, const typename Traits<T>::ArrayType *fieldValueToRed) const
65     {
66       using DataArrayT = typename Traits<T>::ArrayType;
67       MPI_Comm comm(MPI_COMM_WORLD);
68       CommInterface ci;
69       if( _cell_global->getNumberOfTuples() != fieldValueToRed->getNumberOfTuples() )
70         throw INTERP_KERNEL::Exception("PAraUMesh::redistributeCellFieldT : invalid input length of array !");
71       std::unique_ptr<mcIdType[]> allGlobalCellIds,allGlobalCellIdsIndex;
72       int size(ci.allGatherArrays(comm,globalCellIds,allGlobalCellIds,allGlobalCellIdsIndex));
73       // Prepare ParaUMesh parts to be sent : compute for each proc the contribution of current rank.
74       std::vector< MCAuto<DataArrayIdType> > globalCellIdsToBeSent(size);
75       std::vector< MCAuto<DataArrayT> > fieldToBeSent(size);
76       for(int curRk = 0 ; curRk < size ; ++curRk)
77       {
78         mcIdType offset(allGlobalCellIdsIndex[curRk]);
79         MCAuto<DataArrayIdType> globalCellIdsOfCurProc(DataArrayIdType::New());
80         globalCellIdsOfCurProc->useArray(allGlobalCellIds.get()+offset,false,DeallocType::CPP_DEALLOC,allGlobalCellIdsIndex[curRk+1]-offset,1);
81         // the key call is here : compute for rank curRk the cells to be sent
82         MCAuto<DataArrayIdType> globalCellIdsCaptured(_cell_global->buildIntersection(globalCellIdsOfCurProc));// OK for the global cellIds
83         MCAuto<DataArrayIdType> localCellIdsCaptured(_cell_global->findIdForEach(globalCellIdsCaptured->begin(),globalCellIdsCaptured->end()));
84         globalCellIdsToBeSent[curRk] = globalCellIdsCaptured;
85         fieldToBeSent[curRk] = fieldValueToRed->selectByTupleIdSafe(localCellIdsCaptured->begin(),localCellIdsCaptured->end());
86       }
87       // Receive
88       std::vector< MCAuto<DataArrayIdType> > globalCellIdsReceived;
89       ci.allToAllArrays(comm,globalCellIdsToBeSent,globalCellIdsReceived);
90       std::vector< MCAuto<DataArrayT> > fieldValueReceived;
91       ci.allToAllArrays(comm,fieldToBeSent,fieldValueReceived);
92       // use globalCellIdsReceived to reorganize everything
93       MCAuto<DataArrayIdType> aggregatedCellIds( DataArrayIdType::Aggregate(FromVecAutoToVecOfConst<DataArrayIdType>(globalCellIdsReceived)) );
94       MCAuto<DataArrayIdType> aggregatedCellIdsSorted(aggregatedCellIds->copySorted());
95       MCAuto<DataArrayIdType> idsIntoAggregatedIds(DataArrayIdType::FindPermutationFromFirstToSecondDuplicate(aggregatedCellIdsSorted,aggregatedCellIds));
96       MCAuto<DataArrayIdType> cellIdsOfSameNodeIds(aggregatedCellIdsSorted->indexOfSameConsecutiveValueGroups());
97       MCAuto<DataArrayIdType> n2o_cells(idsIntoAggregatedIds->selectByTupleIdSafe(cellIdsOfSameNodeIds->begin(),cellIdsOfSameNodeIds->end()-1));//new == new ordering so that global cell ids are sorted . old == coarse ordering implied by the aggregation
98       //
99       MCAuto<DataArrayT> fieldAggregated(DataArrayT::Aggregate(FromVecAutoToVecOfConst<DataArrayT>(fieldValueReceived)));
100       MCAuto<DataArrayT> ret(fieldAggregated->selectByTupleIdSafe(n2o_cells->begin(),n2o_cells->end()));
101       return ret.retn();
102     }
103     
104     template<class T>
105     typename Traits<T>::ArrayType *redistributeNodeFieldT(const DataArrayIdType *globalCellIds, const typename Traits<T>::ArrayType *fieldValueToRed) const
106     {
107       using DataArrayT = typename Traits<T>::ArrayType;
108       MPI_Comm comm(MPI_COMM_WORLD);
109       CommInterface ci;
110       if( _node_global->getNumberOfTuples() != fieldValueToRed->getNumberOfTuples() )
111         throw INTERP_KERNEL::Exception("PAraUMesh::redistributeNodeFieldT : invalid input length of array !");
112       std::unique_ptr<mcIdType[]> allGlobalCellIds,allGlobalCellIdsIndex;
113       int size(ci.allGatherArrays(comm,globalCellIds,allGlobalCellIds,allGlobalCellIdsIndex));
114       // Prepare ParaUMesh parts to be sent : compute for each proc the contribution of current rank.
115       std::vector< MCAuto<DataArrayIdType> > globalNodeIdsToBeSent(size);
116       std::vector< MCAuto<DataArrayT> > fieldToBeSent(size);
117       for(int curRk = 0 ; curRk < size ; ++curRk)
118       {
119         mcIdType offset(allGlobalCellIdsIndex[curRk]);
120         MCAuto<DataArrayIdType> globalCellIdsOfCurProc(DataArrayIdType::New());
121         globalCellIdsOfCurProc->useArray(allGlobalCellIds.get()+offset,false,DeallocType::CPP_DEALLOC,allGlobalCellIdsIndex[curRk+1]-offset,1);
122         // the key call is here : compute for rank curRk the cells to be sent
123         MCAuto<DataArrayIdType> globalCellIdsCaptured(_cell_global->buildIntersection(globalCellIdsOfCurProc));// OK for the global cellIds
124         MCAuto<DataArrayIdType> localCellIdsCaptured(_cell_global->findIdForEach(globalCellIdsCaptured->begin(),globalCellIdsCaptured->end()));
125         MCAuto<MEDCouplingUMesh> meshPart(_mesh->buildPartOfMySelf(localCellIdsCaptured->begin(),localCellIdsCaptured->end(),true));
126         MCAuto<DataArrayIdType> o2n(meshPart->zipCoordsTraducer());// OK for the mesh
127         MCAuto<DataArrayIdType> n2o(o2n->invertArrayO2N2N2O(meshPart->getNumberOfNodes()));
128         MCAuto<DataArrayIdType> globalNodeIdsPart(_node_global->selectByTupleIdSafe(n2o->begin(),n2o->end())); // OK for the global nodeIds
129         globalNodeIdsToBeSent[curRk] = globalNodeIdsPart;
130         fieldToBeSent[curRk] = fieldValueToRed->selectByTupleIdSafe(n2o->begin(),n2o->end());
131       }
132       // Receive
133       std::vector< MCAuto<DataArrayIdType> > globalNodeIdsReceived;
134       ci.allToAllArrays(comm,globalNodeIdsToBeSent,globalNodeIdsReceived);
135       std::vector< MCAuto<DataArrayT> > fieldValueReceived;
136       ci.allToAllArrays(comm,fieldToBeSent,fieldValueReceived);
137       // firstly deal with nodes.
138       MCAuto<DataArrayIdType> aggregatedNodeIds( DataArrayIdType::Aggregate(FromVecAutoToVecOfConst<DataArrayIdType>(globalNodeIdsReceived)) );
139       MCAuto<DataArrayIdType> aggregatedNodeIdsSorted(aggregatedNodeIds->copySorted());
140       MCAuto<DataArrayIdType> nodeIdsIntoAggregatedIds(DataArrayIdType::FindPermutationFromFirstToSecondDuplicate(aggregatedNodeIdsSorted,aggregatedNodeIds));
141       MCAuto<DataArrayIdType> idxOfSameNodeIds(aggregatedNodeIdsSorted->indexOfSameConsecutiveValueGroups());
142       MCAuto<DataArrayIdType> n2o_nodes(nodeIdsIntoAggregatedIds->selectByTupleIdSafe(idxOfSameNodeIds->begin(),idxOfSameNodeIds->end()-1));//new == new ordering so that global node ids are sorted . old == coarse ordering implied by the aggregation
143       //
144       MCAuto<DataArrayT> fieldAggregated(DataArrayT::Aggregate(FromVecAutoToVecOfConst<DataArrayT>(fieldValueReceived)));
145       MCAuto<DataArrayT> ret(fieldAggregated->selectByTupleIdSafe(n2o_nodes->begin(),n2o_nodes->end()));
146       //
147       return ret.retn();
148     }
149   };
150 }