Salome HOME
82f226b7b787d224101809042bc203df2c8b21d0
[modules/multipr.git] / src / MULTIPR / MULTIPR_DecimationFilter.cxx
1 // Project MULTIPR, IOLS WP1.2.1 - EDF/CS
2 // Partitioning/decimation module for the SALOME v3.2 platform
3
4 /**
5  * \file    MULTIPR_DecimationFilter.cxx
6  *
7  * \brief   see MULTIPR_DecimationFilter.hxx
8  *
9  * \author  Olivier LE ROUX - CS, Virtual Reality Dpt
10  * 
11  * \date    01/2007
12  */
13
14 //*****************************************************************************
15 // Includes section
16 //*****************************************************************************
17
18 #include "MULTIPR_DecimationFilter.hxx"
19 #include "MULTIPR_Field.hxx"
20 #include "MULTIPR_Mesh.hxx"
21 #include "MULTIPR_PointOfField.hxx"
22 #include "MULTIPR_DecimationAccel.hxx"
23 #include "MULTIPR_Exceptions.hxx"
24
25 #include <iostream>
26
27 using namespace std;
28
29
30 namespace multipr
31 {
32
33
34 //*****************************************************************************
35 // Class DecimationFilter implementation
36 //*****************************************************************************
37
38 // factory
39 DecimationFilter* DecimationFilter::create(const char* pFilterName)
40 {
41         if (pFilterName == NULL) throw NullArgumentException("filter name should not be NULL", __FILE__, __LINE__);
42         
43         if (strcmp(pFilterName, "Filtre_GradientMoyen") == 0)
44         {
45                 return new DecimationFilterGradAvg();
46         }
47         else
48         {
49                 throw IllegalArgumentException("unknown filter", __FILE__, __LINE__);
50         }
51 }
52
53
54 //*****************************************************************************
55 // Class DecimationFilterGradAvg
56 //*****************************************************************************
57
58 DecimationFilterGradAvg::DecimationFilterGradAvg() 
59 {
60         // do nothing
61 }
62
63
64 DecimationFilterGradAvg::~DecimationFilterGradAvg()  
65
66         // do nothing
67 }
68
69
70 Mesh* DecimationFilterGradAvg::apply(Mesh* pMesh, const char* pArgv, const char* pNameNewMesh)
71 {
72         //---------------------------------------------------------------------
73         // Retrieve and check parameters
74         //---------------------------------------------------------------------
75         if (pMesh == NULL) throw NullArgumentException("pMesh should not be NULL", __FILE__, __LINE__);
76         if (pArgv == NULL) throw NullArgumentException("pArgv should not be NULL", __FILE__, __LINE__);
77         if (pNameNewMesh == NULL) throw NullArgumentException("pNameNewMesh should not be NULL", __FILE__, __LINE__);
78         
79         char   fieldName[MED_TAILLE_NOM + 1];
80         int    fieldIt;
81         double threshold;
82         double radius;
83         int    boxing; // number of cells along axis (if 100 then grid will have 100*100*100 = 10**6 cells)
84         
85         int ret = sscanf(pArgv, "%s %d %lf %lf %d",
86                 fieldName,
87                 &fieldIt,
88                 &threshold,
89                 &radius,
90                 &boxing);
91         
92         if (ret != 5) throw IllegalArgumentException("wrong number of arguments for filter GradAvg; expected 5 parameters", __FILE__, __LINE__);
93         
94         //---------------------------------------------------------------------
95         // Retrieve field = for each point: get its coordinate and the value of the field
96         //---------------------------------------------------------------------
97         Field* field = pMesh->getFieldByName(fieldName);
98         
99         if (field == NULL) throw IllegalArgumentException("field not found", __FILE__, __LINE__);
100         if ((fieldIt < 1) || (fieldIt > field->getNumberOfTimeSteps())) throw IllegalArgumentException("invalid field iteration", __FILE__, __LINE__);
101         
102         vector<PointOfField> points;
103         pMesh->getAllPointsOfField(field, fieldIt, points);
104
105         //---------------------------------------------------------------------
106         // Creates acceleration structure used to compute gradient
107         //---------------------------------------------------------------------
108         DecimationAccel* accel = new DecimationAccelGrid();
109         char strCfg[256]; // a string is used for genericity
110         sprintf(strCfg, "%d %d %d", boxing, boxing, boxing);
111         accel->configure(strCfg);
112         accel->create(points);
113         
114         //---------------------------------------------------------------------
115         // Collects elements of the mesh to be kept
116         //---------------------------------------------------------------------
117         set<int> elementsToKeep;
118         
119         int numElements = pMesh->getNumberOfElements();
120         int numGaussPointsByElt = points.size() / numElements; // for a TETRA10, should be 5 for a field of elements and 10 for a field of nodes
121         
122         // for each element
123         for (int itElt = 0 ; itElt < numElements ; itElt++)
124         {
125                 bool keepElement = false;
126                 
127                 // for each Gauss point of the current element
128                 for (int itPtGauss = 0 ; itPtGauss < numGaussPointsByElt ; itPtGauss++)
129                 {
130                         const PointOfField& currentPt = points[itElt * numGaussPointsByElt + itPtGauss];
131                         
132                         vector<PointOfField> neighbours = accel->findNeighbours(
133                                 currentPt.mXYZ[0], 
134                                 currentPt.mXYZ[1], 
135                                 currentPt.mXYZ[2], 
136                                 radius);
137                         
138                         // if no neighbours => keep element
139                         if (neighbours.size() == 0)
140                         {
141                                 keepElement = true;
142                                 break;
143                         }
144                         
145                         // otherwise compute gradient...
146                         med_float normGrad = computeNormGrad(currentPt, neighbours);
147                         
148                         // debug
149                         //cout << (itElt * numGaussPointsByElt + j) << ": " << normGrad << endl;
150                         
151                         if ((normGrad >= threshold) || isnan(normGrad))
152                         {
153                                 keepElement = true;
154                                 break;
155                         }
156                 }
157                 
158                 if (keepElement)
159                 {
160                         // add index of the element to keep (index must start at 1)
161                         elementsToKeep.insert(itElt + 1);
162                 }
163         }
164
165         //---------------------------------------------------------------------
166         // Cleans
167         //---------------------------------------------------------------------
168         delete accel;
169         
170         //---------------------------------------------------------------------
171         // Create the final mesh by extracting elements to keep from the current mesh
172         //---------------------------------------------------------------------
173         Mesh* newMesh = pMesh->createFromSetOfElements(elementsToKeep, pNameNewMesh);
174         
175         return newMesh;
176 }
177
178
179 void DecimationFilterGradAvg::getGradientInfo(
180                 Mesh*       pMesh, 
181                 const char* pFieldName, 
182                 int         pFieldIt, 
183                 double      pRadius,
184                 int         pBoxing,
185                 double*     pOutGradMin,
186                 double*     pOutGradAvg,
187                 double*     pOutGradMax)
188 {
189         if (pMesh == NULL) throw NullArgumentException("pMesh should not be NULL", __FILE__, __LINE__);
190         if (pFieldName == NULL) throw NullArgumentException("pFieldName should not be NULL", __FILE__, __LINE__);
191         
192         Field* field = pMesh->getFieldByName(pFieldName);
193         
194         if (field == NULL) throw IllegalArgumentException("field not found", __FILE__, __LINE__);
195         if ((pFieldIt < 1) || (pFieldIt > field->getNumberOfTimeSteps())) throw IllegalArgumentException("invalid field iteration", __FILE__, __LINE__);
196         
197         vector<PointOfField> points;
198         pMesh->getAllPointsOfField(field, pFieldIt, points);
199
200         //---------------------------------------------------------------------
201         // Creates acceleration structure used to compute gradient
202         //---------------------------------------------------------------------
203         DecimationAccel* accel = new DecimationAccelGrid();
204         char strCfg[256]; // a string is used for genericity
205         sprintf(strCfg, "%d %d %d", pBoxing, pBoxing, pBoxing);
206         accel->configure(strCfg);
207         accel->create(points);
208         
209         //---------------------------------------------------------------------
210         // Collects elements of the mesh to be kept
211         //---------------------------------------------------------------------
212         
213         int numElements = pMesh->getNumberOfElements();
214         int numGaussPointsByElt = points.size() / numElements; // for a TETRA10, should be 5 for a field of elements and 10 for a field of nodes
215         
216         *pOutGradMax = -1e300;
217         *pOutGradMin = 1e300;
218         *pOutGradAvg = 0.0;
219         int count = 0;
220         
221         //cout << "numElements=" << numElements << endl;
222         //cout << "num gauss pt by elt=" << numGaussPointsByElt << endl;
223         
224         // for each element
225         for (int itElt = 0 ; itElt < numElements ; itElt++)
226         {
227                 // for each Gauss point of the current element
228                 for (int itPtGauss = 0 ; itPtGauss < numGaussPointsByElt ; itPtGauss++)
229                 {
230                         const PointOfField& currentPt = points[itElt * numGaussPointsByElt + itPtGauss];
231                         
232                         vector<PointOfField> neighbours = accel->findNeighbours(
233                                 currentPt.mXYZ[0], 
234                                 currentPt.mXYZ[1], 
235                                 currentPt.mXYZ[2], 
236                                 pRadius);
237                         
238                         // if no neighbours => keep element
239                         if (neighbours.size() == 0)
240                         {
241                                 continue;
242                         }
243                         
244                         // otherwise compute gradient...
245                         med_float normGrad = computeNormGrad(currentPt, neighbours);
246                         
247                         // debug
248                         //cout << (itElt * numGaussPointsByElt + j) << ": " << normGrad << endl;
249                         
250                         if (!isnan(normGrad))
251                         {
252                                 if (normGrad > *pOutGradMax) *pOutGradMax = normGrad;
253                                 if (normGrad < *pOutGradMin) *pOutGradMin = normGrad;
254                                 *pOutGradAvg += normGrad;
255                                 count++;
256                         }
257                 }
258         }
259         
260         if (count != 0) *pOutGradAvg /= double(count);
261
262         //---------------------------------------------------------------------
263         // Cleans
264         //---------------------------------------------------------------------
265         delete accel;
266 }
267
268
269 med_float DecimationFilterGradAvg::computeNormGrad(const PointOfField& pPt, const std::vector<PointOfField>& pNeighbours) const
270 {
271         med_float gradX = 0.0;
272         med_float gradY = 0.0;
273         med_float gradZ = 0.0;
274  
275         // for each neighbour
276         for (unsigned i = 0 ; i < pNeighbours.size() ; i++)
277         {
278                 const PointOfField& neighbourPt = pNeighbours[i];
279                 
280                 med_float vecX = neighbourPt.mXYZ[0] - pPt.mXYZ[0];
281                 med_float vecY = neighbourPt.mXYZ[1] - pPt.mXYZ[1];
282                 med_float vecZ = neighbourPt.mXYZ[2] - pPt.mXYZ[2];
283                 
284                 med_float norm = med_float( sqrt( vecX*vecX + vecY*vecY + vecZ*vecZ ) );
285                 med_float val =  neighbourPt.mVal - pPt.mVal;
286                 
287                 val /= norm;
288                 
289                 gradX += vecX * val;
290                 gradY += vecY * val;
291                 gradZ += vecZ * val;
292         }
293         
294         med_float invSize = 1.0 / med_float(pNeighbours.size());
295         
296         gradX *= invSize;
297         gradY *= invSize;
298         gradZ *= invSize;
299         
300         med_float norm = med_float( sqrt( gradX*gradX + gradY*gradY + gradZ*gradZ ) );
301         
302         return norm;
303         
304 }
305
306
307 } // namespace multipr
308
309 // EOF