1 // Copyright (C) 2019 EDF R&D
3 // This library is free software; you can redistribute it and/or
4 // modify it under the terms of the GNU Lesser General Public
5 // License as published by the Free Software Foundation; either
6 // version 2.1 of the License.
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 // See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
19 // Author: Anthony Geay, anthony.geay@edf.fr, EDF R&D
21 #include "TestAdaoExchange.hxx"
23 #include "AdaoExchangeLayer.hxx"
24 #include "AdaoExchangeLayerException.hxx"
25 #include "AdaoModelKeyVal.hxx"
26 #include "PyObjectRAII.hxx"
28 #include "py2cpp/py2cpp.hxx"
33 #include "TestAdaoHelper.cxx"
35 // Functor a remplacer par un appel a un evaluateur parallele
36 class NonParallelFunctor
39 NonParallelFunctor(std::function< std::vector<double>(const std::vector<double>&) > cppFunction):_cpp_function(cppFunction) { }
40 PyObject *operator()(PyObject *inp) const
42 PyGILState_STATE gstate(PyGILState_Ensure());
43 PyObjectRAII iterator(PyObjectRAII::FromNew(PyObject_GetIter(inp)));
45 throw AdaoExchangeLayerException("Input object is not iterable !");
46 PyObject *item(nullptr);
47 PyObjectRAII numpyModule(PyObjectRAII::FromNew(PyImport_ImportModule("numpy")));
48 if(numpyModule.isNull())
49 throw AdaoExchangeLayerException("Failed to load numpy");
50 PyObjectRAII ravelFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(numpyModule,"ravel")));
51 std::vector< PyObjectRAII > pyrets;
52 while( item = PyIter_Next(iterator) )
54 PyObjectRAII item2(PyObjectRAII::FromNew(item));
56 PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(1)));
57 PyTuple_SetItem(args,0,item2.retn());
58 PyObjectRAII npyArray(PyObjectRAII::FromNew(PyObject_CallObject(ravelFunc,args)));
59 // Waiting management of npy arrays into py2cpp
60 PyObjectRAII lolistFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(npyArray,"tolist")));
63 PyObjectRAII args2(PyObjectRAII::FromNew(PyTuple_New(0)));
64 listPy=PyObjectRAII::FromNew(PyObject_CallObject(lolistFunc,args2));
66 std::vector<double> vect;
68 py2cpp::PyPtr listPy2(listPy.retn());
69 py2cpp::fromPyPtr(listPy2,vect);
72 PyGILState_Release(gstate);
73 std::vector<double> res(_cpp_function(vect));// L'appel effectif est ici
74 gstate=PyGILState_Ensure();
76 py2cpp::PyPtr resPy(py2cpp::toPyPtr(res));
77 PyObjectRAII resPy2(PyObjectRAII::FromBorrowed(resPy.get()));
78 pyrets.push_back(resPy2);
81 std::size_t len(pyrets.size());
82 PyObjectRAII ret(PyObjectRAII::FromNew(PyList_New(len)));
83 for(std::size_t i=0;i<len;++i)
85 PyList_SetItem(ret,i,pyrets[i].retn());
87 //PyObject *tmp(PyObject_Repr(ret));
88 // std::cerr << PyUnicode_AsUTF8(tmp) << std::endl;
89 PyGILState_Release(gstate);
93 std::function< std::vector<double>(const std::vector<double>&) > _cpp_function;
96 PyObjectRAII NumpyToListWaitingForPy2CppManagement(PyObject *npObj)
98 PyObjectRAII func(PyObjectRAII::FromNew(PyObject_GetAttrString(npObj,"tolist")));
100 throw AdaoExchangeLayerException("input pyobject does not have tolist attribute !");
101 PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(0)));
102 PyObjectRAII ret(PyObjectRAII::FromNew(PyObject_CallObject(func,args)));
106 void AdaoExchangeTest::setUp()
110 void AdaoExchangeTest::tearDown()
114 void AdaoExchangeTest::cleanUp()
118 using namespace AdaoModel;
120 void AdaoExchangeTest::test3DVar()
122 NonParallelFunctor functor(funcBase);
124 AdaoExchangeLayer adao;
126 // For bounds, Background/Vector, Observation/Vector
127 Visitor2 visitorPythonObj(adao.getPythonContext());
128 mm.visitPythonLeaves(&visitorPythonObj);
130 adao.loadTemplate(&mm);
133 std::string sciptPyOfModelMaker(mm.pyStr());
134 //std::cerr << sciptPyOfModelMaker << std::endl;
137 PyObject *listOfElts( nullptr );
138 while( adao.next(listOfElts) )
140 PyObject *resultOfChunk(functor(listOfElts));
141 adao.setResult(resultOfChunk);
143 PyObject *res(adao.getResult());
144 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
145 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
146 std::vector<double> vect;
148 py2cpp::PyPtr obj(optimum_4_py2cpp);
149 py2cpp::fromPyPtr(obj,vect);
151 CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
152 CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
153 CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
154 CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
157 void AdaoExchangeTest::testBlue()
159 class TestBlueVisitor : public RecursiveVisitor
162 void visit(GenericKeyVal *obj)
164 EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
166 objc->setVal(EnumAlgo::Blue);
168 void enterSubDir(DictKeyVal *subdir) { }
169 void exitSubDir(DictKeyVal *subdir) { }
172 NonParallelFunctor functor(funcBase);
178 AdaoExchangeLayer adao;
180 // For bounds, Background/Vector, Observation/Vector
181 Visitor2 visitorPythonObj(adao.getPythonContext());
182 mm.visitPythonLeaves(&visitorPythonObj);
184 adao.loadTemplate(&mm);
187 std::string sciptPyOfModelMaker(mm.pyStr());
188 //std::cerr << sciptPyOfModelMaker << std::endl;
191 PyObject *listOfElts( nullptr );
192 while( adao.next(listOfElts) )
194 PyObject *resultOfChunk(functor(listOfElts));
195 adao.setResult(resultOfChunk);
197 PyObject *res(adao.getResult());
198 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
199 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
200 std::vector<double> vect;
202 py2cpp::PyPtr obj(optimum_4_py2cpp);
203 py2cpp::fromPyPtr(obj,vect);
205 CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
206 CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
207 CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
208 CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
211 void AdaoExchangeTest::testNonLinearLeastSquares()
213 class TestNonLinearLeastSquaresVisitor : public RecursiveVisitor
216 void visit(GenericKeyVal *obj)
218 EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
220 objc->setVal(EnumAlgo::NonLinearLeastSquares);
222 void enterSubDir(DictKeyVal *subdir) { }
223 void exitSubDir(DictKeyVal *subdir) { }
225 NonParallelFunctor functor(funcBase);
228 TestNonLinearLeastSquaresVisitor vis;
231 AdaoExchangeLayer adao;
233 // For bounds, Background/Vector, Observation/Vector
234 Visitor2 visitorPythonObj(adao.getPythonContext());
235 mm.visitPythonLeaves(&visitorPythonObj);
237 adao.loadTemplate(&mm);
240 std::string sciptPyOfModelMaker(mm.pyStr());
241 //std::cerr << sciptPyOfModelMaker << std::endl;
244 PyObject *listOfElts( nullptr );
245 while( adao.next(listOfElts) )
247 PyObject *resultOfChunk(functor(listOfElts));
248 adao.setResult(resultOfChunk);
250 PyObject *res(adao.getResult());
251 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
252 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
253 std::vector<double> vect;
255 py2cpp::PyPtr obj(optimum_4_py2cpp);
256 py2cpp::fromPyPtr(obj,vect);
258 CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
259 CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
260 CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
261 CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
264 void AdaoExchangeTest::testCasCrue()
266 NonParallelFunctor functor(funcCrue);
268 AdaoExchangeLayer adao;
270 // For bounds, Background/Vector, Observation/Vector
271 VisitorCruePython visitorPythonObj(adao.getPythonContext());
272 mm.visitPythonLeaves(&visitorPythonObj);
274 adao.loadTemplate(&mm);
277 std::string sciptPyOfModelMaker(mm.pyStr());
278 //std::cerr << sciptPyOfModelMaker << std::endl;
281 PyObject *listOfElts( nullptr );
282 while( adao.next(listOfElts) )
284 PyObject *resultOfChunk(functor(listOfElts));
285 adao.setResult(resultOfChunk);
287 PyObject *res(adao.getResult());
288 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
289 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
290 std::vector<double> vect;
292 py2cpp::PyPtr obj(optimum_4_py2cpp);
293 py2cpp::fromPyPtr(obj,vect);
295 CPPUNIT_ASSERT_EQUAL(1,(int)vect.size());
296 CPPUNIT_ASSERT_DOUBLES_EQUAL(25.,vect[0],1e-3);
299 CPPUNIT_TEST_SUITE_REGISTRATION( AdaoExchangeTest );
301 #include <cppunit/CompilerOutputter.h>
302 #include <cppunit/TestResult.h>
303 #include <cppunit/TestResultCollector.h>
304 #include <cppunit/TextTestProgressListener.h>
305 #include <cppunit/BriefTestProgressListener.h>
306 #include <cppunit/extensions/TestFactoryRegistry.h>
307 #include <cppunit/TestRunner.h>
308 #include <cppunit/TextTestRunner.h>
310 int main(int argc, char* argv[])
312 // --- Create the event manager and test controller
313 CPPUNIT_NS::TestResult controller;
315 // --- Add a listener that collects test result
316 CPPUNIT_NS::TestResultCollector result;
317 controller.addListener( &result );
319 // --- Add a listener that print dots as test run.
321 CPPUNIT_NS::TextTestProgressListener progress;
323 CPPUNIT_NS::BriefTestProgressListener progress;
325 controller.addListener( &progress );
327 // --- Get the top level suite from the registry
329 CPPUNIT_NS::Test *suite =
330 CPPUNIT_NS::TestFactoryRegistry::getRegistry().makeTest();
332 // --- Adds the test to the list of test to run
334 CPPUNIT_NS::TestRunner runner;
335 runner.addTest( suite );
336 runner.run( controller);
338 // --- Print test in a compiler compatible format.
339 std::ofstream testFile;
340 testFile.open("test.log", std::ios::out | std::ios::app);
341 testFile << "------ ADAO exchange test log:" << std::endl;
342 CPPUNIT_NS::CompilerOutputter outputter( &result, testFile );
345 // --- Run the tests.
347 bool wasSucessful = result.wasSuccessful();
350 // --- Return error code 1 if the one of test failed.
352 return wasSucessful ? 0 : 1;