]> SALOME platform Git repositories - modules/adao.git/blob - src/daSalome/daYacsIntegration/daOptimizerLoop.py
Salome HOME
edf727efc41da7c9807e8e28db21429546aa9419
[modules/adao.git] / src / daSalome / daYacsIntegration / daOptimizerLoop.py
1 #-*- coding: utf-8 -*-
2
3 import SALOMERuntime
4 import pilot
5 import pickle
6 import numpy
7 import threading
8
9 from daCore.AssimilationStudy import AssimilationStudy
10 from daYacsIntegration import daStudy
11
12 class OptimizerHooks:
13
14   def __init__(self, optim_algo):
15     self.optim_algo = optim_algo
16
17     # Gestion du compteur
18     self.sample_counter = 0
19     self.counter_lock = threading.Lock()
20
21   def create_sample(self, data, method):
22     sample = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/ParametricInput'))
23
24     # TODO Input, Output VarList
25     inputVarList  = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
26     outputVarList = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
27     inputVarList.pushBack("adao_default")
28     outputVarList.pushBack("adao_default")
29     sample.setEltAtRank("inputVarList", inputVarList)
30     sample.setEltAtRank("outputVarList", outputVarList)
31
32     # Les parametres specifiques à ADAO
33     specificParameters = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("SALOME_TYPES/Parameter"))
34     method_name = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/Parameter'))
35     method_name.setEltAtRank("name", "method")
36     method_name.setEltAtRank("value", method)
37     specificParameters.pushBack(method_name)
38     sample.setEltAtRank("specificParameters", specificParameters)
39
40     # Les données
41     # TODO à faire
42     parameter_1D = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
43     #print data
44     #print data.ndim
45     #print data.shape
46     #print data[:,0]
47     #print data.flatten()
48     #print data.flatten().shape
49     it = data.flat
50     for val in it:
51       print val
52       parameter_1D.pushBack(val)
53     parameter_2D = pilot.SequenceAny_New(parameter_1D.getType())
54     parameter_2D.pushBack(parameter_1D)
55     parameters_3D = pilot.SequenceAny_New(parameter_2D.getType())
56     parameters_3D.pushBack(parameter_2D)
57     sample.setEltAtRank("inputValues", parameters_3D)
58
59     return sample
60
61   def Direct(self, X, sync = 1):
62     print "Call Direct OptimizerHooks"
63     if sync == 1:
64       # 1: Get a unique sample number
65       self.counter_lock.acquire()
66       self.sample_counter += 1
67       local_counter = self.sample_counter
68
69       # 2: Put sample in the job pool
70       sample = self.create_sample(X, "Direct")
71       self.optim_algo.pool.pushInSample(local_counter, sample)
72
73       # 3: Wait
74       while 1:
75         print "waiting"
76         self.optim_algo.signalMasterAndWait()
77         print "signal"
78         if self.optim_algo.isTerminationRequested():
79           self.optim_algo.pool.destroyAll()
80           return
81         else:
82           # Get current Id
83           sample_id = self.optim_algo.pool.getCurrentId()
84           if sample_id == local_counter:
85             # 4: Data is ready
86             matrix_from_pool = self.optim_algo.pool.getOutSample(local_counter).getStringValue()
87
88             # 5: Release lock
89             # Have to be done before but need a new implementation
90             # of the optimizer loop
91             self.counter_lock.release()
92
93             # 6: return results
94             Y = pickle.loads(matrix_from_pool)
95             return Y
96     else:
97       print "sync false is not yet implemented"
98       raise daStudy.daError("sync == false not yet implemented")
99
100   def Tangent(self, X, sync = 1):
101     print "Call Tangent OptimizerHooks"
102     if sync == 1:
103       # 1: Get a unique sample number
104       self.counter_lock.acquire()
105       self.sample_counter += 1
106       local_counter = self.sample_counter
107
108       # 2: Put sample in the job pool
109       sample = self.create_sample(X, "Tangent")
110       self.optim_algo.pool.pushInSample(local_counter, sample)
111
112       # 3: Wait
113       while 1:
114         self.optim_algo.signalMasterAndWait()
115         if self.optim_algo.isTerminationRequested():
116           self.optim_algo.pool.destroyAll()
117           return
118         else:
119           # Get current Id
120           sample_id = self.optim_algo.pool.getCurrentId()
121           if sample_id == local_counter:
122             # 4: Data is ready
123             matrix_from_pool = self.optim_algo.pool.getOutSample(local_counter).getStringValue()
124
125             # 5: Release lock
126             # Have to be done before but need a new implementation
127             # of the optimizer loop
128             self.counter_lock.release()
129
130             # 6: return results
131             Y = pickle.loads(matrix_from_pool)
132             return Y
133     else:
134       print "sync false is not yet implemented"
135       raise daStudy.daError("sync == false not yet implemented")
136
137   def Adjoint(self, (X, Y), sync = 1):
138     print "Call Adjoint OptimizerHooks"
139     if sync == 1:
140       # 1: Get a unique sample number
141       self.counter_lock.acquire()
142       self.sample_counter += 1
143       local_counter = self.sample_counter
144
145       # 2: Put sample in the job pool
146       sample = self.create_sample(Y, "Tangent")
147       self.optim_algo.pool.pushInSample(local_counter, sample)
148
149       # 3: Wait
150       while 1:
151         print "waiting"
152         self.optim_algo.signalMasterAndWait()
153         print "signal"
154         if self.optim_algo.isTerminationRequested():
155           self.optim_algo.pool.destroyAll()
156           return
157         else:
158           # Get current Id
159           sample_id = self.optim_algo.pool.getCurrentId()
160           if sample_id == local_counter:
161             # 4: Data is ready
162             matrix_from_pool = self.optim_algo.pool.getOutSample(local_counter).getStringValue()
163
164             # 5: Release lock
165             # Have to be done before but need a new implementation
166             # of the optimizer loop
167             self.counter_lock.release()
168
169             # 6: return results
170             Z = pickle.loads(matrix_from_pool)
171             return Z
172     else:
173       print "sync false is not yet implemented"
174       raise daStudy.daError("sync == false not yet implemented")
175
176 class AssimilationAlgorithm_asynch(SALOMERuntime.OptimizerAlgASync):
177
178   def __init__(self):
179     SALOMERuntime.RuntimeSALOME_setRuntime()
180     SALOMERuntime.OptimizerAlgASync.__init__(self, None)
181     self.runtime = SALOMERuntime.getSALOMERuntime()
182
183     # Definission des types d'entres et de sorties pour le code de calcul
184     self.tin  = self.runtime.getTypeCode("SALOME_TYPES/ParametricInput")
185     self.tout = self.runtime.getTypeCode("pyobj")
186
187     self.optim_hooks = OptimizerHooks(self)
188
189   # input vient du port algoinit, input est un Any YACS !
190   def initialize(self,input):
191     print "Algorithme initialize"
192
193     # get the daStudy
194     #print "[Debug] Input is ", input
195     str_da_study = input.getStringValue()
196     self.da_study = pickle.loads(str_da_study)
197     #print "[Debug] da_study is ", self.da_study
198     self.da_study.initAlgorithm()
199     self.ADD = self.da_study.getAssimilationStudy()
200
201   def startToTakeDecision(self):
202     print "Algorithme startToTakeDecision"
203
204     # Check if ObservationOperator is already set
205     if self.da_study.getObservationOperatorType("Direct") == "Function" or self.da_study.getObservationOperatorType("Tangent") == "Function" or self.da_study.getObservationOperatorType("Adjoint") == "Function" :
206       print "Set Hooks"
207       # Use proxy function for YACS
208       self.hooks = OptimizerHooks(self)
209       direct = tangent = adjoint = None
210       if self.da_study.getObservationOperatorType("Direct") == "Function":
211         direct = self.hooks.Direct
212       if self.da_study.getObservationOperatorType("Tangent") == "Function" :
213         tangent = self.hooks.Tangent
214       if self.da_study.getObservationOperatorType("Adjoint") == "Function" :
215         adjoint = self.hooks.Adjoint
216
217       # Set ObservationOperator
218       self.ADD.setObservationOperator(asFunction = {"Direct":direct, "Tangent":tangent, "Adjoint":adjoint})
219
220
221     # Start Assimilation Study
222     print "ADD analyze"
223     self.ADD.analyze()
224
225     # Assimilation Study is finished
226     self.pool.destroyAll()
227
228   def getAlgoResult(self):
229     print "getAlgoResult"
230     self.ADD.prepare_to_pickle()
231     result = pickle.dumps(self.da_study)
232     return result
233
234   # Obligatoire ???
235   def finish(self):
236     print "Algorithme finish"
237   def parseFileToInit(self,fileName):
238     print "Algorithme parseFileToInit"
239
240   # Fonctions qui ne changent pas
241   def setPool(self,pool):
242     self.pool=pool
243   def getTCForIn(self):
244     return self.tin
245   def getTCForOut(self):
246     return self.tout
247   def getTCForAlgoInit(self):
248     return self.tout
249   def getTCForAlgoResult(self):
250     return self.tout
251