Salome HOME
35549326b61d6d71515d05e022a370e09594eb3d
[modules/adao.git] / src / daSalome / daYacsIntegration / daOptimizerLoop.py
1 #-*- coding: utf-8 -*-
2 # Copyright (C) 2010-2011 EDF R&D
3 #
4 # This library is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU Lesser General Public
6 # License as published by the Free Software Foundation; either
7 # version 2.1 of the License.
8 #
9 # This library is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 # Lesser General Public License for more details.
13 #
14 # You should have received a copy of the GNU Lesser General Public
15 # License along with this library; if not, write to the Free Software
16 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
17 #
18 # See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
19 #
20 # Author: André Ribes, andre.ribes@edf.fr, EDF R&D
21
22 import SALOMERuntime
23 import pilot
24 import pickle
25 import numpy
26 import threading
27
28 from daCore.AssimilationStudy import AssimilationStudy
29 from daYacsIntegration import daStudy
30
31 class OptimizerHooks:
32
33   def __init__(self, optim_algo):
34     self.optim_algo = optim_algo
35
36     # Gestion du compteur
37     self.sample_counter = 0
38     self.counter_lock = threading.Lock()
39
40   def create_sample(self, data, method):
41     sample = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/ParametricInput'))
42
43     # TODO Input, Output VarList
44     inputVarList  = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
45     outputVarList = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
46     for var in self.optim_algo.da_study.InputVariables:
47       inputVarList.pushBack(var)
48     for var in self.optim_algo.da_study.OutputVariables:
49       outputVarList.pushBack(var)
50     sample.setEltAtRank("inputVarList", inputVarList)
51     sample.setEltAtRank("outputVarList", outputVarList)
52
53     # Les parametres specifiques à ADAO
54     specificParameters = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("SALOME_TYPES/Parameter"))
55     method_name = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/Parameter'))
56     method_name.setEltAtRank("name", "method")
57     method_name.setEltAtRank("value", method)
58     specificParameters.pushBack(method_name)
59     sample.setEltAtRank("specificParameters", specificParameters)
60
61     # Les données
62     # TODO à faire
63     #print data
64     #print data.ndim
65     #print data.shape
66     #print data[:,0]
67     #print data.flatten()
68     #print data.flatten().shape
69
70     variable          = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
71     variable_sequence = pilot.SequenceAny_New(variable.getType())
72     state_sequence    = pilot.SequenceAny_New(variable_sequence.getType())
73     time_sequence     = pilot.SequenceAny_New(state_sequence.getType())
74
75     #print "Input Data", data
76     if isinstance(data, type((1,2))):
77       self.add_parameters(data[0], variable_sequence)
78       self.add_parameters(data[1], variable_sequence, Output=True) # Output == Y
79     else:
80       self.add_parameters(data, variable_sequence)
81     state_sequence.pushBack(variable_sequence)
82     time_sequence.pushBack(state_sequence)
83     sample.setEltAtRank("inputValues", time_sequence)
84     return sample
85
86   def add_parameters(self, data, variable_sequence, Output=False):
87     param = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
88     elt_list = 0 # index dans la liste des arguments
89     val_number = 0 # nbre dans l'argument courant
90     if not Output:
91       val_end = self.optim_algo.da_study.InputVariables[self.optim_algo.da_study.InputVariablesOrder[elt_list]] # nbr de l'argument courant (-1 == tout)
92     else:
93       val_end = self.optim_algo.da_study.OutputVariables[self.optim_algo.da_study.OutputVariablesOrder[elt_list]] # nbr de l'argument courant (-1 == tout)
94
95     it = data.flat
96     for val in it:
97       param.pushBack(val)
98       val_number += 1
99       # Test si l'argument est ok
100       if val_end != -1:
101         if val_number == val_end:
102           variable_sequence.pushBack(param)
103           param = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
104           val_number = 0
105           elt_list += 1
106           if not Output:
107             if elt_list < len(self.optim_algo.da_study.InputVariablesOrder):
108               val_end = self.optim_algo.da_study.InputVariables[self.optim_algo.da_study.InputVariablesOrder[elt_list]]
109             else:
110               break
111           else:
112             if elt_list < len(self.optim_algo.da_study.OutputVariablesOrder):
113               val_end = self.optim_algo.da_study.OutputVariables[self.optim_algo.da_study.OutputVariablesOrder[elt_list]]
114             else:
115               break
116     if val_end == -1:
117       variable_sequence.pushBack(param)
118
119   def get_data_from_any(self, any_data):
120     error = any_data["returnCode"].getIntValue()
121     if error != 0:
122       self.optim_algo.setError(any_data["errorMessage"].getStringValue())
123
124     data = []
125     outputValues = any_data["outputValues"]
126     #print outputValues
127     for variable in outputValues[0][0]:
128       for i in range(variable.size()):
129         data.append(variable[i].getDoubleValue())
130
131     matrix = numpy.matrix(data).T
132     return matrix
133
134   def Direct(self, X, sync = 1):
135     #print "Call Direct OptimizerHooks"
136     if sync == 1:
137       # 1: Get a unique sample number
138       self.counter_lock.acquire()
139       self.sample_counter += 1
140       local_counter = self.sample_counter
141
142       # 2: Put sample in the job pool
143       sample = self.create_sample(X, "Direct")
144       self.optim_algo.pool.pushInSample(local_counter, sample)
145
146       # 3: Wait
147       while 1:
148         #print "waiting"
149         self.optim_algo.signalMasterAndWait()
150         #print "signal"
151         if self.optim_algo.isTerminationRequested():
152           self.optim_algo.pool.destroyAll()
153           return
154         else:
155           # Get current Id
156           sample_id = self.optim_algo.pool.getCurrentId()
157           if sample_id == local_counter:
158             # 4: Data is ready
159             any_data = self.optim_algo.pool.getOutSample(local_counter)
160             Y = self.get_data_from_any(any_data)
161
162             # 5: Release lock
163             # Have to be done before but need a new implementation
164             # of the optimizer loop
165             self.counter_lock.release()
166             return Y
167     else:
168       #print "sync false is not yet implemented"
169       self.optim_algo.setError("sync == false not yet implemented")
170
171   def Tangent(self, X, sync = 1):
172     #print "Call Tangent OptimizerHooks"
173     if sync == 1:
174       # 1: Get a unique sample number
175       self.counter_lock.acquire()
176       self.sample_counter += 1
177       local_counter = self.sample_counter
178
179       # 2: Put sample in the job pool
180       sample = self.create_sample(X, "Tangent")
181       self.optim_algo.pool.pushInSample(local_counter, sample)
182
183       # 3: Wait
184       while 1:
185         self.optim_algo.signalMasterAndWait()
186         if self.optim_algo.isTerminationRequested():
187           self.optim_algo.pool.destroyAll()
188           return
189         else:
190           # Get current Id
191           sample_id = self.optim_algo.pool.getCurrentId()
192           if sample_id == local_counter:
193             # 4: Data is ready
194             any_data = self.optim_algo.pool.getOutSample(local_counter)
195             Y = self.get_data_from_any(any_data)
196
197             # 5: Release lock
198             # Have to be done before but need a new implementation
199             # of the optimizer loop
200             self.counter_lock.release()
201             return Y
202     else:
203       #print "sync false is not yet implemented"
204       self.optim_algo.setError("sync == false not yet implemented")
205
206   def Adjoint(self, (X, Y), sync = 1):
207     #print "Call Adjoint OptimizerHooks"
208     if sync == 1:
209       # 1: Get a unique sample number
210       self.counter_lock.acquire()
211       self.sample_counter += 1
212       local_counter = self.sample_counter
213
214       # 2: Put sample in the job pool
215       sample = self.create_sample((X,Y), "Adjoint")
216       self.optim_algo.pool.pushInSample(local_counter, sample)
217
218       # 3: Wait
219       while 1:
220         #print "waiting"
221         self.optim_algo.signalMasterAndWait()
222         #print "signal"
223         if self.optim_algo.isTerminationRequested():
224           self.optim_algo.pool.destroyAll()
225           return
226         else:
227           # Get current Id
228           sample_id = self.optim_algo.pool.getCurrentId()
229           if sample_id == local_counter:
230             # 4: Data is ready
231             any_data = self.optim_algo.pool.getOutSample(local_counter)
232             Z = self.get_data_from_any(any_data)
233
234             # 5: Release lock
235             # Have to be done before but need a new implementation
236             # of the optimizer loop
237             self.counter_lock.release()
238             return Z
239     else:
240       #print "sync false is not yet implemented"
241       self.optim_algo.setError("sync == false not yet implemented")
242
243 class AssimilationAlgorithm_asynch(SALOMERuntime.OptimizerAlgASync):
244
245   def __init__(self):
246     SALOMERuntime.RuntimeSALOME_setRuntime()
247     SALOMERuntime.OptimizerAlgASync.__init__(self, None)
248     self.runtime = SALOMERuntime.getSALOMERuntime()
249
250     # Definission des types d'entres et de sorties pour le code de calcul
251     self.tin      = self.runtime.getTypeCode("SALOME_TYPES/ParametricInput")
252     self.tout     = self.runtime.getTypeCode("SALOME_TYPES/ParametricOutput")
253     self.pyobject = self.runtime.getTypeCode("pyobj")
254
255     self.optim_hooks = OptimizerHooks(self)
256
257   # input vient du port algoinit, input est un Any YACS !
258   def initialize(self,input):
259     #print "Algorithme initialize"
260
261     # get the daStudy
262     #print "[Debug] Input is ", input
263     str_da_study = input.getStringValue()
264     self.da_study = pickle.loads(str_da_study)
265     #print "[Debug] da_study is ", self.da_study
266     self.da_study.initAlgorithm()
267     self.ADD = self.da_study.getAssimilationStudy()
268
269   def startToTakeDecision(self):
270     #print "Algorithme startToTakeDecision"
271
272     # Check if ObservationOperator is already set
273     if self.da_study.getObservationOperatorType("Direct") == "Function" or self.da_study.getObservationOperatorType("Tangent") == "Function" or self.da_study.getObservationOperatorType("Adjoint") == "Function" :
274       #print "Set Hooks"
275       # Use proxy function for YACS
276       self.hooks = OptimizerHooks(self)
277       direct = tangent = adjoint = None
278       if self.da_study.getObservationOperatorType("Direct") == "Function":
279         direct = self.hooks.Direct
280       if self.da_study.getObservationOperatorType("Tangent") == "Function" :
281         tangent = self.hooks.Tangent
282       if self.da_study.getObservationOperatorType("Adjoint") == "Function" :
283         adjoint = self.hooks.Adjoint
284
285       # Set ObservationOperator
286       self.ADD.setObservationOperator(asFunction = {"Direct":direct, "Tangent":tangent, "Adjoint":adjoint})
287
288
289     # Start Assimilation Study
290     #print "ADD analyze"
291     self.ADD.analyze()
292
293     # Assimilation Study is finished
294     self.pool.destroyAll()
295
296   def getAlgoResult(self):
297     #print "getAlgoResult"
298     self.ADD.prepare_to_pickle()
299     result = pickle.dumps(self.da_study)
300     return result
301
302   # Obligatoire ???
303   def finish(self):
304     pass
305   def parseFileToInit(self,fileName):
306     pass
307
308   # Fonctions qui ne changent pas
309   def setPool(self,pool):
310     self.pool=pool
311   def getTCForIn(self):
312     return self.tin
313   def getTCForOut(self):
314     return self.tout
315   def getTCForAlgoInit(self):
316     return self.pyobject
317   def getTCForAlgoResult(self):
318     return self.pyobject
319