]> SALOME platform Git repositories - modules/adao.git/blob - src/daSalome/daYacsIntegration/daOptimizerLoop.py
Salome HOME
Generation of a schema with observers in progress
[modules/adao.git] / src / daSalome / daYacsIntegration / daOptimizerLoop.py
1 #-*- coding: utf-8 -*-
2 # Copyright (C) 2010-2011 EDF R&D
3 #
4 # This library is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU Lesser General Public
6 # License as published by the Free Software Foundation; either
7 # version 2.1 of the License.
8 #
9 # This library is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 # Lesser General Public License for more details.
13 #
14 # You should have received a copy of the GNU Lesser General Public
15 # License along with this library; if not, write to the Free Software
16 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
17 #
18 # See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
19 #
20 # Author: André Ribes, andre.ribes@edf.fr, EDF R&D
21
22 import SALOMERuntime
23 import pilot
24 import pickle
25 import numpy
26 import threading
27
28 from daCore.AssimilationStudy import AssimilationStudy
29 from daYacsIntegration import daStudy
30
31
32 class OptimizerHooks:
33
34   def __init__(self, optim_algo):
35     self.optim_algo = optim_algo
36
37     # Gestion du compteur
38     self.sample_counter = 0
39     self.counter_lock = threading.Lock()
40
41   def create_sample(self, data, method):
42     sample = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/ParametricInput'))
43
44     # TODO Input, Output VarList
45     inputVarList  = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
46     outputVarList = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
47     for var in self.optim_algo.da_study.InputVariables:
48       inputVarList.pushBack(var)
49     for var in self.optim_algo.da_study.OutputVariables:
50       outputVarList.pushBack(var)
51     sample.setEltAtRank("inputVarList", inputVarList)
52     sample.setEltAtRank("outputVarList", outputVarList)
53
54     # Les parametres specifiques à ADAO
55     specificParameters = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("SALOME_TYPES/Parameter"))
56     method_name = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/Parameter'))
57     method_name.setEltAtRank("name", "method")
58     method_name.setEltAtRank("value", method)
59     specificParameters.pushBack(method_name)
60     sample.setEltAtRank("specificParameters", specificParameters)
61
62     # Les données
63     # TODO à faire
64     #print data
65     #print data.ndim
66     #print data.shape
67     #print data[:,0]
68     #print data.flatten()
69     #print data.flatten().shape
70
71     variable          = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
72     variable_sequence = pilot.SequenceAny_New(variable.getType())
73     state_sequence    = pilot.SequenceAny_New(variable_sequence.getType())
74     time_sequence     = pilot.SequenceAny_New(state_sequence.getType())
75
76     #print "Input Data", data
77     if isinstance(data, type((1,2))):
78       self.add_parameters(data[0], variable_sequence)
79       self.add_parameters(data[1], variable_sequence, Output=True) # Output == Y
80     else:
81       self.add_parameters(data, variable_sequence)
82     state_sequence.pushBack(variable_sequence)
83     time_sequence.pushBack(state_sequence)
84     sample.setEltAtRank("inputValues", time_sequence)
85     return sample
86
87   def add_parameters(self, data, variable_sequence, Output=False):
88     param = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
89     elt_list = 0 # index dans la liste des arguments
90     val_number = 0 # nbre dans l'argument courant
91     if not Output:
92       val_end = self.optim_algo.da_study.InputVariables[self.optim_algo.da_study.InputVariablesOrder[elt_list]] # nbr de l'argument courant (-1 == tout)
93     else:
94       val_end = self.optim_algo.da_study.OutputVariables[self.optim_algo.da_study.OutputVariablesOrder[elt_list]] # nbr de l'argument courant (-1 == tout)
95
96     it = data.flat
97     for val in it:
98       param.pushBack(val)
99       val_number += 1
100       # Test si l'argument est ok
101       if val_end != -1:
102         if val_number == val_end:
103           variable_sequence.pushBack(param)
104           param = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
105           val_number = 0
106           elt_list += 1
107           if not Output:
108             if elt_list < len(self.optim_algo.da_study.InputVariablesOrder):
109               val_end = self.optim_algo.da_study.InputVariables[self.optim_algo.da_study.InputVariablesOrder[elt_list]]
110             else:
111               break
112           else:
113             if elt_list < len(self.optim_algo.da_study.OutputVariablesOrder):
114               val_end = self.optim_algo.da_study.OutputVariables[self.optim_algo.da_study.OutputVariablesOrder[elt_list]]
115             else:
116               break
117     if val_end == -1:
118       variable_sequence.pushBack(param)
119
120   def get_data_from_any(self, any_data):
121     error = any_data["returnCode"].getIntValue()
122     if error != 0:
123       self.optim_algo.setError(any_data["errorMessage"].getStringValue())
124
125     data = []
126     outputValues = any_data["outputValues"]
127     #print outputValues
128     for variable in outputValues[0][0]:
129       for i in range(variable.size()):
130         data.append(variable[i].getDoubleValue())
131
132     matrix = numpy.matrix(data).T
133     return matrix
134
135   def Direct(self, X, sync = 1):
136     #print "Call Direct OptimizerHooks"
137     if sync == 1:
138       # 1: Get a unique sample number
139       self.counter_lock.acquire()
140       self.sample_counter += 1
141       local_counter = self.sample_counter
142
143       # 2: Put sample in the job pool
144       sample = self.create_sample(X, "Direct")
145       self.optim_algo.pool.pushInSample(local_counter, sample)
146
147       # 3: Wait
148       while 1:
149         #print "waiting"
150         self.optim_algo.signalMasterAndWait()
151         #print "signal"
152         if self.optim_algo.isTerminationRequested():
153           self.optim_algo.pool.destroyAll()
154           return
155         else:
156           # Get current Id
157           sample_id = self.optim_algo.pool.getCurrentId()
158           if sample_id == local_counter:
159             # 4: Data is ready
160             any_data = self.optim_algo.pool.getOutSample(local_counter)
161             Y = self.get_data_from_any(any_data)
162
163             # 5: Release lock
164             # Have to be done before but need a new implementation
165             # of the optimizer loop
166             self.counter_lock.release()
167             return Y
168     else:
169       #print "sync false is not yet implemented"
170       self.optim_algo.setError("sync == false not yet implemented")
171
172   def Tangent(self, X, sync = 1):
173     #print "Call Tangent OptimizerHooks"
174     if sync == 1:
175       # 1: Get a unique sample number
176       self.counter_lock.acquire()
177       self.sample_counter += 1
178       local_counter = self.sample_counter
179
180       # 2: Put sample in the job pool
181       sample = self.create_sample(X, "Tangent")
182       self.optim_algo.pool.pushInSample(local_counter, sample)
183
184       # 3: Wait
185       while 1:
186         self.optim_algo.signalMasterAndWait()
187         if self.optim_algo.isTerminationRequested():
188           self.optim_algo.pool.destroyAll()
189           return
190         else:
191           # Get current Id
192           sample_id = self.optim_algo.pool.getCurrentId()
193           if sample_id == local_counter:
194             # 4: Data is ready
195             any_data = self.optim_algo.pool.getOutSample(local_counter)
196             Y = self.get_data_from_any(any_data)
197
198             # 5: Release lock
199             # Have to be done before but need a new implementation
200             # of the optimizer loop
201             self.counter_lock.release()
202             return Y
203     else:
204       #print "sync false is not yet implemented"
205       self.optim_algo.setError("sync == false not yet implemented")
206
207   def Adjoint(self, (X, Y), sync = 1):
208     #print "Call Adjoint OptimizerHooks"
209     if sync == 1:
210       # 1: Get a unique sample number
211       self.counter_lock.acquire()
212       self.sample_counter += 1
213       local_counter = self.sample_counter
214
215       # 2: Put sample in the job pool
216       sample = self.create_sample((X,Y), "Adjoint")
217       self.optim_algo.pool.pushInSample(local_counter, sample)
218
219       # 3: Wait
220       while 1:
221         #print "waiting"
222         self.optim_algo.signalMasterAndWait()
223         #print "signal"
224         if self.optim_algo.isTerminationRequested():
225           self.optim_algo.pool.destroyAll()
226           return
227         else:
228           # Get current Id
229           sample_id = self.optim_algo.pool.getCurrentId()
230           if sample_id == local_counter:
231             # 4: Data is ready
232             any_data = self.optim_algo.pool.getOutSample(local_counter)
233             Z = self.get_data_from_any(any_data)
234
235             # 5: Release lock
236             # Have to be done before but need a new implementation
237             # of the optimizer loop
238             self.counter_lock.release()
239             return Z
240     else:
241       #print "sync false is not yet implemented"
242       self.optim_algo.setError("sync == false not yet implemented")
243
244 class AssimilationAlgorithm_asynch(SALOMERuntime.OptimizerAlgASync):
245
246   def __init__(self):
247     SALOMERuntime.RuntimeSALOME_setRuntime()
248     SALOMERuntime.OptimizerAlgASync.__init__(self, None)
249     self.runtime = SALOMERuntime.getSALOMERuntime()
250
251     # Definission des types d'entres et de sorties pour le code de calcul
252     self.tin      = self.runtime.getTypeCode("SALOME_TYPES/ParametricInput")
253     self.tout     = self.runtime.getTypeCode("SALOME_TYPES/ParametricOutput")
254     self.pyobject = self.runtime.getTypeCode("pyobj")
255
256     self.optim_hooks = OptimizerHooks(self)
257
258   # input vient du port algoinit, input est un Any YACS !
259   def initialize(self,input):
260     #print "Algorithme initialize"
261
262     # get the daStudy
263     #print "[Debug] Input is ", input
264     str_da_study = input.getStringValue()
265     self.da_study = pickle.loads(str_da_study)
266     #print "[Debug] da_study is ", self.da_study
267     self.da_study.initAlgorithm()
268     self.ADD = self.da_study.getAssimilationStudy()
269
270   def startToTakeDecision(self):
271     #print "Algorithme startToTakeDecision"
272
273     # Check if ObservationOperator is already set
274     if self.da_study.getObservationOperatorType("Direct") == "Function" or self.da_study.getObservationOperatorType("Tangent") == "Function" or self.da_study.getObservationOperatorType("Adjoint") == "Function" :
275       #print "Set Hooks"
276       # Use proxy function for YACS
277       self.hooks = OptimizerHooks(self)
278       direct = tangent = adjoint = None
279       if self.da_study.getObservationOperatorType("Direct") == "Function":
280         direct = self.hooks.Direct
281       if self.da_study.getObservationOperatorType("Tangent") == "Function" :
282         tangent = self.hooks.Tangent
283       if self.da_study.getObservationOperatorType("Adjoint") == "Function" :
284         adjoint = self.hooks.Adjoint
285
286       # Set ObservationOperator
287       self.ADD.setObservationOperator(asFunction = {"Direct":direct, "Tangent":tangent, "Adjoint":adjoint})
288
289
290     # Start Assimilation Study
291     #print "ADD analyze"
292     self.ADD.analyze()
293
294     # Assimilation Study is finished
295     self.pool.destroyAll()
296
297   def getAlgoResult(self):
298     #print "getAlgoResult"
299     self.ADD.prepare_to_pickle()
300     result = pickle.dumps(self.da_study)
301     return result
302
303   # Obligatoire ???
304   def finish(self):
305     pass
306   def parseFileToInit(self,fileName):
307     pass
308
309   # Fonctions qui ne changent pas
310   def setPool(self,pool):
311     self.pool=pool
312   def getTCForIn(self):
313     return self.tin
314   def getTCForOut(self):
315     return self.tout
316   def getTCForAlgoInit(self):
317     return self.pyobject
318   def getTCForAlgoResult(self):
319     return self.pyobject
320