1 // Copyright (C) 2019 EDF R&D
3 // This library is free software; you can redistribute it and/or
4 // modify it under the terms of the GNU Lesser General Public
5 // License as published by the Free Software Foundation; either
6 // version 2.1 of the License.
8 // This library is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11 // Lesser General Public License for more details.
13 // You should have received a copy of the GNU Lesser General Public
14 // License along with this library; if not, write to the Free Software
15 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 // See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
19 // Author: Anthony Geay, anthony.geay@edf.fr, EDF R&D
21 #include "TestAdaoExchange.hxx"
23 #include "AdaoExchangeLayer.hxx"
24 #include "AdaoExchangeLayerException.hxx"
25 #include "AdaoModelKeyVal.hxx"
26 #include "PyObjectRAII.hxx"
28 #include "py2cpp/py2cpp.hxx"
33 #include "TestAdaoHelper.cxx"
35 // Functor a remplacer par un appel a un evaluateur parallele
36 class NonParallelFunctor
39 NonParallelFunctor(std::function< std::vector<double>(const std::vector<double>&) > cppFunction):_cpp_function(cppFunction) { }
40 PyObject *operator()(PyObject *inp) const
42 PyGILState_STATE gstate(PyGILState_Ensure());
43 PyObjectRAII iterator(PyObjectRAII::FromNew(PyObject_GetIter(inp)));
45 throw AdaoExchangeLayerException("Input object is not iterable !");
46 PyObject *item(nullptr);
47 PyObjectRAII numpyModule(PyObjectRAII::FromNew(PyImport_ImportModule("numpy")));
48 if(numpyModule.isNull())
49 throw AdaoExchangeLayerException("Failed to load numpy");
50 PyObjectRAII ravelFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(numpyModule,"ravel")));
51 std::vector< PyObjectRAII > pyrets;
52 while( item = PyIter_Next(iterator) )
54 PyObjectRAII item2(PyObjectRAII::FromNew(item));
56 PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(1)));
57 PyTuple_SetItem(args,0,item2.retn());
58 PyObjectRAII npyArray(PyObjectRAII::FromNew(PyObject_CallObject(ravelFunc,args)));
59 // Waiting management of npy arrays into py2cpp
60 PyObjectRAII lolistFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(npyArray,"tolist")));
63 PyObjectRAII args2(PyObjectRAII::FromNew(PyTuple_New(0)));
64 listPy=PyObjectRAII::FromNew(PyObject_CallObject(lolistFunc,args2));
66 std::vector<double> vect;
68 py2cpp::PyPtr listPy2(listPy.retn());
69 py2cpp::fromPyPtr(listPy2,vect);
72 PyGILState_Release(gstate);
73 std::vector<double> res(_cpp_function(vect));// L'appel effectif est ici
74 gstate=PyGILState_Ensure();
76 py2cpp::PyPtr resPy(py2cpp::toPyPtr(res));
77 PyObjectRAII resPy2(PyObjectRAII::FromBorrowed(resPy.get()));
78 pyrets.push_back(resPy2);
81 std::size_t len(pyrets.size());
82 PyObjectRAII ret(PyObjectRAII::FromNew(PyList_New(len)));
83 for(std::size_t i=0;i<len;++i)
85 PyList_SetItem(ret,i,pyrets[i].retn());
87 //PyObject *tmp(PyObject_Repr(ret));
88 // std::cerr << PyUnicode_AsUTF8(tmp) << std::endl;
89 PyGILState_Release(gstate);
93 std::function< std::vector<double>(const std::vector<double>&) > _cpp_function;
96 PyObjectRAII NumpyToListWaitingForPy2CppManagement(PyObject *npObj)
98 PyObjectRAII func(PyObjectRAII::FromNew(PyObject_GetAttrString(npObj,"tolist")));
100 throw AdaoExchangeLayerException("input pyobject does not have tolist attribute !");
101 PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(0)));
102 PyObjectRAII ret(PyObjectRAII::FromNew(PyObject_CallObject(func,args)));
106 void AdaoExchangeTest::setUp()
110 void AdaoExchangeTest::tearDown()
114 void AdaoExchangeTest::cleanUp()
118 using namespace AdaoModel;
120 void AdaoExchangeTest::test3DVar()
122 NonParallelFunctor functor(funcBase);
124 AdaoExchangeLayer adao;
126 // For bounds, Background/Vector, Observation/Vector
127 Visitor2 visitorPythonObj(adao.getPythonContext());
130 mm.visitPythonLeaves(&visitorPythonObj);
133 adao.loadTemplate(&mm);
136 std::string sciptPyOfModelMaker(mm.pyStr());
137 //std::cerr << sciptPyOfModelMaker << std::endl;
140 PyObject *listOfElts( nullptr );
141 while( adao.next(listOfElts) )
143 PyObject *resultOfChunk(functor(listOfElts));
144 adao.setResult(resultOfChunk);
146 PyObject *res(adao.getResult());
147 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
148 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
149 std::vector<double> vect;
151 py2cpp::PyPtr obj(optimum_4_py2cpp);
152 py2cpp::fromPyPtr(obj,vect);
154 CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
155 CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
156 CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
157 CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
160 void AdaoExchangeTest::testBlue()
162 class TestBlueVisitor : public RecursiveVisitor
165 void visit(GenericKeyVal *obj)
167 EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
169 objc->setVal(EnumAlgo::Blue);
171 void enterSubDir(DictKeyVal *subdir) { }
172 void exitSubDir(DictKeyVal *subdir) { }
175 NonParallelFunctor functor(funcBase);
181 AdaoExchangeLayer adao;
183 // For bounds, Background/Vector, Observation/Vector
184 Visitor2 visitorPythonObj(adao.getPythonContext());
187 mm.visitPythonLeaves(&visitorPythonObj);
190 adao.loadTemplate(&mm);
193 std::string sciptPyOfModelMaker(mm.pyStr());
194 //std::cerr << sciptPyOfModelMaker << std::endl;
197 PyObject *listOfElts( nullptr );
198 while( adao.next(listOfElts) )
200 PyObject *resultOfChunk(functor(listOfElts));
201 adao.setResult(resultOfChunk);
203 PyObject *res(adao.getResult());
204 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
205 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
206 std::vector<double> vect;
208 py2cpp::PyPtr obj(optimum_4_py2cpp);
209 py2cpp::fromPyPtr(obj,vect);
211 CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
212 CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
213 CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
214 CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
217 void AdaoExchangeTest::testNonLinearLeastSquares()
219 class TestNonLinearLeastSquaresVisitor : public RecursiveVisitor
222 void visit(GenericKeyVal *obj)
224 EnumAlgoKeyVal *objc(dynamic_cast<EnumAlgoKeyVal *>(obj));
226 objc->setVal(EnumAlgo::NonLinearLeastSquares);
228 void enterSubDir(DictKeyVal *subdir) { }
229 void exitSubDir(DictKeyVal *subdir) { }
231 NonParallelFunctor functor(funcBase);
234 TestNonLinearLeastSquaresVisitor vis;
237 AdaoExchangeLayer adao;
239 // For bounds, Background/Vector, Observation/Vector
240 Visitor2 visitorPythonObj(adao.getPythonContext());
243 mm.visitPythonLeaves(&visitorPythonObj);
246 adao.loadTemplate(&mm);
249 std::string sciptPyOfModelMaker(mm.pyStr());
250 //std::cerr << sciptPyOfModelMaker << std::endl;
253 PyObject *listOfElts( nullptr );
254 while( adao.next(listOfElts) )
256 PyObject *resultOfChunk(functor(listOfElts));
257 adao.setResult(resultOfChunk);
259 PyObject *res(adao.getResult());
260 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
261 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
262 std::vector<double> vect;
264 py2cpp::PyPtr obj(optimum_4_py2cpp);
265 py2cpp::fromPyPtr(obj,vect);
267 CPPUNIT_ASSERT_EQUAL(3,(int)vect.size());
268 CPPUNIT_ASSERT_DOUBLES_EQUAL(2.,vect[0],1e-7);
269 CPPUNIT_ASSERT_DOUBLES_EQUAL(3.,vect[1],1e-7);
270 CPPUNIT_ASSERT_DOUBLES_EQUAL(4.,vect[2],1e-7);
273 void AdaoExchangeTest::testCasCrue()
275 NonParallelFunctor functor(funcCrue);
277 AdaoExchangeLayer adao;
279 // For bounds, Background/Vector, Observation/Vector
280 VisitorCruePython visitorPythonObj(adao.getPythonContext());
283 mm.visitPythonLeaves(&visitorPythonObj);
286 adao.loadTemplate(&mm);
289 std::string sciptPyOfModelMaker(mm.pyStr());
290 //std::cerr << sciptPyOfModelMaker << std::endl;
293 PyObject *listOfElts( nullptr );
294 while( adao.next(listOfElts) )
296 PyObject *resultOfChunk(functor(listOfElts));
297 adao.setResult(resultOfChunk);
299 PyObject *res(adao.getResult());
300 PyObjectRAII optimum(PyObjectRAII::FromNew(res));
301 PyObjectRAII optimum_4_py2cpp(NumpyToListWaitingForPy2CppManagement(optimum));
302 std::vector<double> vect;
304 py2cpp::PyPtr obj(optimum_4_py2cpp);
305 py2cpp::fromPyPtr(obj,vect);
307 CPPUNIT_ASSERT_EQUAL(1,(int)vect.size());
308 CPPUNIT_ASSERT_DOUBLES_EQUAL(25.,vect[0],1e-3);
311 CPPUNIT_TEST_SUITE_REGISTRATION( AdaoExchangeTest );
313 #include <cppunit/CompilerOutputter.h>
314 #include <cppunit/TestResult.h>
315 #include <cppunit/TestResultCollector.h>
316 #include <cppunit/TextTestProgressListener.h>
317 #include <cppunit/BriefTestProgressListener.h>
318 #include <cppunit/extensions/TestFactoryRegistry.h>
319 #include <cppunit/TestRunner.h>
320 #include <cppunit/TextTestRunner.h>
322 int main(int argc, char* argv[])
324 // --- Create the event manager and test controller
325 CPPUNIT_NS::TestResult controller;
327 // --- Add a listener that collects test result
328 CPPUNIT_NS::TestResultCollector result;
329 controller.addListener( &result );
331 // --- Add a listener that print dots as test run.
333 CPPUNIT_NS::TextTestProgressListener progress;
335 CPPUNIT_NS::BriefTestProgressListener progress;
337 controller.addListener( &progress );
339 // --- Get the top level suite from the registry
341 CPPUNIT_NS::Test *suite =
342 CPPUNIT_NS::TestFactoryRegistry::getRegistry().makeTest();
344 // --- Adds the test to the list of test to run
346 CPPUNIT_NS::TestRunner runner;
347 runner.addTest( suite );
348 runner.run( controller);
350 // --- Print test in a compiler compatible format.
351 std::ofstream testFile;
352 testFile.open("test.log", std::ios::out | std::ios::app);
353 testFile << "------ ADAO exchange test log:" << std::endl;
354 CPPUNIT_NS::CompilerOutputter outputter( &result, testFile );
357 // --- Run the tests.
359 bool wasSucessful = result.wasSuccessful();
362 // --- Return error code 1 if the one of test failed.
364 return wasSucessful ? 0 : 1;