]> SALOME platform Git repositories - modules/adao.git/blob - src/daSalome/daYacsIntegration/daOptimizerLoop.py
Salome HOME
Adding KalmanFilter and treatment of evolution model
[modules/adao.git] / src / daSalome / daYacsIntegration / daOptimizerLoop.py
1 #-*- coding: utf-8 -*-
2 # Copyright (C) 2010-2011 EDF R&D
3 #
4 # This library is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU Lesser General Public
6 # License as published by the Free Software Foundation; either
7 # version 2.1 of the License.
8 #
9 # This library is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 # Lesser General Public License for more details.
13 #
14 # You should have received a copy of the GNU Lesser General Public
15 # License along with this library; if not, write to the Free Software
16 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
17 #
18 # See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
19 #
20 # Author: André Ribes, andre.ribes@edf.fr, EDF R&D
21
22 import SALOMERuntime
23 import pilot
24 import pickle
25 import numpy
26 import threading
27
28 from daCore.AssimilationStudy import AssimilationStudy
29 from daYacsIntegration import daStudy
30
31
32 class OptimizerHooks:
33
34   def __init__(self, optim_algo, switch_value=-1):
35     self.optim_algo = optim_algo
36     self.switch_value = str(int(switch_value))
37
38   def create_sample(self, data, method):
39     sample = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/ParametricInput'))
40
41     # TODO Input, Output VarList
42     inputVarList  = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
43     outputVarList = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
44     for var in self.optim_algo.da_study.InputVariables:
45       inputVarList.pushBack(var)
46     for var in self.optim_algo.da_study.OutputVariables:
47       outputVarList.pushBack(var)
48     sample.setEltAtRank("inputVarList", inputVarList)
49     sample.setEltAtRank("outputVarList", outputVarList)
50
51     # Les parametres specifiques à ADAO
52     specificParameters = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("SALOME_TYPES/Parameter"))
53     method_name = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/Parameter'))
54     method_name.setEltAtRank("name", "method")
55     method_name.setEltAtRank("value", method)
56     specificParameters.pushBack(method_name)
57     # print self.optim_algo.has_observer
58     if self.optim_algo.has_observer:
59       obs_switch = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/Parameter'))
60       obs_switch.setEltAtRank("name", "switch_value")
61       obs_switch.setEltAtRank("value", "1")
62       specificParameters.pushBack(obs_switch)
63     if self.optim_algo.has_evolution_model:
64       obs_switch = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/Parameter'))
65       obs_switch.setEltAtRank("name", "switch_value")
66       obs_switch.setEltAtRank("value", self.switch_value)
67       specificParameters.pushBack(obs_switch)
68     sample.setEltAtRank("specificParameters", specificParameters)
69
70     # Les données
71     # TODO à faire
72     #print data
73     #print data.ndim
74     #print data.shape
75     #print data[:,0]
76     #print data.flatten()
77     #print data.flatten().shape
78
79     variable          = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
80     variable_sequence = pilot.SequenceAny_New(variable.getType())
81     state_sequence    = pilot.SequenceAny_New(variable_sequence.getType())
82     time_sequence     = pilot.SequenceAny_New(state_sequence.getType())
83
84     #print "Input Data", data
85     if isinstance(data, type((1,2))):
86       self.add_parameters(data[0], variable_sequence)
87       self.add_parameters(data[1], variable_sequence, Output=True) # Output == Y
88     else:
89       self.add_parameters(data, variable_sequence)
90     state_sequence.pushBack(variable_sequence)
91     time_sequence.pushBack(state_sequence)
92     sample.setEltAtRank("inputValues", time_sequence)
93     return sample
94
95   def add_parameters(self, data, variable_sequence, Output=False):
96     param = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
97     elt_list = 0 # index dans la liste des arguments
98     val_number = 0 # nbre dans l'argument courant
99     if not Output:
100       val_end = self.optim_algo.da_study.InputVariables[self.optim_algo.da_study.InputVariablesOrder[elt_list]] # nbr de l'argument courant (-1 == tout)
101     else:
102       val_end = self.optim_algo.da_study.OutputVariables[self.optim_algo.da_study.OutputVariablesOrder[elt_list]] # nbr de l'argument courant (-1 == tout)
103
104     if data is None:
105         it = []
106     else:
107         it = data.flat
108     for val in it:
109       param.pushBack(val)
110       val_number += 1
111       # Test si l'argument est ok
112       if val_end != -1:
113         if val_number == val_end:
114           variable_sequence.pushBack(param)
115           param = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
116           val_number = 0
117           elt_list += 1
118           if not Output:
119             if elt_list < len(self.optim_algo.da_study.InputVariablesOrder):
120               val_end = self.optim_algo.da_study.InputVariables[self.optim_algo.da_study.InputVariablesOrder[elt_list]]
121             else:
122               break
123           else:
124             if elt_list < len(self.optim_algo.da_study.OutputVariablesOrder):
125               val_end = self.optim_algo.da_study.OutputVariables[self.optim_algo.da_study.OutputVariablesOrder[elt_list]]
126             else:
127               break
128     if val_end == -1:
129       variable_sequence.pushBack(param)
130
131   def get_data_from_any(self, any_data):
132     error = any_data["returnCode"].getIntValue()
133     if error != 0:
134       self.optim_algo.setError(any_data["errorMessage"].getStringValue())
135
136     data = []
137     outputValues = any_data["outputValues"]
138     #print outputValues
139     for variable in outputValues[0][0]:
140       for i in range(variable.size()):
141         data.append(variable[i].getDoubleValue())
142
143     matrix = numpy.matrix(data).T
144     return matrix
145
146   def Direct(self, X, sync = 1):
147     # print "Call Direct OptimizerHooks"
148     if sync == 1:
149       # 1: Get a unique sample number
150       self.optim_algo.counter_lock.acquire()
151       self.optim_algo.sample_counter += 1
152       local_counter = self.optim_algo.sample_counter
153
154       # 2: Put sample in the job pool
155       sample = self.create_sample(X, "Direct")
156       self.optim_algo.pool.pushInSample(local_counter, sample)
157
158       # 3: Wait
159       while 1:
160         #print "waiting"
161         self.optim_algo.signalMasterAndWait()
162         #print "signal"
163         if self.optim_algo.isTerminationRequested():
164           self.optim_algo.pool.destroyAll()
165           return
166         else:
167           # Get current Id
168           sample_id = self.optim_algo.pool.getCurrentId()
169           if sample_id == local_counter:
170             # 4: Data is ready
171             any_data = self.optim_algo.pool.getOutSample(local_counter)
172             Z = self.get_data_from_any(any_data)
173
174             # 5: Release lock
175             # Have to be done before but need a new implementation
176             # of the optimizer loop
177             self.optim_algo.counter_lock.release()
178             return Z
179     else:
180       #print "sync false is not yet implemented"
181       self.optim_algo.setError("sync == false not yet implemented")
182
183   def Tangent(self, (X, dX), sync = 1):
184     #print "Call Tangent OptimizerHooks"
185     if sync == 1:
186       # 1: Get a unique sample number
187       self.optim_algo.counter_lock.acquire()
188       self.optim_algo.sample_counter += 1
189       local_counter = self.optim_algo.sample_counter
190
191       # 2: Put sample in the job pool
192       sample = self.create_sample((X,dX) , "Tangent")
193       self.optim_algo.pool.pushInSample(local_counter, sample)
194
195       # 3: Wait
196       while 1:
197         self.optim_algo.signalMasterAndWait()
198         if self.optim_algo.isTerminationRequested():
199           self.optim_algo.pool.destroyAll()
200           return
201         else:
202           # Get current Id
203           sample_id = self.optim_algo.pool.getCurrentId()
204           if sample_id == local_counter:
205             # 4: Data is ready
206             any_data = self.optim_algo.pool.getOutSample(local_counter)
207             Z = self.get_data_from_any(any_data)
208
209             # 5: Release lock
210             # Have to be done before but need a new implementation
211             # of the optimizer loop
212             self.optim_algo.counter_lock.release()
213             return Z
214     else:
215       #print "sync false is not yet implemented"
216       self.optim_algo.setError("sync == false not yet implemented")
217
218   def Adjoint(self, (X, Y), sync = 1):
219     #print "Call Adjoint OptimizerHooks"
220     if sync == 1:
221       # 1: Get a unique sample number
222       self.optim_algo.counter_lock.acquire()
223       self.optim_algo.sample_counter += 1
224       local_counter = self.optim_algo.sample_counter
225
226       # 2: Put sample in the job pool
227       sample = self.create_sample((X,Y), "Adjoint")
228       self.optim_algo.pool.pushInSample(local_counter, sample)
229
230       # 3: Wait
231       while 1:
232         #print "waiting"
233         self.optim_algo.signalMasterAndWait()
234         #print "signal"
235         if self.optim_algo.isTerminationRequested():
236           self.optim_algo.pool.destroyAll()
237           return
238         else:
239           # Get current Id
240           sample_id = self.optim_algo.pool.getCurrentId()
241           if sample_id == local_counter:
242             # 4: Data is ready
243             any_data = self.optim_algo.pool.getOutSample(local_counter)
244             Z = self.get_data_from_any(any_data)
245
246             # 5: Release lock
247             # Have to be done before but need a new implementation
248             # of the optimizer loop
249             self.optim_algo.counter_lock.release()
250             return Z
251     else:
252       #print "sync false is not yet implemented"
253       self.optim_algo.setError("sync == false not yet implemented")
254
255 class AssimilationAlgorithm_asynch(SALOMERuntime.OptimizerAlgASync):
256
257   def __init__(self):
258     SALOMERuntime.RuntimeSALOME_setRuntime()
259     SALOMERuntime.OptimizerAlgASync.__init__(self, None)
260     self.runtime = SALOMERuntime.getSALOMERuntime()
261
262     self.has_evolution_model = False
263     self.has_observer = False
264
265     # Gestion du compteur
266     self.sample_counter = 0
267     self.counter_lock = threading.Lock()
268
269     # Definission des types d'entres et de sorties pour le code de calcul
270     self.tin      = self.runtime.getTypeCode("SALOME_TYPES/ParametricInput")
271     self.tout     = self.runtime.getTypeCode("SALOME_TYPES/ParametricOutput")
272     self.pyobject = self.runtime.getTypeCode("pyobj")
273
274     # Absolument indispensable de définir ainsi "self.optim_hooks"
275     # (sinon on a une "Unknown Exception" sur l'attribut "finish")
276     self.optim_hooks = OptimizerHooks(self)
277
278   # input vient du port algoinit, input est un Any YACS !
279   def initialize(self,input):
280     #print "Algorithme initialize"
281
282     # get the daStudy
283     #print "[Debug] Input is ", input
284     str_da_study = input.getStringValue()
285     self.da_study = pickle.loads(str_da_study)
286     #print "[Debug] da_study is ", self.da_study
287     self.da_study.initAlgorithm()
288     self.ADD = self.da_study.getAssimilationStudy()
289
290   def startToTakeDecision(self):
291     #print "Algorithme startToTakeDecision"
292
293     # Check if ObservationOperator is already set
294     if self.da_study.getObservationOperatorType("Direct") == "Function" or self.da_study.getObservationOperatorType("Tangent") == "Function" or self.da_study.getObservationOperatorType("Adjoint") == "Function" :
295       #print "Set Hooks"
296       # Use proxy function for YACS
297       self.hooksOO = OptimizerHooks(self, switch_value=1)
298       direct = tangent = adjoint = None
299       if self.da_study.getObservationOperatorType("Direct") == "Function":
300         direct = self.hooksOO.Direct
301       if self.da_study.getObservationOperatorType("Tangent") == "Function" :
302         tangent = self.hooksOO.Tangent
303       if self.da_study.getObservationOperatorType("Adjoint") == "Function" :
304         adjoint = self.hooksOO.Adjoint
305
306       # Set ObservationOperator
307       self.ADD.setObservationOperator(asFunction = {"Direct":direct, "Tangent":tangent, "Adjoint":adjoint})
308
309     # Check if EvolutionModel is already set
310     if self.da_study.getEvolutionModelType("Direct") == "Function" or self.da_study.getEvolutionModelType("Tangent") == "Function" or self.da_study.getEvolutionModelType("Adjoint") == "Function" :
311       self.has_evolution_model = True
312       #print "Set Hooks"
313       # Use proxy function for YACS
314       self.hooksEM = OptimizerHooks(self, switch_value=2)
315       direct = tangent = adjoint = None
316       if self.da_study.getEvolutionModelType("Direct") == "Function":
317         direct = self.hooksEM.Direct
318       if self.da_study.getEvolutionModelType("Tangent") == "Function" :
319         tangent = self.hooksEM.Tangent
320       if self.da_study.getEvolutionModelType("Adjoint") == "Function" :
321         adjoint = self.hooksEM.Adjoint
322
323       # Set EvolutionModel
324       self.ADD.setEvolutionModel(asFunction = {"Direct":direct, "Tangent":tangent, "Adjoint":adjoint})
325
326     # Set Observers
327     for observer_name in self.da_study.observers_dict.keys():
328       # print "observers %s found" % observer_name
329       self.has_observer = True
330       if self.da_study.observers_dict[observer_name]["scheduler"] != "":
331         self.ADD.setDataObserver(observer_name, HookFunction=self.obs, Scheduler = self.da_study.observers_dict[observer_name]["scheduler"], HookParameters = observer_name)
332       else:
333         self.ADD.setDataObserver(observer_name, HookFunction=self.obs, HookParameters = observer_name)
334
335     # Start Assimilation Study
336     print "Launching the analyse\n"
337     self.ADD.analyze()
338
339     # Assimilation Study is finished
340     self.pool.destroyAll()
341
342   def obs(self, var, info):
343     # print "Call observer %s" % info
344     sample = pilot.StructAny_New(self.runtime.getTypeCode('SALOME_TYPES/ParametricInput'))
345
346     # Fake data
347     inputVarList  = pilot.SequenceAny_New(self.runtime.getTypeCode("string"))
348     outputVarList = pilot.SequenceAny_New(self.runtime.getTypeCode("string"))
349     inputVarList.pushBack("a")
350     outputVarList.pushBack("a")
351     sample.setEltAtRank("inputVarList", inputVarList)
352     sample.setEltAtRank("outputVarList", outputVarList)
353     variable          = pilot.SequenceAny_New(self.runtime.getTypeCode("double"))
354     variable_sequence = pilot.SequenceAny_New(variable.getType())
355     state_sequence    = pilot.SequenceAny_New(variable_sequence.getType())
356     time_sequence     = pilot.SequenceAny_New(state_sequence.getType())
357     variable.pushBack(1.0)
358     variable_sequence.pushBack(variable)
359     state_sequence.pushBack(variable_sequence)
360     time_sequence.pushBack(state_sequence)
361     sample.setEltAtRank("inputValues", time_sequence)
362
363     # Add observer values in specific parameters
364     specificParameters = pilot.SequenceAny_New(self.runtime.getTypeCode("SALOME_TYPES/Parameter"))
365
366     # Switch Value
367     obs_switch = pilot.StructAny_New(self.runtime.getTypeCode('SALOME_TYPES/Parameter'))
368     obs_switch.setEltAtRank("name", "switch_value")
369     obs_switch.setEltAtRank("value", self.da_study.observers_dict[info]["number"])
370     specificParameters.pushBack(obs_switch)
371
372     # Var
373     var_struct = pilot.StructAny_New(self.runtime.getTypeCode('SALOME_TYPES/Parameter'))
374     var_struct.setEltAtRank("name", "var")
375
376     # Remove Data Observer, so you can ...
377     var.removeDataObserver(self.obs)
378     # Pickle then ...
379     var_str = pickle.dumps(var)
380     # Add Again Data Observer
381     if self.da_study.observers_dict[info]["scheduler"] != "":
382       self.ADD.setDataObserver(info, HookFunction=self.obs, Scheduler = self.da_study.observers_dict[info]["scheduler"], HookParameters = info)
383     else:
384       self.ADD.setDataObserver(info, HookFunction=self.obs, HookParameters = info)
385     var_struct.setEltAtRank("value", var_str)
386     specificParameters.pushBack(var_struct)
387
388     # Info
389     info_struct = pilot.StructAny_New(self.runtime.getTypeCode('SALOME_TYPES/Parameter'))
390     info_struct.setEltAtRank("name", "info")
391     info_struct.setEltAtRank("value", self.da_study.observers_dict[info]["info"])
392     specificParameters.pushBack(info_struct)
393
394     sample.setEltAtRank("specificParameters", specificParameters)
395
396     self.counter_lock.acquire()
397     self.sample_counter += 1
398     local_counter = self.sample_counter
399     self.pool.pushInSample(local_counter, sample)
400
401     # Wait
402     import sys, traceback
403     try:
404       while 1:
405         self.signalMasterAndWait()
406         if self.isTerminationRequested():
407           self.pool.destroyAll()
408         else:
409           # Get current Id
410           sample_id = self.pool.getCurrentId()
411           if sample_id == local_counter:
412             # 5: Release lock
413             # Have to be done before but need a new implementation
414             # of the optimizer loop
415             self.counter_lock.release()
416             break
417     except:
418       print "Exception in user code:"
419       print '-'*60
420       traceback.print_exc(file=sys.stdout)
421       print '-'*60
422
423   def getAlgoResult(self):
424     #print "getAlgoResult"
425     self.ADD.prepare_to_pickle()
426     # Remove data observers cannot pickle assimilation study object
427     for observer_name in self.da_study.observers_dict.keys():
428       self.ADD.removeDataObserver(observer_name, self.obs)
429     result = pickle.dumps(self.da_study)
430     return result
431
432   # Obligatoire ???
433   def finish(self):
434     pass
435   def parseFileToInit(self,fileName):
436     pass
437
438   # Fonctions qui ne changent pas
439   def setPool(self,pool):
440     self.pool=pool
441   def getTCForIn(self):
442     return self.tin
443   def getTCForOut(self):
444     return self.tout
445   def getTCForAlgoInit(self):
446     return self.pyobject
447   def getTCForAlgoResult(self):
448     return self.pyobject
449