2 # Copyright (C) 2010-2012 EDF R&D
4 # This library is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU Lesser General Public
6 # License as published by the Free Software Foundation; either
7 # version 2.1 of the License.
9 # This library is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 # Lesser General Public License for more details.
14 # You should have received a copy of the GNU Lesser General Public
15 # License along with this library; if not, write to the Free Software
16 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18 # See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
20 # Author: André Ribes, andre.ribes@edf.fr, EDF R&D
28 from daCore.AssimilationStudy import AssimilationStudy
29 from daYacsIntegration import daStudy
34 def __init__(self, optim_algo, switch_value=-1):
35 self.optim_algo = optim_algo
36 self.switch_value = str(int(switch_value))
38 def create_sample(self, data, method):
39 sample = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/ParametricInput'))
41 # TODO Input, Output VarList
42 inputVarList = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
43 outputVarList = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
44 for var in self.optim_algo.da_study.InputVariables:
45 inputVarList.pushBack(var)
46 for var in self.optim_algo.da_study.OutputVariables:
47 outputVarList.pushBack(var)
48 sample.setEltAtRank("inputVarList", inputVarList)
49 sample.setEltAtRank("outputVarList", outputVarList)
51 # Les parametres specifiques à ADAO
52 specificParameters = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("SALOME_TYPES/Parameter"))
53 method_name = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/Parameter'))
54 method_name.setEltAtRank("name", "method")
55 method_name.setEltAtRank("value", method)
56 specificParameters.pushBack(method_name)
57 # print self.optim_algo.has_observer
58 if self.optim_algo.has_observer:
59 obs_switch = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/Parameter'))
60 obs_switch.setEltAtRank("name", "switch_value")
61 obs_switch.setEltAtRank("value", "1")
62 specificParameters.pushBack(obs_switch)
63 if self.optim_algo.has_evolution_model:
64 obs_switch = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/Parameter'))
65 obs_switch.setEltAtRank("name", "switch_value")
66 obs_switch.setEltAtRank("value", self.switch_value)
67 specificParameters.pushBack(obs_switch)
68 sample.setEltAtRank("specificParameters", specificParameters)
77 #print data.flatten().shape
79 variable = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
80 variable_sequence = pilot.SequenceAny_New(variable.getType())
81 state_sequence = pilot.SequenceAny_New(variable_sequence.getType())
82 time_sequence = pilot.SequenceAny_New(state_sequence.getType())
84 #print "Input Data", data
85 if isinstance(data, type((1,2))):
86 self.add_parameters(data[0], variable_sequence)
87 self.add_parameters(data[1], variable_sequence, Output=True) # Output == Y
89 self.add_parameters(data, variable_sequence)
90 state_sequence.pushBack(variable_sequence)
91 time_sequence.pushBack(state_sequence)
92 sample.setEltAtRank("inputValues", time_sequence)
95 def add_parameters(self, data, variable_sequence, Output=False):
96 param = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
97 elt_list = 0 # index dans la liste des arguments
98 val_number = 0 # nbre dans l'argument courant
100 val_end = self.optim_algo.da_study.InputVariables[self.optim_algo.da_study.InputVariablesOrder[elt_list]] # nbr de l'argument courant (-1 == tout)
102 val_end = self.optim_algo.da_study.OutputVariables[self.optim_algo.da_study.OutputVariablesOrder[elt_list]] # nbr de l'argument courant (-1 == tout)
111 # Test si l'argument est ok
113 if val_number == val_end:
114 variable_sequence.pushBack(param)
115 param = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
119 if elt_list < len(self.optim_algo.da_study.InputVariablesOrder):
120 val_end = self.optim_algo.da_study.InputVariables[self.optim_algo.da_study.InputVariablesOrder[elt_list]]
124 if elt_list < len(self.optim_algo.da_study.OutputVariablesOrder):
125 val_end = self.optim_algo.da_study.OutputVariables[self.optim_algo.da_study.OutputVariablesOrder[elt_list]]
129 variable_sequence.pushBack(param)
131 def get_data_from_any(self, any_data):
132 error = any_data["returnCode"].getIntValue()
134 self.optim_algo.setError(any_data["errorMessage"].getStringValue())
137 outputValues = any_data["outputValues"]
139 for variable in outputValues[0][0]:
140 for i in range(variable.size()):
141 data.append(variable[i].getDoubleValue())
143 matrix = numpy.matrix(data).T
146 def Direct(self, X, sync = 1):
147 # print "Call Direct OptimizerHooks"
149 # 1: Get a unique sample number
150 self.optim_algo.counter_lock.acquire()
151 self.optim_algo.sample_counter += 1
152 local_counter = self.optim_algo.sample_counter
154 # 2: Put sample in the job pool
155 sample = self.create_sample(X, "Direct")
156 self.optim_algo.pool.pushInSample(local_counter, sample)
161 self.optim_algo.signalMasterAndWait()
163 if self.optim_algo.isTerminationRequested():
164 self.optim_algo.pool.destroyAll()
168 sample_id = self.optim_algo.pool.getCurrentId()
169 if sample_id == local_counter:
171 any_data = self.optim_algo.pool.getOutSample(local_counter)
172 Z = self.get_data_from_any(any_data)
175 # Have to be done before but need a new implementation
176 # of the optimizer loop
177 self.optim_algo.counter_lock.release()
180 #print "sync false is not yet implemented"
181 self.optim_algo.setError("sync == false not yet implemented")
183 def Tangent(self, (X, dX), sync = 1):
184 #print "Call Tangent OptimizerHooks"
186 # 1: Get a unique sample number
187 self.optim_algo.counter_lock.acquire()
188 self.optim_algo.sample_counter += 1
189 local_counter = self.optim_algo.sample_counter
191 # 2: Put sample in the job pool
192 sample = self.create_sample((X,dX) , "Tangent")
193 self.optim_algo.pool.pushInSample(local_counter, sample)
197 self.optim_algo.signalMasterAndWait()
198 if self.optim_algo.isTerminationRequested():
199 self.optim_algo.pool.destroyAll()
203 sample_id = self.optim_algo.pool.getCurrentId()
204 if sample_id == local_counter:
206 any_data = self.optim_algo.pool.getOutSample(local_counter)
207 Z = self.get_data_from_any(any_data)
210 # Have to be done before but need a new implementation
211 # of the optimizer loop
212 self.optim_algo.counter_lock.release()
215 #print "sync false is not yet implemented"
216 self.optim_algo.setError("sync == false not yet implemented")
218 def Adjoint(self, (X, Y), sync = 1):
219 #print "Call Adjoint OptimizerHooks"
221 # 1: Get a unique sample number
222 self.optim_algo.counter_lock.acquire()
223 self.optim_algo.sample_counter += 1
224 local_counter = self.optim_algo.sample_counter
226 # 2: Put sample in the job pool
227 sample = self.create_sample((X,Y), "Adjoint")
228 self.optim_algo.pool.pushInSample(local_counter, sample)
233 self.optim_algo.signalMasterAndWait()
235 if self.optim_algo.isTerminationRequested():
236 self.optim_algo.pool.destroyAll()
240 sample_id = self.optim_algo.pool.getCurrentId()
241 if sample_id == local_counter:
243 any_data = self.optim_algo.pool.getOutSample(local_counter)
244 Z = self.get_data_from_any(any_data)
247 # Have to be done before but need a new implementation
248 # of the optimizer loop
249 self.optim_algo.counter_lock.release()
252 #print "sync false is not yet implemented"
253 self.optim_algo.setError("sync == false not yet implemented")
255 class AssimilationAlgorithm_asynch(SALOMERuntime.OptimizerAlgASync):
258 SALOMERuntime.RuntimeSALOME_setRuntime()
259 SALOMERuntime.OptimizerAlgASync.__init__(self, None)
260 self.runtime = SALOMERuntime.getSALOMERuntime()
262 self.has_evolution_model = False
263 self.has_observer = False
265 # Gestion du compteur
266 self.sample_counter = 0
267 self.counter_lock = threading.Lock()
269 # Definission des types d'entres et de sorties pour le code de calcul
270 self.tin = self.runtime.getTypeCode("SALOME_TYPES/ParametricInput")
271 self.tout = self.runtime.getTypeCode("SALOME_TYPES/ParametricOutput")
272 self.pyobject = self.runtime.getTypeCode("pyobj")
274 # Absolument indispensable de définir ainsi "self.optim_hooks"
275 # (sinon on a une "Unknown Exception" sur l'attribut "finish")
276 self.optim_hooks = OptimizerHooks(self)
278 # input vient du port algoinit, input est un Any YACS !
279 def initialize(self,input):
280 #print "Algorithme initialize"
283 #print "[Debug] Input is ", input
284 str_da_study = input.getStringValue()
285 self.da_study = pickle.loads(str_da_study)
286 #print "[Debug] da_study is ", self.da_study
287 self.da_study.initAlgorithm()
288 self.ADD = self.da_study.getAssimilationStudy()
290 def startToTakeDecision(self):
291 #print "Algorithme startToTakeDecision"
293 # Check if ObservationOperator is already set
294 if self.da_study.getObservationOperatorType("Direct") == "Function" or self.da_study.getObservationOperatorType("Tangent") == "Function" or self.da_study.getObservationOperatorType("Adjoint") == "Function" :
296 # Use proxy function for YACS
297 self.hooksOO = OptimizerHooks(self, switch_value=1)
298 direct = tangent = adjoint = None
299 if self.da_study.getObservationOperatorType("Direct") == "Function":
300 direct = self.hooksOO.Direct
301 if self.da_study.getObservationOperatorType("Tangent") == "Function" :
302 tangent = self.hooksOO.Tangent
303 if self.da_study.getObservationOperatorType("Adjoint") == "Function" :
304 adjoint = self.hooksOO.Adjoint
306 # Set ObservationOperator
307 self.ADD.setObservationOperator(asFunction = {"Direct":direct, "Tangent":tangent, "Adjoint":adjoint})
309 # Check if EvolutionModel is already set
310 if self.da_study.getEvolutionModelType("Direct") == "Function" or self.da_study.getEvolutionModelType("Tangent") == "Function" or self.da_study.getEvolutionModelType("Adjoint") == "Function" :
311 self.has_evolution_model = True
313 # Use proxy function for YACS
314 self.hooksEM = OptimizerHooks(self, switch_value=2)
315 direct = tangent = adjoint = None
316 if self.da_study.getEvolutionModelType("Direct") == "Function":
317 direct = self.hooksEM.Direct
318 if self.da_study.getEvolutionModelType("Tangent") == "Function" :
319 tangent = self.hooksEM.Tangent
320 if self.da_study.getEvolutionModelType("Adjoint") == "Function" :
321 adjoint = self.hooksEM.Adjoint
324 self.ADD.setEvolutionModel(asFunction = {"Direct":direct, "Tangent":tangent, "Adjoint":adjoint})
327 for observer_name in self.da_study.observers_dict.keys():
328 # print "observers %s found" % observer_name
329 self.has_observer = True
330 if self.da_study.observers_dict[observer_name]["scheduler"] != "":
331 self.ADD.setDataObserver(observer_name, HookFunction=self.obs, Scheduler = self.da_study.observers_dict[observer_name]["scheduler"], HookParameters = observer_name)
333 self.ADD.setDataObserver(observer_name, HookFunction=self.obs, HookParameters = observer_name)
335 # Start Assimilation Study
336 print "Launching the analyse\n"
339 # Assimilation Study is finished
340 self.pool.destroyAll()
342 def obs(self, var, info):
343 # print "Call observer %s" % info
344 sample = pilot.StructAny_New(self.runtime.getTypeCode('SALOME_TYPES/ParametricInput'))
347 inputVarList = pilot.SequenceAny_New(self.runtime.getTypeCode("string"))
348 outputVarList = pilot.SequenceAny_New(self.runtime.getTypeCode("string"))
349 inputVarList.pushBack("a")
350 outputVarList.pushBack("a")
351 sample.setEltAtRank("inputVarList", inputVarList)
352 sample.setEltAtRank("outputVarList", outputVarList)
353 variable = pilot.SequenceAny_New(self.runtime.getTypeCode("double"))
354 variable_sequence = pilot.SequenceAny_New(variable.getType())
355 state_sequence = pilot.SequenceAny_New(variable_sequence.getType())
356 time_sequence = pilot.SequenceAny_New(state_sequence.getType())
357 variable.pushBack(1.0)
358 variable_sequence.pushBack(variable)
359 state_sequence.pushBack(variable_sequence)
360 time_sequence.pushBack(state_sequence)
361 sample.setEltAtRank("inputValues", time_sequence)
363 # Add observer values in specific parameters
364 specificParameters = pilot.SequenceAny_New(self.runtime.getTypeCode("SALOME_TYPES/Parameter"))
367 obs_switch = pilot.StructAny_New(self.runtime.getTypeCode('SALOME_TYPES/Parameter'))
368 obs_switch.setEltAtRank("name", "switch_value")
369 obs_switch.setEltAtRank("value", self.da_study.observers_dict[info]["number"])
370 specificParameters.pushBack(obs_switch)
373 var_struct = pilot.StructAny_New(self.runtime.getTypeCode('SALOME_TYPES/Parameter'))
374 var_struct.setEltAtRank("name", "var")
376 # Remove Data Observer, so you can ...
377 var.removeDataObserver(self.obs)
379 var_str = pickle.dumps(var)
380 # Add Again Data Observer
381 if self.da_study.observers_dict[info]["scheduler"] != "":
382 self.ADD.setDataObserver(info, HookFunction=self.obs, Scheduler = self.da_study.observers_dict[info]["scheduler"], HookParameters = info)
384 self.ADD.setDataObserver(info, HookFunction=self.obs, HookParameters = info)
385 var_struct.setEltAtRank("value", var_str)
386 specificParameters.pushBack(var_struct)
389 info_struct = pilot.StructAny_New(self.runtime.getTypeCode('SALOME_TYPES/Parameter'))
390 info_struct.setEltAtRank("name", "info")
391 info_struct.setEltAtRank("value", self.da_study.observers_dict[info]["info"])
392 specificParameters.pushBack(info_struct)
394 sample.setEltAtRank("specificParameters", specificParameters)
396 self.counter_lock.acquire()
397 self.sample_counter += 1
398 local_counter = self.sample_counter
399 self.pool.pushInSample(local_counter, sample)
402 import sys, traceback
405 self.signalMasterAndWait()
406 if self.isTerminationRequested():
407 self.pool.destroyAll()
410 sample_id = self.pool.getCurrentId()
411 if sample_id == local_counter:
413 # Have to be done before but need a new implementation
414 # of the optimizer loop
415 self.counter_lock.release()
418 print "Exception in user code:"
420 traceback.print_exc(file=sys.stdout)
423 def getAlgoResult(self):
424 #print "getAlgoResult"
425 self.ADD.prepare_to_pickle()
426 # Remove data observers cannot pickle assimilation study object
427 for observer_name in self.da_study.observers_dict.keys():
428 self.ADD.removeDataObserver(observer_name, self.obs)
429 result = pickle.dumps(self.da_study)
435 def parseFileToInit(self,fileName):
438 # Fonctions qui ne changent pas
439 def setPool(self,pool):
441 def getTCForIn(self):
443 def getTCForOut(self):
445 def getTCForAlgoInit(self):
447 def getTCForAlgoResult(self):