Salome HOME
Merge from BR_V5_DEV 16Feb09
[modules/med.git] / src / MULTIPR / MULTIPR_DecimationFilter.cxx
1 //  Copyright (C) 2007-2008  CEA/DEN, EDF R&D
2 //
3 //  This library is free software; you can redistribute it and/or
4 //  modify it under the terms of the GNU Lesser General Public
5 //  License as published by the Free Software Foundation; either
6 //  version 2.1 of the License.
7 //
8 //  This library is distributed in the hope that it will be useful,
9 //  but WITHOUT ANY WARRANTY; without even the implied warranty of
10 //  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
11 //  Lesser General Public License for more details.
12 //
13 //  You should have received a copy of the GNU Lesser General Public
14 //  License along with this library; if not, write to the Free Software
15 //  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
16 //
17 //  See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
18 //
19 // Partitioning/decimation module for the SALOME v3.2 platform
20 //
21 /**
22  * \file    MULTIPR_DecimationFilter.cxx
23  *
24  * \brief   see MULTIPR_DecimationFilter.hxx
25  *
26  * \author  Olivier LE ROUX - CS, Virtual Reality Dpt
27  * 
28  * \date    01/2007
29  */
30
31 //*****************************************************************************
32 // Includes section
33 //*****************************************************************************
34
35 #include "MULTIPR_DecimationFilter.hxx"
36 #include "MULTIPR_Field.hxx"
37 #include "MULTIPR_Mesh.hxx"
38 #include "MULTIPR_Nodes.hxx"
39 #include "MULTIPR_Elements.hxx"
40 #include "MULTIPR_Profil.hxx"
41 #include "MULTIPR_PointOfField.hxx"
42 #include "MULTIPR_DecimationAccel.hxx"
43 #include "MULTIPR_Exceptions.hxx"
44
45 #include <cmath>
46 #include <iostream>
47
48 using namespace std;
49
50
51 namespace multipr
52 {
53
54 //*****************************************************************************
55 // Class DecimationFilter implementation
56 //*****************************************************************************
57
58 // Factory used to build all filters from their name.
59 DecimationFilter* DecimationFilter::create(const char* pFilterName)
60 {
61     if (pFilterName == NULL) throw NullArgumentException("filter name should not be NULL", __FILE__, __LINE__);
62     
63     if (strcmp(pFilterName, "Filtre_GradientMoyen") == 0)
64     {
65         return new DecimationFilterGradAvg();
66     }
67         else if (strcmp(pFilterName, "Filtre_Direct") == 0)
68         {
69                 return new DecimationFilterTreshold();
70         }
71     else
72     {
73         throw IllegalArgumentException("unknown filter", __FILE__, __LINE__);
74     }
75 }
76
77
78 //*****************************************************************************
79 // Class DecimationFilterGradAvg
80 //*****************************************************************************
81
82 DecimationFilterGradAvg::DecimationFilterGradAvg() 
83 {
84     // do nothing
85 }
86
87
88 DecimationFilterGradAvg::~DecimationFilterGradAvg()  
89
90     // do nothing
91 }
92
93
94 Mesh* DecimationFilterGradAvg::apply(Mesh* pMesh, const char* pArgv, const char* pNameNewMesh)
95 {
96     //---------------------------------------------------------------------
97     // Retrieve and check parameters
98     //---------------------------------------------------------------------
99     if (pMesh == NULL) throw NullArgumentException("pMesh should not be NULL", __FILE__, __LINE__);
100     if (pArgv == NULL) throw NullArgumentException("pArgv should not be NULL", __FILE__, __LINE__);
101     if (pNameNewMesh == NULL) throw NullArgumentException("pNameNewMesh should not be NULL", __FILE__, __LINE__);
102     
103     char   fieldName[MED_TAILLE_NOM + 1];
104     int    fieldIt;
105         int        meshIt;
106     double threshold;
107     double radius;
108     int    boxing; // number of cells along axis (if 100 then grid will have 100*100*100 = 10**6 cells)
109     set<med_int> elementsToKeep[eMaxMedMesh];
110     Field* field = NULL;
111         
112     int ret = sscanf(pArgv, "%s %d %lf %lf %d",
113         fieldName,
114         &fieldIt,
115         &threshold,
116         &radius,
117         &boxing);
118     if (ret != 5) throw IllegalArgumentException("wrong number of arguments for filter GradAvg; expected 5 parameters", __FILE__, __LINE__);
119     
120         for (meshIt = 0; meshIt < eMaxMedMesh; ++meshIt)
121         {
122                 
123                 //---------------------------------------------------------------------
124                 // Retrieve field = for each point: get its coordinate and the value of the field
125                 //---------------------------------------------------------------------
126                 field = pMesh->getFieldByName(fieldName, (eMeshType)meshIt);
127                 
128                 if (field == NULL) continue;
129                 if ((fieldIt < 1) || (fieldIt > field->getNumberOfTimeSteps())) throw IllegalArgumentException("invalid field iteration", __FILE__, __LINE__);
130                 
131                 vector<PointOfField> points;
132                 pMesh->getAllPointsOfField(field, fieldIt, points, (eMeshType)meshIt);
133         
134                 //---------------------------------------------------------------------
135                 // Creates acceleration structure used to compute gradient
136                 //---------------------------------------------------------------------
137                 DecimationAccel* accel = new DecimationAccelGrid();
138                 char strCfg[256]; // a string is used for genericity
139                 sprintf(strCfg, "%d %d %d", boxing, boxing, boxing);
140                 accel->configure(strCfg);
141                 accel->create(points);
142                 
143                 //---------------------------------------------------------------------
144                 // Collects elements of the mesh to be kept
145                 //---------------------------------------------------------------------
146                 int numElements = 0;
147         if (field->isFieldOnNodes())
148         {
149             numElements = pMesh->getNodes()->getNumberOfNodes();
150         }
151         else
152         {
153             numElements = pMesh->getNumberOfElements((eMeshType)meshIt);
154         }
155         if (field->getProfil(fieldIt).size() != 0)
156         {
157             Profil* profil = pMesh->getProfil(field->getProfil(fieldIt));
158             if (profil == NULL) throw IllegalStateException("Can't find the profile in the mesh.", __FILE__, __LINE__);
159             numElements = profil->getSet().size();
160         }
161       
162                 int numGaussPointsByElt = 0;
163         if (field->isFieldOnNodes())
164         {
165             numGaussPointsByElt = 1;
166         }
167         else
168         {
169             numGaussPointsByElt = points.size() / numElements;
170         }
171                         
172                 // for each element
173                 for (int itElt = 0 ; itElt < numElements ; itElt++)
174                 {
175                         bool keepElement = false;
176                         
177                         // for each Gauss point of the current element
178                         for (int itPtGauss = 0 ; itPtGauss < numGaussPointsByElt ; itPtGauss++)
179                         {
180                                 const PointOfField& currentPt = points[itElt * numGaussPointsByElt + itPtGauss];
181                                 
182                                 vector<PointOfField> neighbours = accel->findNeighbours(
183                                         currentPt.mXYZ[0], 
184                                         currentPt.mXYZ[1], 
185                                         currentPt.mXYZ[2], 
186                                         radius);
187                                 
188                                 // if no neighbours => keep element
189                                 if (neighbours.size() == 0)
190                                 {
191                                         keepElement = true;
192                                         break;
193                                 }
194                                 
195                                 // otherwise compute gradient...
196                                 med_float normGrad = computeNormGrad(currentPt, neighbours);
197                                 
198                                 // debug
199                                 //cout << (itElt * numGaussPointsByElt + j) << ": " << normGrad << endl;
200                                 
201                                 if ((normGrad >= threshold) || isnan(normGrad))
202                                 {
203                                         keepElement = true;
204                                         break;
205                                 }
206                         }
207                         
208                         if (keepElement)
209                         {
210                                 // add index of the element to keep (index must start at 1)
211                                 elementsToKeep[meshIt].insert(med_int(itElt + 1));
212                         }
213                 }
214         
215                 //---------------------------------------------------------------------
216                 // Cleans
217                 //---------------------------------------------------------------------
218                 delete accel;
219         
220         if (field->isFieldOnNodes())
221         {
222             break;
223         }
224         }
225     
226     //---------------------------------------------------------------------
227     // Create the final mesh by extracting elements to keep from the current mesh
228     //---------------------------------------------------------------------
229     Mesh* newMesh = NULL;
230     if (field && field->isFieldOnNodes())
231     {
232         std::set<med_int> setOfElts[eMaxMedMesh];
233         
234         for (meshIt = 0; meshIt < eMaxMedMesh; ++meshIt)
235         {
236             if (pMesh->getElements(meshIt) != NULL)
237             {
238                 pMesh->getElements(meshIt)->extractSubSetFromNodes(elementsToKeep[0], setOfElts[meshIt]);
239             }
240         }
241         newMesh = pMesh->createFromSetOfElements(setOfElts, pNameNewMesh);
242     }
243     else
244     {
245         newMesh = pMesh->createFromSetOfElements(elementsToKeep, pNameNewMesh);
246     }
247     
248     return newMesh;
249 }
250
251
252 void DecimationFilterGradAvg::getGradientInfo(
253         Mesh*       pMesh, 
254         const char* pFieldName, 
255         int         pFieldIt, 
256         double      pRadius,
257         int         pBoxing,
258         double*     pOutGradMin,
259         double*     pOutGradAvg,
260         double*     pOutGradMax)
261 {
262     if (pMesh == NULL) throw NullArgumentException("pMesh should not be NULL", __FILE__, __LINE__);
263     if (pFieldName == NULL) throw NullArgumentException("pFieldName should not be NULL", __FILE__, __LINE__);
264     
265     for (int meshIt = 0; meshIt < eMaxMedMesh; ++meshIt)
266         {
267
268         //---------------------------------------------------------------------
269         // Retrieve field = for each point: get its coordinate and the value of the field
270         //---------------------------------------------------------------------
271         Field* field = pMesh->getFieldByName(pFieldName, (eMeshType)meshIt);
272         
273         if (field == NULL) continue;
274         if ((pFieldIt < 1) || (pFieldIt > field->getNumberOfTimeSteps())) throw IllegalArgumentException("invalid field iteration", __FILE__, __LINE__);
275         
276         vector<PointOfField> points;
277         pMesh->getAllPointsOfField(field, pFieldIt, points, (eMeshType)meshIt);
278     
279         //---------------------------------------------------------------------
280         // Creates acceleration structure used to compute gradient
281         //---------------------------------------------------------------------
282         DecimationAccel* accel = new DecimationAccelGrid();
283         char strCfg[256]; // a string is used for genericity
284         sprintf(strCfg, "%d %d %d", pBoxing, pBoxing, pBoxing);
285         accel->configure(strCfg);
286         accel->create(points);
287         
288         //---------------------------------------------------------------------
289         // Collects elements of the mesh to be kept
290         //---------------------------------------------------------------------
291                 int numElements = 0;
292         if (field->isFieldOnNodes())
293         {
294             numElements = pMesh->getNodes()->getNumberOfNodes();
295         }
296         else
297         {
298             numElements = pMesh->getNumberOfElements((eMeshType)meshIt);
299         }
300         if (field->getProfil(pFieldIt).size() != 0)
301         {
302             Profil* profil = pMesh->getProfil(field->getProfil(pFieldIt));
303             if (profil == NULL) throw IllegalStateException("Can't find the profile in the mesh.", __FILE__, __LINE__);
304             numElements = profil->getSet().size();
305         }
306       
307                 int numGaussPointsByElt = 0;
308         if (field->isFieldOnNodes())
309         {
310             numGaussPointsByElt = 1;
311         }
312         else
313         {
314             numGaussPointsByElt = points.size() / numElements;
315         }
316         
317         *pOutGradMax = -1e300;
318         *pOutGradMin = 1e300;
319         *pOutGradAvg = 0.0;
320         int count = 0;
321         
322         // for each element
323         for (int itElt = 0 ; itElt < numElements ; itElt++)
324         {
325             // for each Gauss point of the current element
326             for (int itPtGauss = 0 ; itPtGauss < numGaussPointsByElt ; itPtGauss++)
327             {
328                 const PointOfField& currentPt = points[itElt * numGaussPointsByElt + itPtGauss];
329                 
330                 vector<PointOfField> neighbours = accel->findNeighbours(
331                     currentPt.mXYZ[0], 
332                     currentPt.mXYZ[1], 
333                     currentPt.mXYZ[2], 
334                     pRadius);
335                 
336                 // if no neighbours => keep element
337                 if (neighbours.size() == 0)
338                 {
339                     continue;
340                 }
341                 
342                 // otherwise compute gradient...
343                 med_float normGrad = computeNormGrad(currentPt, neighbours);
344                 
345                 // debug
346                 //cout << (itElt * numGaussPointsByElt + j) << ": " << normGrad << endl;
347                 
348                 if (!isnan(normGrad))
349                 {
350                     if (normGrad > *pOutGradMax) *pOutGradMax = normGrad;
351                     if (normGrad < *pOutGradMin) *pOutGradMin = normGrad;
352                     *pOutGradAvg += normGrad;
353                     count++;
354                 }
355             }
356         }
357         
358         if (count != 0) *pOutGradAvg /= double(count);
359         
360         //---------------------------------------------------------------------
361         // Cleans
362         //---------------------------------------------------------------------
363         delete accel;
364         if (field->isFieldOnNodes())
365         {
366             break;
367         }
368     }
369 }
370
371
372 med_float DecimationFilterGradAvg::computeNormGrad(const PointOfField& pPt, const std::vector<PointOfField>& pNeighbours) const
373 {
374     med_float gradX = 0.0;
375     med_float gradY = 0.0;
376     med_float gradZ = 0.0;
377  
378     // for each neighbour
379     for (unsigned i = 0 ; i < pNeighbours.size() ; i++)
380     {
381         const PointOfField& neighbourPt = pNeighbours[i];
382         
383         med_float vecX = neighbourPt.mXYZ[0] - pPt.mXYZ[0];
384         med_float vecY = neighbourPt.mXYZ[1] - pPt.mXYZ[1];
385         med_float vecZ = neighbourPt.mXYZ[2] - pPt.mXYZ[2];
386         
387         med_float norm = med_float( sqrt( vecX*vecX + vecY*vecY + vecZ*vecZ ) );
388         med_float val =  neighbourPt.mVal - pPt.mVal;
389         
390         val /= norm;
391         
392         gradX += vecX * val;
393         gradY += vecY * val;
394         gradZ += vecZ * val;
395     }
396     
397     med_float invSize = 1.0 / med_float(pNeighbours.size());
398     
399     gradX *= invSize;
400     gradY *= invSize;
401     gradZ *= invSize;
402     
403     med_float norm = med_float( sqrt( gradX*gradX + gradY*gradY + gradZ*gradZ ) );
404     
405     return norm;
406     
407 }
408
409 //*****************************************************************************
410 // Class DecimationFilterGradAvg
411 //*****************************************************************************
412
413 DecimationFilterTreshold::DecimationFilterTreshold() 
414 {
415     // do nothing
416 }
417
418
419 DecimationFilterTreshold::~DecimationFilterTreshold()  
420
421     // do nothing
422 }
423
424
425 Mesh* DecimationFilterTreshold::apply(Mesh* pMesh, const char* pArgv, const char* pNameNewMesh)
426 {
427     if (pMesh == NULL) throw NullArgumentException("pMesh should not be NULL", __FILE__, __LINE__);
428     if (pArgv == NULL) throw NullArgumentException("pArgv should not be NULL", __FILE__, __LINE__);
429     if (pNameNewMesh == NULL) throw NullArgumentException("pNameNewMesh should not be NULL", __FILE__, __LINE__);
430
431     char        fieldName[MED_TAILLE_NOM + 1];
432     int         fieldIt;
433     double      threshold;
434         int             meshIt;
435         set<med_int> elementsToKeep[eMaxMedMesh];
436     Field*  field = NULL;
437     
438         int ret = sscanf(pArgv, "%s %d %lf",
439         fieldName,
440         &fieldIt,
441         &threshold);
442
443         if (ret != 3) throw IllegalArgumentException("wrong number of arguments for filter Treshold; expected 3 parameters", __FILE__, __LINE__);
444         
445         for (meshIt = 0; meshIt < eMaxMedMesh; ++meshIt)
446         {
447                 //---------------------------------------------------------------------
448                 // Retrieve field = for each point: get its coordinate and the value of the field
449                 //---------------------------------------------------------------------
450                 field = pMesh->getFieldByName(fieldName, (eMeshType)meshIt);
451                 if (field == NULL) continue;
452                 if ((fieldIt < 1) || (fieldIt > field->getNumberOfTimeSteps())) throw IllegalArgumentException("invalid field iteration", __FILE__, __LINE__);
453                 
454                 vector<PointOfField> points;
455                 pMesh->getAllPointsOfField(field, fieldIt, points, (eMeshType)meshIt);
456                 
457                 //---------------------------------------------------------------------
458                 // Collects elements of the mesh to be kept
459                 //---------------------------------------------------------------------
460                 int numElements = 0;
461         if (field->isFieldOnNodes())
462         {
463             numElements = pMesh->getNodes()->getNumberOfNodes();
464         }
465         else
466         {
467             numElements = pMesh->getNumberOfElements((eMeshType)meshIt);
468         }
469         if (field->getProfil(fieldIt).size() != 0)
470         {
471             Profil* profil = pMesh->getProfil(field->getProfil(fieldIt));
472             if (profil == NULL) throw IllegalStateException("Can't find the profile in the mesh.", __FILE__, __LINE__);
473             numElements = profil->getSet().size();
474         }
475       
476                 int numGaussPointsByElt = 0;
477         if (field->isFieldOnNodes())
478         {
479             numGaussPointsByElt = 1;
480         }
481         else
482         {
483             numGaussPointsByElt = points.size() / numElements;
484         }
485                 // for each element
486                 for (int itElt = 0 ; itElt < numElements ; itElt++)
487                 {
488                         bool keepElement = false;
489                         
490                         // for each Gauss point of the current element
491                         for (int itPtGauss = 0 ; itPtGauss < numGaussPointsByElt ; itPtGauss++)
492                         {
493                                 const PointOfField& currentPt = points[itElt * numGaussPointsByElt + itPtGauss];
494
495                                 if (currentPt.mVal > threshold)
496                                 {
497                                         keepElement = true;
498                                         break;
499                                 }
500                         }
501                         
502                         if (keepElement)
503                         {
504                                 // add index of the element to keep (index must start at 1)
505                                 elementsToKeep[meshIt].insert(med_int(itElt + 1));
506                         }
507                 }
508         if (field->isFieldOnNodes())
509         {
510             break;
511         }
512         }
513     //---------------------------------------------------------------------
514     // Create the final mesh by extracting elements to keep from the current mesh
515     //---------------------------------------------------------------------
516     Mesh* newMesh = NULL;
517     if (field && field->isFieldOnNodes())
518     {
519         std::set<med_int> setOfElts[eMaxMedMesh];
520         
521         for (meshIt = 0; meshIt < eMaxMedMesh; ++meshIt)
522         {
523             if (pMesh->getElements(meshIt) != NULL)
524             {
525                 pMesh->getElements(meshIt)->extractSubSetFromNodes(elementsToKeep[0], setOfElts[meshIt]);
526             }
527         }
528         newMesh = pMesh->createFromSetOfElements(setOfElts, pNameNewMesh);
529     }
530     else
531     {
532         newMesh = pMesh->createFromSetOfElements(elementsToKeep, pNameNewMesh);
533     }
534
535     return newMesh;
536 }
537
538 } // namespace multipr
539
540 // EOF