Salome HOME
f0d715048f4acadbd5c113a52d27492d0cb480f0
[tools/adao_interface.git] / TestAdaoExchange.cxx
1 // Copyright (C) 2019 EDF R&D
2 //
3 // This library is free software; you can redistribute it and/or
4 // modify it under the terms of the GNU Lesser General Public
5 // License as published by the Free Software Foundation; either
6 // version 2.1 of the License.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
16 //
17 // See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
18 //
19 // Author: Anthony Geay, anthony.geay@edf.fr, EDF R&D
20
21 #include "TestAdaoExchange.hxx"
22
23 #include "AdaoExchangeLayer.hxx"
24 #include "AdaoExchangeLayerException.hxx"
25 #include "AdaoModelKeyVal.hxx"
26 #include "PyObjectRAII.hxx"
27
28 #include "py2cpp/py2cpp.hxx"
29
30 #include <vector>
31 #include <iterator>
32
33 #include "TestAdaoHelper.cxx"
34
35 // Functor a remplacer par un appel a un evaluateur parallele
36 class NonParallelFunctor
37 {
38 public:
39   NonParallelFunctor(std::function< std::vector<double>(const std::vector<double>&) > cppFunction):_cpp_function(cppFunction) { }
40   PyObject *operator()(PyObject *inp) const
41   {
42     PyGILState_STATE gstate(PyGILState_Ensure());
43     PyObjectRAII iterator(PyObjectRAII::FromNew(PyObject_GetIter(inp)));
44     if(iterator.isNull())
45       throw AdaoExchangeLayerException("Input object is not iterable !");
46     PyObject *item(nullptr);
47     PyObjectRAII numpyModule(PyObjectRAII::FromNew(PyImport_ImportModule("numpy")));
48     if(numpyModule.isNull())
49       throw AdaoExchangeLayerException("Failed to load numpy");
50     PyObjectRAII ravelFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(numpyModule,"ravel")));
51     std::vector< PyObjectRAII > pyrets;
52     while( item = PyIter_Next(iterator) )
53       {
54         PyObjectRAII item2(PyObjectRAII::FromNew(item));
55         {
56           PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(1)));
57           PyTuple_SetItem(args,0,item2.retn());
58           PyObjectRAII npyArray(PyObjectRAII::FromNew(PyObject_CallObject(ravelFunc,args)));
59           // Waiting management of npy arrays into py2cpp
60           PyObjectRAII lolistFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(npyArray,"tolist")));
61           PyObjectRAII listPy;
62           {
63             PyObjectRAII args2(PyObjectRAII::FromNew(PyTuple_New(0)));
64             listPy=PyObjectRAII::FromNew(PyObject_CallObject(lolistFunc,args2));
65           }
66           std::vector<double> vect;
67           {
68             py2cpp::PyPtr listPy2(listPy.retn());
69             py2cpp::fromPyPtr(listPy2,vect);
70           }
71           //
72           PyGILState_Release(gstate);
73           std::vector<double> res(_cpp_function(vect));// L'appel effectif est ici
74           gstate=PyGILState_Ensure();
75           //
76           py2cpp::PyPtr resPy(py2cpp::toPyPtr(res));
77           PyObjectRAII resPy2(PyObjectRAII::FromBorrowed(resPy.get()));
78           pyrets.push_back(resPy2);
79         }
80       }
81     std::size_t len(pyrets.size());
82     PyObjectRAII ret(PyObjectRAII::FromNew(PyList_New(len)));
83     for(std::size_t i=0;i<len;++i)
84       {
85         PyList_SetItem(ret,i,pyrets[i].retn());
86       }
87     //PyObject *tmp(PyObject_Repr(ret));
88     // std::cerr << PyUnicode_AsUTF8(tmp) << std::endl;
89     PyGILState_Release(gstate);
90     return ret.retn();
91   }
92 private:
93   std::function< std::vector<double>(const std::vector<double>&) > _cpp_function;
94 };
95
96 PyObjectRAII NumpyToListWaitingForPy2CppManagement(PyObject *npObj)
97 {
98   PyObjectRAII func(PyObjectRAII::FromNew(PyObject_GetAttrString(npObj,"tolist")));
99   if(func.isNull())
100     throw AdaoExchangeLayerException("input pyobject does not have tolist attribute !");
101   PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(0)));
102   PyObjectRAII ret(PyObjectRAII::FromNew(PyObject_CallObject(func,args)));
103   return ret;
104 }
105
106 void AdaoExchangeTest::setUp()
107 {
108 }
109
110 void AdaoExchangeTest::tearDown()
111 {
112 }
113
114 void AdaoExchangeTest::cleanUp()
115 {
116 }
117
118 using namespace AdaoModel;
119
120 void AdaoExchangeTest::test3DVar()
121 {
122   NonParallelFunctor functor(funcBase);
123   MainModel mm;
124   AdaoExchangeLayer adao;
125   adao.init();
126   // For bounds, Background/Vector, Observation/Vector
127   Visitor2 visitorPythonObj(adao.getPythonContext());
128   mm.visitPythonLeaves(&visitorPythonObj);
129   //
130   adao.loadTemplate(&mm);
131   //
132   {
133     std::string sciptPyOfModelMaker(mm.pyStr());
134     //std::cerr << sciptPyOfModelMaker << std::endl;
135   }
136   adao.execute();
137   PyObject *listOfElts( nullptr );
138   while( adao.next(listOfElts) )
139     {
140       PyObject *resultOfChunk(functor(listOfElts));
141       adao.setResult(resultOfChunk);
142     }
143   PyObject *res(adao.getResult());
144   PyObjectRAII optimum(PyObjectRAII::FromNew(res));
145   PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
146   std::vector<double> vect;
147   {
148     py2cpp::PyPtr obj(optimum_4_py2cpp);
149     py2cpp::fromPyPtr(obj,vect);
150   }
151   CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
152   CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
153   CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
154   CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
155 }
156
157 void AdaoExchangeTest::testBlue()
158 {
159   class TestBlueVisitor : public RecursiveVisitor
160   {
161   public:
162     void visit(GenericKeyVal *obj)
163     {
164       EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
165       if(objc)
166         objc->setVal(EnumAlgo::Blue);
167     }
168     void enterSubDir(DictKeyVal *subdir) { }
169     void exitSubDir(DictKeyVal *subdir) { }
170   };
171
172   NonParallelFunctor functor(funcBase);
173   MainModel mm;
174   //
175   TestBlueVisitor vis;
176   mm.visitAll(&vis);
177   //
178   AdaoExchangeLayer adao;
179   adao.init();
180   // For bounds, Background/Vector, Observation/Vector
181   Visitor2 visitorPythonObj(adao.getPythonContext());
182   mm.visitPythonLeaves(&visitorPythonObj);
183   //
184   adao.loadTemplate(&mm);
185   //
186   {
187     std::string sciptPyOfModelMaker(mm.pyStr());
188     //std::cerr << sciptPyOfModelMaker << std::endl;
189   }
190   adao.execute();
191     PyObject *listOfElts( nullptr );
192     while( adao.next(listOfElts) )
193       {
194         PyObject *resultOfChunk(functor(listOfElts));
195         adao.setResult(resultOfChunk);
196       }
197     PyObject *res(adao.getResult());
198     PyObjectRAII optimum(PyObjectRAII::FromNew(res));
199     PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
200     std::vector<double> vect;
201     {
202       py2cpp::PyPtr obj(optimum_4_py2cpp);
203       py2cpp::fromPyPtr(obj,vect);
204     }
205     CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
206     CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
207     CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
208     CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
209 }
210
211 void AdaoExchangeTest::testNonLinearLeastSquares()
212 {
213   class TestNonLinearLeastSquaresVisitor : public RecursiveVisitor
214   {
215   public:
216     void visit(GenericKeyVal *obj)
217     {
218       EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
219       if(objc)
220         objc->setVal(EnumAlgo::NonLinearLeastSquares);
221     }
222     void enterSubDir(DictKeyVal *subdir) { }
223     void exitSubDir(DictKeyVal *subdir) { }
224   };
225   NonParallelFunctor functor(funcBase);
226   MainModel mm;
227   //
228   TestNonLinearLeastSquaresVisitor vis;
229   mm.visitAll(&vis);
230   //
231   AdaoExchangeLayer adao;
232   adao.init();
233   // For bounds, Background/Vector, Observation/Vector
234   Visitor2 visitorPythonObj(adao.getPythonContext());
235   mm.visitPythonLeaves(&visitorPythonObj);
236   //
237   adao.loadTemplate(&mm);
238   //
239   {
240     std::string sciptPyOfModelMaker(mm.pyStr());
241     //std::cerr << sciptPyOfModelMaker << std::endl;
242   }
243   adao.execute();
244   PyObject *listOfElts( nullptr );
245   while( adao.next(listOfElts) )
246     {
247       PyObject *resultOfChunk(functor(listOfElts));
248       adao.setResult(resultOfChunk);
249     }
250   PyObject *res(adao.getResult());
251   PyObjectRAII optimum(PyObjectRAII::FromNew(res));
252   PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
253   std::vector<double> vect;
254   {
255     py2cpp::PyPtr obj(optimum_4_py2cpp);
256     py2cpp::fromPyPtr(obj,vect);
257   }
258   CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
259   CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
260   CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
261   CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
262 }
263
264 void AdaoExchangeTest::testCasCrue()
265 {
266   NonParallelFunctor functor(funcCrue);
267   MainModel mm;
268   AdaoExchangeLayer adao;
269   adao.init();
270   // For bounds, Background/Vector, Observation/Vector
271   VisitorCruePython visitorPythonObj(adao.getPythonContext());
272   mm.visitPythonLeaves(&visitorPythonObj);
273   //
274   adao.loadTemplate(&mm);
275   //
276   {
277     std::string sciptPyOfModelMaker(mm.pyStr());
278     //std::cerr << sciptPyOfModelMaker << std::endl;
279   }
280   adao.execute();
281   PyObject *listOfElts( nullptr );
282   while( adao.next(listOfElts) )
283     {
284       PyObject *resultOfChunk(functor(listOfElts));
285       adao.setResult(resultOfChunk);
286     }
287   PyObject *res(adao.getResult());
288   PyObjectRAII optimum(PyObjectRAII::FromNew(res));
289   PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
290   std::vector<double> vect;
291   {
292     py2cpp::PyPtr obj(optimum_4_py2cpp);
293     py2cpp::fromPyPtr(obj,vect);
294   }
295   CPPUNIT_ASSERT_EQUAL(1,(int)vect.size());
296   CPPUNIT_ASSERT_DOUBLES_EQUAL(25.,vect[0],1e-3);
297 }
298
299 CPPUNIT_TEST_SUITE_REGISTRATION( AdaoExchangeTest );
300
301 #include <cppunit/CompilerOutputter.h>
302 #include <cppunit/TestResult.h>
303 #include <cppunit/TestResultCollector.h>
304 #include <cppunit/TextTestProgressListener.h>
305 #include <cppunit/BriefTestProgressListener.h>
306 #include <cppunit/extensions/TestFactoryRegistry.h>
307 #include <cppunit/TestRunner.h>
308 #include <cppunit/TextTestRunner.h>
309
310 int main(int argc, char* argv[])
311 {
312   // --- Create the event manager and test controller
313   CPPUNIT_NS::TestResult controller;
314
315   // ---  Add a listener that collects test result
316   CPPUNIT_NS::TestResultCollector result;
317   controller.addListener( &result );        
318
319   // ---  Add a listener that print dots as test run.
320 #ifdef WIN32
321   CPPUNIT_NS::TextTestProgressListener progress;
322 #else
323   CPPUNIT_NS::BriefTestProgressListener progress;
324 #endif
325   controller.addListener( &progress );      
326
327   // ---  Get the top level suite from the registry
328
329   CPPUNIT_NS::Test *suite =
330     CPPUNIT_NS::TestFactoryRegistry::getRegistry().makeTest();
331
332   // ---  Adds the test to the list of test to run
333
334   CPPUNIT_NS::TestRunner runner;
335   runner.addTest( suite );
336   runner.run( controller);
337
338   // ---  Print test in a compiler compatible format.
339   std::ofstream testFile;
340   testFile.open("test.log", std::ios::out | std::ios::app);
341   testFile << "------ ADAO exchange test log:" << std::endl;
342   CPPUNIT_NS::CompilerOutputter outputter( &result, testFile );
343   outputter.write(); 
344
345   // ---  Run the tests.
346
347   bool wasSucessful = result.wasSuccessful();
348   testFile.close();
349
350   // ---  Return error code 1 if the one of test failed.
351
352   return wasSucessful ? 0 : 1;
353 }