Salome HOME
Remove useless dependency to py2cpp.
[tools/adao_interface.git] / TestAdaoExchange.cxx
1 // Copyright (C) 2019 EDF R&D
2 //
3 // This library is free software; you can redistribute it and/or
4 // modify it under the terms of the GNU Lesser General Public
5 // License as published by the Free Software Foundation; either
6 // version 2.1 of the License.
7 //
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
11 // Lesser General Public License for more details.
12 //
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
16 //
17 // See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
18 //
19 // Author: Anthony Geay, anthony.geay@edf.fr, EDF R&D
20
21 #include "TestAdaoExchange.hxx"
22
23 #include "AdaoExchangeLayer.hxx"
24 #include "AdaoExchangeLayerException.hxx"
25 #include "AdaoModelKeyVal.hxx"
26 #include "PyObjectRAII.hxx"
27
28 #include "py2cpp/py2cpp.hxx"
29
30 #include <vector>
31 #include <iterator>
32
33 #include "TestAdaoHelper.cxx"
34
35 // Functor a remplacer par un appel a un evaluateur parallele
36 class NonParallelFunctor
37 {
38 public:
39   NonParallelFunctor(std::function< std::vector<double>(const std::vector<double>&) > cppFunction):_cpp_function(cppFunction) { }
40   PyObject *operator()(PyObject *inp) const
41   {
42     PyGILState_STATE gstate(PyGILState_Ensure());
43     PyObjectRAII iterator(PyObjectRAII::FromNew(PyObject_GetIter(inp)));
44     if(iterator.isNull())
45       throw AdaoExchangeLayerException("Input object is not iterable !");
46     PyObject *item(nullptr);
47     PyObjectRAII numpyModule(PyObjectRAII::FromNew(PyImport_ImportModule("numpy")));
48     if(numpyModule.isNull())
49       throw AdaoExchangeLayerException("Failed to load numpy");
50     PyObjectRAII ravelFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(numpyModule,"ravel")));
51     std::vector< PyObjectRAII > pyrets;
52     while( item = PyIter_Next(iterator) )
53       {
54         PyObjectRAII item2(PyObjectRAII::FromNew(item));
55         {
56           PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(1)));
57           PyTuple_SetItem(args,0,item2.retn());
58           PyObjectRAII npyArray(PyObjectRAII::FromNew(PyObject_CallObject(ravelFunc,args)));
59           // Waiting management of npy arrays into py2cpp
60           PyObjectRAII lolistFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(npyArray,"tolist")));
61           PyObjectRAII listPy;
62           {
63             PyObjectRAII args2(PyObjectRAII::FromNew(PyTuple_New(0)));
64             listPy=PyObjectRAII::FromNew(PyObject_CallObject(lolistFunc,args2));
65           }
66           std::vector<double> vect;
67           {
68             py2cpp::PyPtr listPy2(listPy.retn());
69             py2cpp::fromPyPtr(listPy2,vect);
70           }
71           //
72           PyGILState_Release(gstate);
73           std::vector<double> res(_cpp_function(vect));// L'appel effectif est ici
74           gstate=PyGILState_Ensure();
75           //
76           py2cpp::PyPtr resPy(py2cpp::toPyPtr(res));
77           PyObjectRAII resPy2(PyObjectRAII::FromBorrowed(resPy.get()));
78           pyrets.push_back(resPy2);
79         }
80       }
81     std::size_t len(pyrets.size());
82     PyObjectRAII ret(PyObjectRAII::FromNew(PyList_New(len)));
83     for(std::size_t i=0;i<len;++i)
84       {
85         PyList_SetItem(ret,i,pyrets[i].retn());
86       }
87     //PyObject *tmp(PyObject_Repr(ret));
88     // std::cerr << PyUnicode_AsUTF8(tmp) << std::endl;
89     PyGILState_Release(gstate);
90     return ret.retn();
91   }
92 private:
93   std::function< std::vector<double>(const std::vector<double>&) > _cpp_function;
94 };
95
96 PyObjectRAII NumpyToListWaitingForPy2CppManagement(PyObject *npObj)
97 {
98   PyObjectRAII func(PyObjectRAII::FromNew(PyObject_GetAttrString(npObj,"tolist")));
99   if(func.isNull())
100     throw AdaoExchangeLayerException("input pyobject does not have tolist attribute !");
101   PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(0)));
102   PyObjectRAII ret(PyObjectRAII::FromNew(PyObject_CallObject(func,args)));
103   return ret;
104 }
105
106 void AdaoExchangeTest::setUp()
107 {
108 }
109
110 void AdaoExchangeTest::tearDown()
111 {
112 }
113
114 void AdaoExchangeTest::cleanUp()
115 {
116 }
117
118 using namespace AdaoModel;
119
120 void AdaoExchangeTest::test3DVar()
121 {
122   NonParallelFunctor functor(funcBase);
123   MainModel mm;
124   AdaoExchangeLayer adao;
125   adao.init();
126   // For bounds, Background/Vector, Observation/Vector
127   adao.setFunctionCallbackInModel(&mm);
128   Visitor2 visitorPythonObj(adao.getPythonContext());
129   {
130     AutoGIL agil;
131     mm.visitPythonLeaves(&visitorPythonObj);
132   }
133   //
134   adao.loadTemplate(&mm);
135   //
136   {
137     std::string sciptPyOfModelMaker(mm.pyStr());
138     //std::cerr << sciptPyOfModelMaker << std::endl;
139   }
140   adao.execute();
141   PyObject *listOfElts( nullptr );
142   while( adao.next(listOfElts) )
143     {
144       PyObject *resultOfChunk(functor(listOfElts));
145       adao.setResult(resultOfChunk);
146     }
147   PyObject *res(adao.getResult());
148   PyObjectRAII optimum(PyObjectRAII::FromNew(res));
149   PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
150   std::vector<double> vect;
151   {
152     py2cpp::PyPtr obj(optimum_4_py2cpp);
153     py2cpp::fromPyPtr(obj,vect);
154   }
155   CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
156   CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
157   CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
158   CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
159 }
160
161 void AdaoExchangeTest::testBlue()
162 {
163   class TestBlueVisitor : public RecursiveVisitor
164   {
165   public:
166     void visit(GenericKeyVal *obj)
167     {
168       EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
169       if(objc)
170         objc->setVal(EnumAlgo::Blue);
171     }
172     void enterSubDir(DictKeyVal *subdir) { }
173     void exitSubDir(DictKeyVal *subdir) { }
174   };
175
176   NonParallelFunctor functor(funcBase);
177   MainModel mm;
178   //
179   TestBlueVisitor vis;
180   mm.visitAll(&vis);
181   //
182   AdaoExchangeLayer adao;
183   adao.init();
184   // For bounds, Background/Vector, Observation/Vector
185   adao.setFunctionCallbackInModel(&mm);
186   Visitor2 visitorPythonObj(adao.getPythonContext());
187   {
188     AutoGIL agil;
189     mm.visitPythonLeaves(&visitorPythonObj);
190   }
191   //
192   adao.loadTemplate(&mm);
193   //
194   {
195     std::string sciptPyOfModelMaker(mm.pyStr());
196     //std::cerr << sciptPyOfModelMaker << std::endl;
197   }
198   adao.execute();
199     PyObject *listOfElts( nullptr );
200     while( adao.next(listOfElts) )
201       {
202         PyObject *resultOfChunk(functor(listOfElts));
203         adao.setResult(resultOfChunk);
204       }
205     PyObject *res(adao.getResult());
206     PyObjectRAII optimum(PyObjectRAII::FromNew(res));
207     PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
208     std::vector<double> vect;
209     {
210       py2cpp::PyPtr obj(optimum_4_py2cpp);
211       py2cpp::fromPyPtr(obj,vect);
212     }
213     CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
214     CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
215     CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
216     CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
217 }
218
219 void AdaoExchangeTest::testNonLinearLeastSquares()
220 {
221   class TestNonLinearLeastSquaresVisitor : public RecursiveVisitor
222   {
223   public:
224     void visit(GenericKeyVal *obj)
225     {
226       EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
227       if(objc)
228         objc->setVal(EnumAlgo::NonLinearLeastSquares);
229     }
230     void enterSubDir(DictKeyVal *subdir) { }
231     void exitSubDir(DictKeyVal *subdir) { }
232   };
233   NonParallelFunctor functor(funcBase);
234   MainModel mm;
235   //
236   TestNonLinearLeastSquaresVisitor vis;
237   mm.visitAll(&vis);
238   //
239   AdaoExchangeLayer adao;
240   adao.init();
241   // For bounds, Background/Vector, Observation/Vector
242   adao.setFunctionCallbackInModel(&mm);
243   Visitor2 visitorPythonObj(adao.getPythonContext());
244   {
245     AutoGIL agil;
246     mm.visitPythonLeaves(&visitorPythonObj);
247   }
248   //
249   adao.loadTemplate(&mm);
250   //
251   {
252     std::string sciptPyOfModelMaker(mm.pyStr());
253     //std::cerr << sciptPyOfModelMaker << std::endl;
254   }
255   adao.execute();
256   PyObject *listOfElts( nullptr );
257   while( adao.next(listOfElts) )
258     {
259       PyObject *resultOfChunk(functor(listOfElts));
260       adao.setResult(resultOfChunk);
261     }
262   PyObject *res(adao.getResult());
263   PyObjectRAII optimum(PyObjectRAII::FromNew(res));
264   PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
265   std::vector<double> vect;
266   {
267     py2cpp::PyPtr obj(optimum_4_py2cpp);
268     py2cpp::fromPyPtr(obj,vect);
269   }
270   CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
271   CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
272   CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
273   CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
274 }
275
276 void AdaoExchangeTest::testCasCrue()
277 {
278   NonParallelFunctor functor(funcCrue);
279   MainModel mm;
280   AdaoExchangeLayer adao;
281   adao.init();
282   // For bounds, Background/Vector, Observation/Vector
283   adao.setFunctionCallbackInModel(&mm);
284   VisitorCruePython visitorPythonObj(adao.getPythonContext());
285   {
286     AutoGIL agil;
287     mm.visitPythonLeaves(&visitorPythonObj);
288   }
289   //
290   adao.loadTemplate(&mm);
291   //
292   {
293     std::string sciptPyOfModelMaker(mm.pyStr());
294     //std::cerr << sciptPyOfModelMaker << std::endl;
295   }
296   adao.execute();
297   PyObject *listOfElts( nullptr );
298   while( adao.next(listOfElts) )
299     {
300       PyObject *resultOfChunk(functor(listOfElts));
301       adao.setResult(resultOfChunk);
302     }
303   PyObject *res(adao.getResult());
304   PyObjectRAII optimum(PyObjectRAII::FromNew(res));
305   PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
306   std::vector<double> vect;
307   {
308     py2cpp::PyPtr obj(optimum_4_py2cpp);
309     py2cpp::fromPyPtr(obj,vect);
310   }
311   CPPUNIT_ASSERT_EQUAL(1,(int)vect.size());
312   CPPUNIT_ASSERT_DOUBLES_EQUAL(25.,vect[0],1e-3);
313 }
314
315 CPPUNIT_TEST_SUITE_REGISTRATION( AdaoExchangeTest );
316
317 #include <cppunit/CompilerOutputter.h>
318 #include <cppunit/TestResult.h>
319 #include <cppunit/TestResultCollector.h>
320 #include <cppunit/TextTestProgressListener.h>
321 #include <cppunit/BriefTestProgressListener.h>
322 #include <cppunit/extensions/TestFactoryRegistry.h>
323 #include <cppunit/TestRunner.h>
324 #include <cppunit/TextTestRunner.h>
325
326 int main(int argc, char* argv[])
327 {
328   // --- Create the event manager and test controller
329   CPPUNIT_NS::TestResult controller;
330
331   // ---  Add a listener that collects test result
332   CPPUNIT_NS::TestResultCollector result;
333   controller.addListener( &result );        
334
335   // ---  Add a listener that print dots as test run.
336 #ifdef WIN32
337   CPPUNIT_NS::TextTestProgressListener progress;
338 #else
339   CPPUNIT_NS::BriefTestProgressListener progress;
340 #endif
341   controller.addListener( &progress );      
342
343   // ---  Get the top level suite from the registry
344
345   CPPUNIT_NS::Test *suite =
346     CPPUNIT_NS::TestFactoryRegistry::getRegistry().makeTest();
347
348   // ---  Adds the test to the list of test to run
349
350   CPPUNIT_NS::TestRunner runner;
351   runner.addTest( suite );
352   runner.run( controller);
353
354   // ---  Print test in a compiler compatible format.
355   std::ofstream testFile;
356   testFile.open("test.log", std::ios::out | std::ios::app);
357   testFile << "------ ADAO exchange test log:" << std::endl;
358   CPPUNIT_NS::CompilerOutputter outputter( &result, testFile );
359   outputter.write(); 
360
361   // ---  Run the tests.
362
363   bool wasSucessful = result.wasSuccessful();
364   testFile.close();
365
366   // ---  Return error code 1 if the one of test failed.
367
368   return wasSucessful ? 0 : 1;
369 }