Salome HOME
Join modifications from BR_Dev_For_4_0 tag V4_1_1.
[modules/med.git] / src / MULTIPR / MULTIPR_DecimationFilter.cxx
1 // Project MULTIPR, IOLS WP1.2.1 - EDF/CS
2 // Partitioning/decimation module for the SALOME v3.2 platform
3
4 /**
5  * \file    MULTIPR_DecimationFilter.cxx
6  *
7  * \brief   see MULTIPR_DecimationFilter.hxx
8  *
9  * \author  Olivier LE ROUX - CS, Virtual Reality Dpt
10  * 
11  * \date    01/2007
12  */
13
14 //*****************************************************************************
15 // Includes section
16 //*****************************************************************************
17
18 #include "MULTIPR_DecimationFilter.hxx"
19 #include "MULTIPR_Field.hxx"
20 #include "MULTIPR_Mesh.hxx"
21 #include "MULTIPR_Nodes.hxx"
22 #include "MULTIPR_Elements.hxx"
23 #include "MULTIPR_Profil.hxx"
24 #include "MULTIPR_PointOfField.hxx"
25 #include "MULTIPR_DecimationAccel.hxx"
26 #include "MULTIPR_Exceptions.hxx"
27
28 #include <cmath>
29 #include <iostream>
30
31 using namespace std;
32
33
34 namespace multipr
35 {
36
37 //*****************************************************************************
38 // Class DecimationFilter implementation
39 //*****************************************************************************
40
41 // Factory used to build all filters from their name.
42 DecimationFilter* DecimationFilter::create(const char* pFilterName)
43 {
44     if (pFilterName == NULL) throw NullArgumentException("filter name should not be NULL", __FILE__, __LINE__);
45     
46     if (strcmp(pFilterName, "Filtre_GradientMoyen") == 0)
47     {
48         return new DecimationFilterGradAvg();
49     }
50         else if (strcmp(pFilterName, "Filtre_Direct") == 0)
51         {
52                 return new DecimationFilterTreshold();
53         }
54     else
55     {
56         throw IllegalArgumentException("unknown filter", __FILE__, __LINE__);
57     }
58 }
59
60
61 //*****************************************************************************
62 // Class DecimationFilterGradAvg
63 //*****************************************************************************
64
65 DecimationFilterGradAvg::DecimationFilterGradAvg() 
66 {
67     // do nothing
68 }
69
70
71 DecimationFilterGradAvg::~DecimationFilterGradAvg()  
72
73     // do nothing
74 }
75
76
77 Mesh* DecimationFilterGradAvg::apply(Mesh* pMesh, const char* pArgv, const char* pNameNewMesh)
78 {
79     //---------------------------------------------------------------------
80     // Retrieve and check parameters
81     //---------------------------------------------------------------------
82     if (pMesh == NULL) throw NullArgumentException("pMesh should not be NULL", __FILE__, __LINE__);
83     if (pArgv == NULL) throw NullArgumentException("pArgv should not be NULL", __FILE__, __LINE__);
84     if (pNameNewMesh == NULL) throw NullArgumentException("pNameNewMesh should not be NULL", __FILE__, __LINE__);
85     
86     char   fieldName[MED_TAILLE_NOM + 1];
87     int    fieldIt;
88         int        meshIt;
89     double threshold;
90     double radius;
91     int    boxing; // number of cells along axis (if 100 then grid will have 100*100*100 = 10**6 cells)
92     set<med_int> elementsToKeep[eMaxMedMesh];
93     Field* field = NULL;
94         
95     int ret = sscanf(pArgv, "%s %d %lf %lf %d",
96         fieldName,
97         &fieldIt,
98         &threshold,
99         &radius,
100         &boxing);
101     if (ret != 5) throw IllegalArgumentException("wrong number of arguments for filter GradAvg; expected 5 parameters", __FILE__, __LINE__);
102     
103         for (meshIt = 0; meshIt < eMaxMedMesh; ++meshIt)
104         {
105                 
106                 //---------------------------------------------------------------------
107                 // Retrieve field = for each point: get its coordinate and the value of the field
108                 //---------------------------------------------------------------------
109                 field = pMesh->getFieldByName(fieldName, (eMeshType)meshIt);
110                 
111                 if (field == NULL) continue;
112                 if ((fieldIt < 1) || (fieldIt > field->getNumberOfTimeSteps())) throw IllegalArgumentException("invalid field iteration", __FILE__, __LINE__);
113                 
114                 vector<PointOfField> points;
115                 pMesh->getAllPointsOfField(field, fieldIt, points, (eMeshType)meshIt);
116         
117                 //---------------------------------------------------------------------
118                 // Creates acceleration structure used to compute gradient
119                 //---------------------------------------------------------------------
120                 DecimationAccel* accel = new DecimationAccelGrid();
121                 char strCfg[256]; // a string is used for genericity
122                 sprintf(strCfg, "%d %d %d", boxing, boxing, boxing);
123                 accel->configure(strCfg);
124                 accel->create(points);
125                 
126                 //---------------------------------------------------------------------
127                 // Collects elements of the mesh to be kept
128                 //---------------------------------------------------------------------
129                 int numElements = 0;
130         if (field->isFieldOnNodes())
131         {
132             numElements = pMesh->getNodes()->getNumberOfNodes();
133         }
134         else
135         {
136             numElements = pMesh->getNumberOfElements((eMeshType)meshIt);
137         }
138         if (field->getProfil(fieldIt).size() != 0)
139         {
140             Profil* profil = pMesh->getProfil(field->getProfil(fieldIt));
141             if (profil == NULL) throw IllegalStateException("Can't find the profile in the mesh.", __FILE__, __LINE__);
142             numElements = profil->getSet().size();
143         }
144       
145                 int numGaussPointsByElt = 0;
146         if (field->isFieldOnNodes())
147         {
148             numGaussPointsByElt = 1;
149         }
150         else
151         {
152             numGaussPointsByElt = points.size() / numElements;
153         }
154                         
155                 // for each element
156                 for (int itElt = 0 ; itElt < numElements ; itElt++)
157                 {
158                         bool keepElement = false;
159                         
160                         // for each Gauss point of the current element
161                         for (int itPtGauss = 0 ; itPtGauss < numGaussPointsByElt ; itPtGauss++)
162                         {
163                                 const PointOfField& currentPt = points[itElt * numGaussPointsByElt + itPtGauss];
164                                 
165                                 vector<PointOfField> neighbours = accel->findNeighbours(
166                                         currentPt.mXYZ[0], 
167                                         currentPt.mXYZ[1], 
168                                         currentPt.mXYZ[2], 
169                                         radius);
170                                 
171                                 // if no neighbours => keep element
172                                 if (neighbours.size() == 0)
173                                 {
174                                         keepElement = true;
175                                         break;
176                                 }
177                                 
178                                 // otherwise compute gradient...
179                                 med_float normGrad = computeNormGrad(currentPt, neighbours);
180                                 
181                                 // debug
182                                 //cout << (itElt * numGaussPointsByElt + j) << ": " << normGrad << endl;
183                                 
184                                 if ((normGrad >= threshold) || isnan(normGrad))
185                                 {
186                                         keepElement = true;
187                                         break;
188                                 }
189                         }
190                         
191                         if (keepElement)
192                         {
193                                 // add index of the element to keep (index must start at 1)
194                                 elementsToKeep[meshIt].insert(med_int(itElt + 1));
195                         }
196                 }
197         
198                 //---------------------------------------------------------------------
199                 // Cleans
200                 //---------------------------------------------------------------------
201                 delete accel;
202         
203         if (field->isFieldOnNodes())
204         {
205             break;
206         }
207         }
208     
209     //---------------------------------------------------------------------
210     // Create the final mesh by extracting elements to keep from the current mesh
211     //---------------------------------------------------------------------
212     Mesh* newMesh = NULL;
213     if (field && field->isFieldOnNodes())
214     {
215         std::set<med_int> setOfElts[eMaxMedMesh];
216         
217         for (meshIt = 0; meshIt < eMaxMedMesh; ++meshIt)
218         {
219             if (pMesh->getElements(meshIt) != NULL)
220             {
221                 pMesh->getElements(meshIt)->extractSubSetFromNodes(elementsToKeep[0], setOfElts[meshIt]);
222             }
223         }
224         newMesh = pMesh->createFromSetOfElements(setOfElts, pNameNewMesh);
225     }
226     else
227     {
228         newMesh = pMesh->createFromSetOfElements(elementsToKeep, pNameNewMesh);
229     }
230     
231     return newMesh;
232 }
233
234
235 void DecimationFilterGradAvg::getGradientInfo(
236         Mesh*       pMesh, 
237         const char* pFieldName, 
238         int         pFieldIt, 
239         double      pRadius,
240         int         pBoxing,
241         double*     pOutGradMin,
242         double*     pOutGradAvg,
243         double*     pOutGradMax)
244 {
245     if (pMesh == NULL) throw NullArgumentException("pMesh should not be NULL", __FILE__, __LINE__);
246     if (pFieldName == NULL) throw NullArgumentException("pFieldName should not be NULL", __FILE__, __LINE__);
247     
248     for (int meshIt = 0; meshIt < eMaxMedMesh; ++meshIt)
249         {
250
251         //---------------------------------------------------------------------
252         // Retrieve field = for each point: get its coordinate and the value of the field
253         //---------------------------------------------------------------------
254         Field* field = pMesh->getFieldByName(pFieldName, (eMeshType)meshIt);
255         
256         if (field == NULL) continue;
257         if ((pFieldIt < 1) || (pFieldIt > field->getNumberOfTimeSteps())) throw IllegalArgumentException("invalid field iteration", __FILE__, __LINE__);
258         
259         vector<PointOfField> points;
260         pMesh->getAllPointsOfField(field, pFieldIt, points, (eMeshType)meshIt);
261     
262         //---------------------------------------------------------------------
263         // Creates acceleration structure used to compute gradient
264         //---------------------------------------------------------------------
265         DecimationAccel* accel = new DecimationAccelGrid();
266         char strCfg[256]; // a string is used for genericity
267         sprintf(strCfg, "%d %d %d", pBoxing, pBoxing, pBoxing);
268         accel->configure(strCfg);
269         accel->create(points);
270         
271         //---------------------------------------------------------------------
272         // Collects elements of the mesh to be kept
273         //---------------------------------------------------------------------
274                 int numElements = 0;
275         if (field->isFieldOnNodes())
276         {
277             numElements = pMesh->getNodes()->getNumberOfNodes();
278         }
279         else
280         {
281             numElements = pMesh->getNumberOfElements((eMeshType)meshIt);
282         }
283         if (field->getProfil(pFieldIt).size() != 0)
284         {
285             Profil* profil = pMesh->getProfil(field->getProfil(pFieldIt));
286             if (profil == NULL) throw IllegalStateException("Can't find the profile in the mesh.", __FILE__, __LINE__);
287             numElements = profil->getSet().size();
288         }
289       
290                 int numGaussPointsByElt = 0;
291         if (field->isFieldOnNodes())
292         {
293             numGaussPointsByElt = 1;
294         }
295         else
296         {
297             numGaussPointsByElt = points.size() / numElements;
298         }
299         
300         *pOutGradMax = -1e300;
301         *pOutGradMin = 1e300;
302         *pOutGradAvg = 0.0;
303         int count = 0;
304         
305         // for each element
306         for (int itElt = 0 ; itElt < numElements ; itElt++)
307         {
308             // for each Gauss point of the current element
309             for (int itPtGauss = 0 ; itPtGauss < numGaussPointsByElt ; itPtGauss++)
310             {
311                 const PointOfField& currentPt = points[itElt * numGaussPointsByElt + itPtGauss];
312                 
313                 vector<PointOfField> neighbours = accel->findNeighbours(
314                     currentPt.mXYZ[0], 
315                     currentPt.mXYZ[1], 
316                     currentPt.mXYZ[2], 
317                     pRadius);
318                 
319                 // if no neighbours => keep element
320                 if (neighbours.size() == 0)
321                 {
322                     continue;
323                 }
324                 
325                 // otherwise compute gradient...
326                 med_float normGrad = computeNormGrad(currentPt, neighbours);
327                 
328                 // debug
329                 //cout << (itElt * numGaussPointsByElt + j) << ": " << normGrad << endl;
330                 
331                 if (!isnan(normGrad))
332                 {
333                     if (normGrad > *pOutGradMax) *pOutGradMax = normGrad;
334                     if (normGrad < *pOutGradMin) *pOutGradMin = normGrad;
335                     *pOutGradAvg += normGrad;
336                     count++;
337                 }
338             }
339         }
340         
341         if (count != 0) *pOutGradAvg /= double(count);
342         
343         //---------------------------------------------------------------------
344         // Cleans
345         //---------------------------------------------------------------------
346         delete accel;
347         if (field->isFieldOnNodes())
348         {
349             break;
350         }
351     }
352 }
353
354
355 med_float DecimationFilterGradAvg::computeNormGrad(const PointOfField& pPt, const std::vector<PointOfField>& pNeighbours) const
356 {
357     med_float gradX = 0.0;
358     med_float gradY = 0.0;
359     med_float gradZ = 0.0;
360  
361     // for each neighbour
362     for (unsigned i = 0 ; i < pNeighbours.size() ; i++)
363     {
364         const PointOfField& neighbourPt = pNeighbours[i];
365         
366         med_float vecX = neighbourPt.mXYZ[0] - pPt.mXYZ[0];
367         med_float vecY = neighbourPt.mXYZ[1] - pPt.mXYZ[1];
368         med_float vecZ = neighbourPt.mXYZ[2] - pPt.mXYZ[2];
369         
370         med_float norm = med_float( sqrt( vecX*vecX + vecY*vecY + vecZ*vecZ ) );
371         med_float val =  neighbourPt.mVal - pPt.mVal;
372         
373         val /= norm;
374         
375         gradX += vecX * val;
376         gradY += vecY * val;
377         gradZ += vecZ * val;
378     }
379     
380     med_float invSize = 1.0 / med_float(pNeighbours.size());
381     
382     gradX *= invSize;
383     gradY *= invSize;
384     gradZ *= invSize;
385     
386     med_float norm = med_float( sqrt( gradX*gradX + gradY*gradY + gradZ*gradZ ) );
387     
388     return norm;
389     
390 }
391
392 //*****************************************************************************
393 // Class DecimationFilterGradAvg
394 //*****************************************************************************
395
396 DecimationFilterTreshold::DecimationFilterTreshold() 
397 {
398     // do nothing
399 }
400
401
402 DecimationFilterTreshold::~DecimationFilterTreshold()  
403
404     // do nothing
405 }
406
407
408 Mesh* DecimationFilterTreshold::apply(Mesh* pMesh, const char* pArgv, const char* pNameNewMesh)
409 {
410     if (pMesh == NULL) throw NullArgumentException("pMesh should not be NULL", __FILE__, __LINE__);
411     if (pArgv == NULL) throw NullArgumentException("pArgv should not be NULL", __FILE__, __LINE__);
412     if (pNameNewMesh == NULL) throw NullArgumentException("pNameNewMesh should not be NULL", __FILE__, __LINE__);
413
414     char        fieldName[MED_TAILLE_NOM + 1];
415     int         fieldIt;
416     double      threshold;
417         int             meshIt;
418         set<med_int> elementsToKeep[eMaxMedMesh];
419     Field*  field = NULL;
420     
421         int ret = sscanf(pArgv, "%s %d %lf",
422         fieldName,
423         &fieldIt,
424         &threshold);
425
426         if (ret != 3) throw IllegalArgumentException("wrong number of arguments for filter Treshold; expected 3 parameters", __FILE__, __LINE__);
427         
428         for (meshIt = 0; meshIt < eMaxMedMesh; ++meshIt)
429         {
430                 //---------------------------------------------------------------------
431                 // Retrieve field = for each point: get its coordinate and the value of the field
432                 //---------------------------------------------------------------------
433                 field = pMesh->getFieldByName(fieldName, (eMeshType)meshIt);
434                 if (field == NULL) continue;
435                 if ((fieldIt < 1) || (fieldIt > field->getNumberOfTimeSteps())) throw IllegalArgumentException("invalid field iteration", __FILE__, __LINE__);
436                 
437                 vector<PointOfField> points;
438                 pMesh->getAllPointsOfField(field, fieldIt, points, (eMeshType)meshIt);
439                 
440                 //---------------------------------------------------------------------
441                 // Collects elements of the mesh to be kept
442                 //---------------------------------------------------------------------
443                 int numElements = 0;
444         if (field->isFieldOnNodes())
445         {
446             numElements = pMesh->getNodes()->getNumberOfNodes();
447         }
448         else
449         {
450             numElements = pMesh->getNumberOfElements((eMeshType)meshIt);
451         }
452         if (field->getProfil(fieldIt).size() != 0)
453         {
454             Profil* profil = pMesh->getProfil(field->getProfil(fieldIt));
455             if (profil == NULL) throw IllegalStateException("Can't find the profile in the mesh.", __FILE__, __LINE__);
456             numElements = profil->getSet().size();
457         }
458       
459                 int numGaussPointsByElt = 0;
460         if (field->isFieldOnNodes())
461         {
462             numGaussPointsByElt = 1;
463         }
464         else
465         {
466             numGaussPointsByElt = points.size() / numElements;
467         }
468                 // for each element
469                 for (int itElt = 0 ; itElt < numElements ; itElt++)
470                 {
471                         bool keepElement = false;
472                         
473                         // for each Gauss point of the current element
474                         for (int itPtGauss = 0 ; itPtGauss < numGaussPointsByElt ; itPtGauss++)
475                         {
476                                 const PointOfField& currentPt = points[itElt * numGaussPointsByElt + itPtGauss];
477
478                                 if (currentPt.mVal > threshold)
479                                 {
480                                         keepElement = true;
481                                         break;
482                                 }
483                         }
484                         
485                         if (keepElement)
486                         {
487                                 // add index of the element to keep (index must start at 1)
488                                 elementsToKeep[meshIt].insert(med_int(itElt + 1));
489                         }
490                 }
491         if (field->isFieldOnNodes())
492         {
493             break;
494         }
495         }
496     //---------------------------------------------------------------------
497     // Create the final mesh by extracting elements to keep from the current mesh
498     //---------------------------------------------------------------------
499     Mesh* newMesh = NULL;
500     if (field && field->isFieldOnNodes())
501     {
502         std::set<med_int> setOfElts[eMaxMedMesh];
503         
504         for (meshIt = 0; meshIt < eMaxMedMesh; ++meshIt)
505         {
506             if (pMesh->getElements(meshIt) != NULL)
507             {
508                 pMesh->getElements(meshIt)->extractSubSetFromNodes(elementsToKeep[0], setOfElts[meshIt]);
509             }
510         }
511         newMesh = pMesh->createFromSetOfElements(setOfElts, pNameNewMesh);
512     }
513     else
514     {
515         newMesh = pMesh->createFromSetOfElements(elementsToKeep, pNameNewMesh);
516     }
517
518     return newMesh;
519 }
520
521 } // namespace multipr
522
523 // EOF