Salome HOME
Key point. OK for GUI persalys + python persalys + Test C++. Tests follow new explici...
[tools/adao_interface.git] / TestAdaoExchange.cxx
1 // Copyright (C) 2019 EDF R&D
2 //
3 // This library is free software; you can redistribute it and/or
4 // modify it under the terms of the GNU Lesser General Public
5 // License as published by the Free Software Foundation; either
6 // version 2.1 of the License.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
16 //
17 // See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
18 //
19 // Author: Anthony Geay, anthony.geay@edf.fr, EDF R&D
20
21 #include "TestAdaoExchange.hxx"
22
23 #include "AdaoExchangeLayer.hxx"
24 #include "AdaoExchangeLayerException.hxx"
25 #include "AdaoModelKeyVal.hxx"
26 #include "PyObjectRAII.hxx"
27
28 #include "py2cpp/py2cpp.hxx"
29
30 #include <vector>
31 #include <iterator>
32
33 #include "TestAdaoHelper.cxx"
34
35 // Functor a remplacer par un appel a un evaluateur parallele
36 class NonParallelFunctor
37 {
38 public:
39   NonParallelFunctor(std::function< std::vector<double>(const std::vector<double>&) > cppFunction):_cpp_function(cppFunction) { }
40   PyObject *operator()(PyObject *inp) const
41   {
42     PyGILState_STATE gstate(PyGILState_Ensure());
43     PyObjectRAII iterator(PyObjectRAII::FromNew(PyObject_GetIter(inp)));
44     if(iterator.isNull())
45       throw AdaoExchangeLayerException("Input object is not iterable !");
46     PyObject *item(nullptr);
47     PyObjectRAII numpyModule(PyObjectRAII::FromNew(PyImport_ImportModule("numpy")));
48     if(numpyModule.isNull())
49       throw AdaoExchangeLayerException("Failed to load numpy");
50     PyObjectRAII ravelFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(numpyModule,"ravel")));
51     std::vector< PyObjectRAII > pyrets;
52     while( item = PyIter_Next(iterator) )
53       {
54         PyObjectRAII item2(PyObjectRAII::FromNew(item));
55         {
56           PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(1)));
57           PyTuple_SetItem(args,0,item2.retn());
58           PyObjectRAII npyArray(PyObjectRAII::FromNew(PyObject_CallObject(ravelFunc,args)));
59           // Waiting management of npy arrays into py2cpp
60           PyObjectRAII lolistFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(npyArray,"tolist")));
61           PyObjectRAII listPy;
62           {
63             PyObjectRAII args2(PyObjectRAII::FromNew(PyTuple_New(0)));
64             listPy=PyObjectRAII::FromNew(PyObject_CallObject(lolistFunc,args2));
65           }
66           std::vector<double> vect;
67           {
68             py2cpp::PyPtr listPy2(listPy.retn());
69             py2cpp::fromPyPtr(listPy2,vect);
70           }
71           //
72           PyGILState_Release(gstate);
73           std::vector<double> res(_cpp_function(vect));// L'appel effectif est ici
74           gstate=PyGILState_Ensure();
75           //
76           py2cpp::PyPtr resPy(py2cpp::toPyPtr(res));
77           PyObjectRAII resPy2(PyObjectRAII::FromBorrowed(resPy.get()));
78           pyrets.push_back(resPy2);
79         }
80       }
81     std::size_t len(pyrets.size());
82     PyObjectRAII ret(PyObjectRAII::FromNew(PyList_New(len)));
83     for(std::size_t i=0;i<len;++i)
84       {
85         PyList_SetItem(ret,i,pyrets[i].retn());
86       }
87     //PyObject *tmp(PyObject_Repr(ret));
88     // std::cerr << PyUnicode_AsUTF8(tmp) << std::endl;
89     PyGILState_Release(gstate);
90     return ret.retn();
91   }
92 private:
93   std::function< std::vector<double>(const std::vector<double>&) > _cpp_function;
94 };
95
96 PyObjectRAII NumpyToListWaitingForPy2CppManagement(PyObject *npObj)
97 {
98   PyObjectRAII func(PyObjectRAII::FromNew(PyObject_GetAttrString(npObj,"tolist")));
99   if(func.isNull())
100     throw AdaoExchangeLayerException("input pyobject does not have tolist attribute !");
101   PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(0)));
102   PyObjectRAII ret(PyObjectRAII::FromNew(PyObject_CallObject(func,args)));
103   return ret;
104 }
105
106 void AdaoExchangeTest::setUp()
107 {
108 }
109
110 void AdaoExchangeTest::tearDown()
111 {
112 }
113
114 void AdaoExchangeTest::cleanUp()
115 {
116 }
117
118 using namespace AdaoModel;
119
120 void AdaoExchangeTest::test3DVar()
121 {
122   NonParallelFunctor functor(funcBase);
123   MainModel mm;
124   AdaoExchangeLayer adao;
125   adao.init();
126   // For bounds, Background/Vector, Observation/Vector
127   Visitor2 visitorPythonObj(adao.getPythonContext());
128   {
129     AutoGIL agil;
130     mm.visitPythonLeaves(&visitorPythonObj);
131   }
132   //
133   adao.loadTemplate(&mm);
134   //
135   {
136     std::string sciptPyOfModelMaker(mm.pyStr());
137     //std::cerr << sciptPyOfModelMaker << std::endl;
138   }
139   adao.execute();
140   PyObject *listOfElts( nullptr );
141   while( adao.next(listOfElts) )
142     {
143       PyObject *resultOfChunk(functor(listOfElts));
144       adao.setResult(resultOfChunk);
145     }
146   PyObject *res(adao.getResult());
147   PyObjectRAII optimum(PyObjectRAII::FromNew(res));
148   PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
149   std::vector<double> vect;
150   {
151     py2cpp::PyPtr obj(optimum_4_py2cpp);
152     py2cpp::fromPyPtr(obj,vect);
153   }
154   CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
155   CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
156   CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
157   CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
158 }
159
160 void AdaoExchangeTest::testBlue()
161 {
162   class TestBlueVisitor : public RecursiveVisitor
163   {
164   public:
165     void visit(GenericKeyVal *obj)
166     {
167       EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
168       if(objc)
169         objc->setVal(EnumAlgo::Blue);
170     }
171     void enterSubDir(DictKeyVal *subdir) { }
172     void exitSubDir(DictKeyVal *subdir) { }
173   };
174
175   NonParallelFunctor functor(funcBase);
176   MainModel mm;
177   //
178   TestBlueVisitor vis;
179   mm.visitAll(&vis);
180   //
181   AdaoExchangeLayer adao;
182   adao.init();
183   // For bounds, Background/Vector, Observation/Vector
184   Visitor2 visitorPythonObj(adao.getPythonContext());
185   {
186     AutoGIL agil;
187     mm.visitPythonLeaves(&visitorPythonObj);
188   }
189   //
190   adao.loadTemplate(&mm);
191   //
192   {
193     std::string sciptPyOfModelMaker(mm.pyStr());
194     //std::cerr << sciptPyOfModelMaker << std::endl;
195   }
196   adao.execute();
197     PyObject *listOfElts( nullptr );
198     while( adao.next(listOfElts) )
199       {
200         PyObject *resultOfChunk(functor(listOfElts));
201         adao.setResult(resultOfChunk);
202       }
203     PyObject *res(adao.getResult());
204     PyObjectRAII optimum(PyObjectRAII::FromNew(res));
205     PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
206     std::vector<double> vect;
207     {
208       py2cpp::PyPtr obj(optimum_4_py2cpp);
209       py2cpp::fromPyPtr(obj,vect);
210     }
211     CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
212     CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
213     CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
214     CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
215 }
216
217 void AdaoExchangeTest::testNonLinearLeastSquares()
218 {
219   class TestNonLinearLeastSquaresVisitor : public RecursiveVisitor
220   {
221   public:
222     void visit(GenericKeyVal *obj)
223     {
224       EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
225       if(objc)
226         objc->setVal(EnumAlgo::NonLinearLeastSquares);
227     }
228     void enterSubDir(DictKeyVal *subdir) { }
229     void exitSubDir(DictKeyVal *subdir) { }
230   };
231   NonParallelFunctor functor(funcBase);
232   MainModel mm;
233   //
234   TestNonLinearLeastSquaresVisitor vis;
235   mm.visitAll(&vis);
236   //
237   AdaoExchangeLayer adao;
238   adao.init();
239   // For bounds, Background/Vector, Observation/Vector
240   Visitor2 visitorPythonObj(adao.getPythonContext());
241   {
242     AutoGIL agil;
243     mm.visitPythonLeaves(&visitorPythonObj);
244   }
245   //
246   adao.loadTemplate(&mm);
247   //
248   {
249     std::string sciptPyOfModelMaker(mm.pyStr());
250     //std::cerr << sciptPyOfModelMaker << std::endl;
251   }
252   adao.execute();
253   PyObject *listOfElts( nullptr );
254   while( adao.next(listOfElts) )
255     {
256       PyObject *resultOfChunk(functor(listOfElts));
257       adao.setResult(resultOfChunk);
258     }
259   PyObject *res(adao.getResult());
260   PyObjectRAII optimum(PyObjectRAII::FromNew(res));
261   PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
262   std::vector<double> vect;
263   {
264     py2cpp::PyPtr obj(optimum_4_py2cpp);
265     py2cpp::fromPyPtr(obj,vect);
266   }
267   CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
268   CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
269   CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
270   CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
271 }
272
273 void AdaoExchangeTest::testCasCrue()
274 {
275   NonParallelFunctor functor(funcCrue);
276   MainModel mm;
277   AdaoExchangeLayer adao;
278   adao.init();
279   // For bounds, Background/Vector, Observation/Vector
280   VisitorCruePython visitorPythonObj(adao.getPythonContext());
281   {
282     AutoGIL agil;
283     mm.visitPythonLeaves(&visitorPythonObj);
284   }
285   //
286   adao.loadTemplate(&mm);
287   //
288   {
289     std::string sciptPyOfModelMaker(mm.pyStr());
290     //std::cerr << sciptPyOfModelMaker << std::endl;
291   }
292   adao.execute();
293   PyObject *listOfElts( nullptr );
294   while( adao.next(listOfElts) )
295     {
296       PyObject *resultOfChunk(functor(listOfElts));
297       adao.setResult(resultOfChunk);
298     }
299   PyObject *res(adao.getResult());
300   PyObjectRAII optimum(PyObjectRAII::FromNew(res));
301   PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
302   std::vector<double> vect;
303   {
304     py2cpp::PyPtr obj(optimum_4_py2cpp);
305     py2cpp::fromPyPtr(obj,vect);
306   }
307   CPPUNIT_ASSERT_EQUAL(1,(int)vect.size());
308   CPPUNIT_ASSERT_DOUBLES_EQUAL(25.,vect[0],1e-3);
309 }
310
311 CPPUNIT_TEST_SUITE_REGISTRATION( AdaoExchangeTest );
312
313 #include <cppunit/CompilerOutputter.h>
314 #include <cppunit/TestResult.h>
315 #include <cppunit/TestResultCollector.h>
316 #include <cppunit/TextTestProgressListener.h>
317 #include <cppunit/BriefTestProgressListener.h>
318 #include <cppunit/extensions/TestFactoryRegistry.h>
319 #include <cppunit/TestRunner.h>
320 #include <cppunit/TextTestRunner.h>
321
322 int main(int argc, char* argv[])
323 {
324   // --- Create the event manager and test controller
325   CPPUNIT_NS::TestResult controller;
326
327   // ---  Add a listener that collects test result
328   CPPUNIT_NS::TestResultCollector result;
329   controller.addListener( &result );        
330
331   // ---  Add a listener that print dots as test run.
332 #ifdef WIN32
333   CPPUNIT_NS::TextTestProgressListener progress;
334 #else
335   CPPUNIT_NS::BriefTestProgressListener progress;
336 #endif
337   controller.addListener( &progress );      
338
339   // ---  Get the top level suite from the registry
340
341   CPPUNIT_NS::Test *suite =
342     CPPUNIT_NS::TestFactoryRegistry::getRegistry().makeTest();
343
344   // ---  Adds the test to the list of test to run
345
346   CPPUNIT_NS::TestRunner runner;
347   runner.addTest( suite );
348   runner.run( controller);
349
350   // ---  Print test in a compiler compatible format.
351   std::ofstream testFile;
352   testFile.open("test.log", std::ios::out | std::ios::app);
353   testFile << "------ ADAO exchange test log:" << std::endl;
354   CPPUNIT_NS::CompilerOutputter outputter( &result, testFile );
355   outputter.write(); 
356
357   // ---  Run the tests.
358
359   bool wasSucessful = result.wasSuccessful();
360   testFile.close();
361
362   // ---  Return error code 1 if the one of test failed.
363
364   return wasSucessful ? 0 : 1;
365 }