Salome HOME
Le cas crue est dans la boite
[tools/adao_interface.git] / TestAdaoHelper.cxx
1 #include <vector>
2 #include "PyObjectRAII.hxx"
3
4 #include <cmath>
5
6 /* func pour test3DVar testBlue et testNonLinearLeastSquares*/
7 std::vector<double> funcBase(const std::vector<double>& vec)
8 {
9   return {vec[0],2.*vec[1],3.*vec[2],vec[0]+2.*vec[1]+3.*vec[2]};
10 }
11
12 double funcCrueInternal(double Q, double K_s)
13 {
14   constexpr double L(5.0e3);
15   constexpr double B(300.);
16   constexpr double Z_v(49.);
17   constexpr double Z_m(51.);
18   constexpr double alpha( (Z_m - Z_v)/L );
19   double H(pow((Q/(K_s*B*sqrt(alpha))),(3.0/5.0)));
20   return H;
21 }
22
23 /* func pour testCasCrue*/
24 std::vector<double> funcCrue(const std::vector<double>& vec)
25 {
26   double K_s(vec[0]);
27   constexpr double Qs[]={10.,20.,30.,40.};
28   constexpr size_t LEN(sizeof(Qs)/sizeof(double));
29   std::vector<double> ret(LEN);
30   for(std::size_t i=0;i<LEN;++i)
31     {
32       ret[i] = funcCrueInternal(Qs[i],K_s);
33     }
34   return ret;
35 }
36
37 PyObject *multiFuncCrue(PyObject *inp)
38 {
39   PyGILState_STATE gstate(PyGILState_Ensure());
40   PyObjectRAII iterator(PyObjectRAII::FromNew(PyObject_GetIter(inp)));
41   if(iterator.isNull())
42     throw AdaoExchangeLayerException("Input object is not iterable !");
43   PyObject *item(nullptr);
44   PyObjectRAII numpyModule(PyObjectRAII::FromNew(PyImport_ImportModule("numpy")));
45   if(numpyModule.isNull())
46     throw AdaoExchangeLayerException("Failed to load numpy");
47   PyObjectRAII ravelFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(numpyModule,"ravel")));
48   std::vector< PyObjectRAII > pyrets;
49   while( item = PyIter_Next(iterator) )
50     {
51       PyObjectRAII item2(PyObjectRAII::FromNew(item));
52       {
53         PyObjectRAII args(PyObjectRAII::FromNew(PyTuple_New(1)));
54         PyTuple_SetItem(args,0,item2.retn());
55         PyObjectRAII npyArray(PyObjectRAII::FromNew(PyObject_CallObject(ravelFunc,args)));
56         // Waiting management of npy arrays into py2cpp
57         PyObjectRAII lolistFunc(PyObjectRAII::FromNew(PyObject_GetAttrString(npyArray,"tolist")));
58         PyObjectRAII listPy;
59         {
60           PyObjectRAII args2(PyObjectRAII::FromNew(PyTuple_New(0)));
61           listPy=PyObjectRAII::FromNew(PyObject_CallObject(lolistFunc,args2));
62           }
63         std::vector<double> vect;
64         {
65           py2cpp::PyPtr listPy2(listPy.retn());
66           py2cpp::fromPyPtr(listPy2,vect);
67         }
68         //
69         PyGILState_Release(gstate);
70         std::vector<double> res(funcCrue(vect));
71         gstate=PyGILState_Ensure();
72         //
73         py2cpp::PyPtr resPy(py2cpp::toPyPtr(res));
74         PyObjectRAII resPy2(PyObjectRAII::FromBorrowed(resPy.get()));
75         pyrets.push_back(resPy2);
76       }
77     }
78   std::size_t len(pyrets.size());
79   PyObjectRAII ret(PyObjectRAII::FromNew(PyList_New(len)));
80   for(std::size_t i=0;i<len;++i)
81     {
82       PyList_SetItem(ret,i,pyrets[i].retn());
83     }
84   //PyObject *tmp(PyObject_Repr(ret));
85   // std::cerr << PyUnicode_AsUTF8(tmp) << std::endl;
86   PyGILState_Release(gstate);
87   return ret.retn();
88 }
89
90 class Visitor2 : public AdaoModel::PythonLeafVisitor
91 {
92 public:
93   Visitor2(PyObject *context):_context(context)
94   {
95     std::vector< std::vector<double> > bounds{ {0., 10.}, {3., 13.}, {1.5, 15.5} };
96     std::vector< double > Xb{5.,7.,9.};
97     py2cpp::PyPtr boundsPy(py2cpp::toPyPtr(bounds));
98     _bounds = boundsPy.get();
99     Py_XINCREF(_bounds);
100     py2cpp::PyPtr XbPy(py2cpp::toPyPtr(Xb));
101     _Xb = XbPy.get();
102     Py_XINCREF(_Xb);
103     std::vector<double> observation{2., 6., 12., 20.};
104     py2cpp::PyPtr observationPy(py2cpp::toPyPtr(observation));
105     _observation = observationPy.get();
106     Py_XINCREF(_observation);
107   }
108   
109   void visit(AdaoModel::MainModel *godFather, AdaoModel::PyObjKeyVal *obj) override
110   {
111     if(obj->getKey()=="Bounds")
112       {
113         std::ostringstream oss; oss << "___" << _cnt++;
114         std::string varname(oss.str());
115         obj->setVal(_bounds);
116         PyDict_SetItemString(_context,varname.c_str(),_bounds);
117         obj->setVarName(varname);
118         return ;
119       }
120     if(godFather->findPathOf(obj)=="Background/Vector")
121       {
122         std::ostringstream oss; oss << "___" << _cnt++;
123         std::string varname(oss.str());
124         obj->setVal(_Xb);
125         PyDict_SetItemString(_context,varname.c_str(),_Xb);
126         obj->setVarName(varname);
127       }
128     if(godFather->findPathOf(obj)=="Observation/Vector")
129       {
130         std::ostringstream oss; oss << "____" << _cnt++;
131         std::string varname(oss.str());
132         obj->setVal(_observation);
133         PyDict_SetItemString(_context,varname.c_str(),_observation);
134         obj->setVarName(varname);
135       }
136   }
137 private:
138   unsigned int _cnt = 0;
139   PyObject *_bounds = nullptr;
140   PyObject *_Xb = nullptr;
141   PyObject *_observation = nullptr;
142   PyObject *_context = nullptr;
143 };
144
145 class VisitorCruePython : public AdaoModel::PythonLeafVisitor
146 {
147 public:
148   VisitorCruePython(PyObject *context):_context(context)
149   {
150     {//case.set( 'Background',          Vector=thetaB)
151       std::vector< double > Xb{ 20. };//thetaB
152       py2cpp::PyPtr XbPy(py2cpp::toPyPtr(Xb));
153       _Xb = XbPy.get();
154       Py_XINCREF(_Xb);
155     }
156     {//case.set( 'BackgroundError',     DiagonalSparseMatrix=sigmaTheta )
157       std::vector< double > sigmaTheta{ 5.e10 };
158       py2cpp::PyPtr sigmaThetaPy(py2cpp::toPyPtr(sigmaTheta));
159       _sigmaTheta = sigmaThetaPy.get();
160       Py_XINCREF(_sigmaTheta);
161     }
162     {//case.set( 'Observation',         Vector=Hobs)
163       std::vector<double> observation{0.19694513, 0.298513, 0.38073079, 0.45246109};
164       py2cpp::PyPtr observationPy(py2cpp::toPyPtr(observation));
165       _observation = observationPy.get();
166       Py_XINCREF(_observation);
167     }
168     {//case.set( 'ObservationError',    ScalarSparseMatrix=sigmaH )
169       double sigmaH( 0.5);
170       py2cpp::PyPtr sigmaHPy(py2cpp::toPyPtr(sigmaH));
171       _sigmaH = sigmaHPy.get();
172       Py_XINCREF(_sigmaH);
173     }
174   }
175
176
177   void visit(AdaoModel::MainModel *godFather, AdaoModel::PyObjKeyVal *obj) override
178   {
179     if(godFather->findPathOf(obj)=="Background/Vector")
180       {
181         std::ostringstream oss; oss << "___" << _cnt++;
182         std::string varname(oss.str());
183         obj->setVal(_Xb);
184         PyDict_SetItemString(_context,varname.c_str(),_Xb);
185         obj->setVarName(varname);
186       }
187     if(godFather->findPathOf(obj)=="BackgroundError/Matrix")
188       {
189         std::ostringstream oss; oss << "___" << _cnt++;
190         std::string varname(oss.str());
191         obj->setVal(_sigmaTheta);
192         PyDict_SetItemString(_context,varname.c_str(),_Xb);
193         obj->setVarName(varname);
194       }
195     if(godFather->findPathOf(obj)=="Observation/Vector")
196       {
197         std::ostringstream oss; oss << "____" << _cnt++;
198         std::string varname(oss.str());
199         obj->setVal(_observation);
200         PyDict_SetItemString(_context,varname.c_str(),_observation);
201         obj->setVarName(varname);
202       }
203     if(godFather->findPathOf(obj)=="ObservationError/Matrix")
204       {
205         std::ostringstream oss; oss << "____" << _cnt++;
206         std::string varname(oss.str());
207         obj->setVal(_sigmaH);
208         PyDict_SetItemString(_context,varname.c_str(),_sigmaH);
209         obj->setVarName(varname);
210       }
211   }
212 private:
213   unsigned int _cnt = 0;
214   PyObject *_Xb = nullptr;
215   PyObject *_sigmaH = nullptr;
216   PyObject *_sigmaTheta = nullptr;
217   PyObject *_observation = nullptr;
218   PyObject *_context = nullptr;
219 };
220