1 // Copyright (C) 2019 EDF R&D
3 // This library is free software; you can redistribute it and/or
4 // modify it under the terms of the GNU Lesser General Public
5 // License as published by the Free Software Foundation; either
6 // version 2.1 of the License.
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 // See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
19 // Author: Anthony Geay, anthony.geay@edf.fr, EDF R&D
21 #include "TestAdaoExchange.hxx"
23 #include "AdaoExchangeLayer.hxx"
24 #include "AdaoExchangeLayerException.hxx"
25 #include "AdaoModelKeyVal.hxx"
26 #include "PyObjectRAII.hxx"
28 #include "py2cpp/py2cpp.hxx"
33 #include "TestAdaoHelper.cxx"
35 // Functor a remplacer par un appel a un evaluateur parallele
36 class NonParallelFunctor
39 NonParallelFunctor(std::function< std::vector<double>(const std::vector<double>&) > cppFunction):_cpp_function(cppFunction) { }
40 PyObject *operator()(PyObject *inp) const
42 PyGILState_STATE gstate(PyGILState_Ensure());
43 PyObjectRAII iterator(PyObjectRAII::FromNew(PyObject_GetIter(inp)));
45 throw AdaoExchangeLayerException("Input object is not iterable !");
46 PyObject *item(nullptr);
47 PyObjectRAII numpyModule(PyObjectRAII::FromNew(PyImport_ImportModule("numpy")));
48 if(numpyModule.isNull())
49 throw AdaoExchangeLayerException("Failed to load numpy");
50 PyObjectRAII ravelFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(numpyModule,"ravel")));
51 std::vector< PyObjectRAII > pyrets;
52 while( item = PyIter_Next(iterator) )
54 PyObjectRAII item2(PyObjectRAII::FromNew(item));
56 PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(1)));
57 PyTuple_SetItem(args,0,item2.retn());
58 PyObjectRAII npyArray(PyObjectRAII::FromNew(PyObject_CallObject(ravelFunc,args)));
59 // Waiting management of npy arrays into py2cpp
60 PyObjectRAII lolistFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(npyArray,"tolist")));
63 PyObjectRAII args2(PyObjectRAII::FromNew(PyTuple_New(0)));
64 listPy=PyObjectRAII::FromNew(PyObject_CallObject(lolistFunc,args2));
66 std::vector<double> vect;
68 py2cpp::PyPtr listPy2(listPy.retn());
69 py2cpp::fromPyPtr(listPy2,vect);
72 PyGILState_Release(gstate);
73 std::vector<double> res(_cpp_function(vect));// L'appel effectif est ici
74 gstate=PyGILState_Ensure();
76 py2cpp::PyPtr resPy(py2cpp::toPyPtr(res));
77 PyObjectRAII resPy2(PyObjectRAII::FromBorrowed(resPy.get()));
78 pyrets.push_back(resPy2);
81 std::size_t len(pyrets.size());
82 PyObjectRAII ret(PyObjectRAII::FromNew(PyList_New(len)));
83 for(std::size_t i=0;i<len;++i)
85 PyList_SetItem(ret,i,pyrets[i].retn());
87 //PyObject *tmp(PyObject_Repr(ret));
88 // std::cerr << PyUnicode_AsUTF8(tmp) << std::endl;
89 PyGILState_Release(gstate);
93 std::function< std::vector<double>(const std::vector<double>&) > _cpp_function;
96 PyObjectRAII NumpyToListWaitingForPy2CppManagement(PyObject *npObj)
98 PyObjectRAII func(PyObjectRAII::FromNew(PyObject_GetAttrString(npObj,"tolist")));
100 throw AdaoExchangeLayerException("input pyobject does not have tolist attribute !");
101 PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(0)));
102 PyObjectRAII ret(PyObjectRAII::FromNew(PyObject_CallObject(func,args)));
106 void AdaoExchangeTest::setUp()
110 void AdaoExchangeTest::tearDown()
114 void AdaoExchangeTest::cleanUp()
118 using namespace AdaoModel;
120 void AdaoExchangeTest::test3DVar()
122 NonParallelFunctor functor(funcBase);
124 AdaoExchangeLayer adao;
126 // For bounds, Background/Vector, Observation/Vector
127 adao.setFunctionCallbackInModel(&mm);
128 Visitor2 visitorPythonObj(adao.getPythonContext());
131 mm.visitPythonLeaves(&visitorPythonObj);
134 adao.loadTemplate(&mm);
137 std::string sciptPyOfModelMaker(mm.pyStr());
138 //std::cerr << sciptPyOfModelMaker << std::endl;
141 PyObject *listOfElts( nullptr );
142 while( adao.next(listOfElts) )
144 PyObject *resultOfChunk(functor(listOfElts));
145 adao.setResult(resultOfChunk);
147 PyObject *res(adao.getResult());
148 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
149 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
150 std::vector<double> vect;
152 py2cpp::PyPtr obj(optimum_4_py2cpp);
153 py2cpp::fromPyPtr(obj,vect);
155 CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
156 CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
157 CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
158 CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
161 void AdaoExchangeTest::testBlue()
163 class TestBlueVisitor : public RecursiveVisitor
166 void visit(GenericKeyVal *obj)
168 EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
170 objc->setVal(EnumAlgo::Blue);
172 void enterSubDir(DictKeyVal *subdir) { }
173 void exitSubDir(DictKeyVal *subdir) { }
176 NonParallelFunctor functor(funcBase);
182 AdaoExchangeLayer adao;
184 // For bounds, Background/Vector, Observation/Vector
185 adao.setFunctionCallbackInModel(&mm);
186 Visitor2 visitorPythonObj(adao.getPythonContext());
189 mm.visitPythonLeaves(&visitorPythonObj);
192 adao.loadTemplate(&mm);
195 std::string sciptPyOfModelMaker(mm.pyStr());
196 //std::cerr << sciptPyOfModelMaker << std::endl;
199 PyObject *listOfElts( nullptr );
200 while( adao.next(listOfElts) )
202 PyObject *resultOfChunk(functor(listOfElts));
203 adao.setResult(resultOfChunk);
205 PyObject *res(adao.getResult());
206 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
207 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
208 std::vector<double> vect;
210 py2cpp::PyPtr obj(optimum_4_py2cpp);
211 py2cpp::fromPyPtr(obj,vect);
213 CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
214 CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
215 CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
216 CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
219 void AdaoExchangeTest::testNonLinearLeastSquares()
221 class TestNonLinearLeastSquaresVisitor : public RecursiveVisitor
224 void visit(GenericKeyVal *obj)
226 EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
228 objc->setVal(EnumAlgo::NonLinearLeastSquares);
230 void enterSubDir(DictKeyVal *subdir) { }
231 void exitSubDir(DictKeyVal *subdir) { }
233 NonParallelFunctor functor(funcBase);
236 TestNonLinearLeastSquaresVisitor vis;
239 AdaoExchangeLayer adao;
241 // For bounds, Background/Vector, Observation/Vector
242 adao.setFunctionCallbackInModel(&mm);
243 Visitor2 visitorPythonObj(adao.getPythonContext());
246 mm.visitPythonLeaves(&visitorPythonObj);
249 adao.loadTemplate(&mm);
252 std::string sciptPyOfModelMaker(mm.pyStr());
253 //std::cerr << sciptPyOfModelMaker << std::endl;
256 PyObject *listOfElts( nullptr );
257 while( adao.next(listOfElts) )
259 PyObject *resultOfChunk(functor(listOfElts));
260 adao.setResult(resultOfChunk);
262 PyObject *res(adao.getResult());
263 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
264 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
265 std::vector<double> vect;
267 py2cpp::PyPtr obj(optimum_4_py2cpp);
268 py2cpp::fromPyPtr(obj,vect);
270 CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
271 CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
272 CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
273 CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
276 void AdaoExchangeTest::testCasCrue()
278 NonParallelFunctor functor(funcCrue);
280 AdaoExchangeLayer adao;
282 // For bounds, Background/Vector, Observation/Vector
283 adao.setFunctionCallbackInModel(&mm);
284 VisitorCruePython visitorPythonObj(adao.getPythonContext());
287 mm.visitPythonLeaves(&visitorPythonObj);
290 adao.loadTemplate(&mm);
293 std::string sciptPyOfModelMaker(mm.pyStr());
294 //std::cerr << sciptPyOfModelMaker << std::endl;
297 PyObject *listOfElts( nullptr );
298 while( adao.next(listOfElts) )
300 PyObject *resultOfChunk(functor(listOfElts));
301 adao.setResult(resultOfChunk);
303 PyObject *res(adao.getResult());
304 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
305 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
306 std::vector<double> vect;
308 py2cpp::PyPtr obj(optimum_4_py2cpp);
309 py2cpp::fromPyPtr(obj,vect);
311 CPPUNIT_ASSERT_EQUAL(1,(int)vect.size());
312 CPPUNIT_ASSERT_DOUBLES_EQUAL(25.,vect[0],1e-3);
315 CPPUNIT_TEST_SUITE_REGISTRATION( AdaoExchangeTest );
317 #include <cppunit/CompilerOutputter.h>
318 #include <cppunit/TestResult.h>
319 #include <cppunit/TestResultCollector.h>
320 #include <cppunit/TextTestProgressListener.h>
321 #include <cppunit/BriefTestProgressListener.h>
322 #include <cppunit/extensions/TestFactoryRegistry.h>
323 #include <cppunit/TestRunner.h>
324 #include <cppunit/TextTestRunner.h>
326 int main(int argc, char* argv[])
328 // --- Create the event manager and test controller
329 CPPUNIT_NS::TestResult controller;
331 // --- Add a listener that collects test result
332 CPPUNIT_NS::TestResultCollector result;
333 controller.addListener( &result );
335 // --- Add a listener that print dots as test run.
337 CPPUNIT_NS::TextTestProgressListener progress;
339 CPPUNIT_NS::BriefTestProgressListener progress;
341 controller.addListener( &progress );
343 // --- Get the top level suite from the registry
345 CPPUNIT_NS::Test *suite =
346 CPPUNIT_NS::TestFactoryRegistry::getRegistry().makeTest();
348 // --- Adds the test to the list of test to run
350 CPPUNIT_NS::TestRunner runner;
351 runner.addTest( suite );
352 runner.run( controller);
354 // --- Print test in a compiler compatible format.
355 std::ofstream testFile;
356 testFile.open("test.log", std::ios::out | std::ios::app);
357 testFile << "------ ADAO exchange test log:" << std::endl;
358 CPPUNIT_NS::CompilerOutputter outputter( &result, testFile );
361 // --- Run the tests.
363 bool wasSucessful = result.wasSuccessful();
366 // --- Return error code 1 if the one of test failed.
368 return wasSucessful ? 0 : 1;