Salome HOME
Aller Ok pour test017 i
[modules/adao.git] / src / daSalome / daYacsIntegration / daOptimizerLoop.py
1 #-*- coding: utf-8 -*-
2
3 import SALOMERuntime
4 import pilot
5 import pickle
6 import numpy
7 import threading
8
9 from daCore.AssimilationStudy import AssimilationStudy
10 from daYacsIntegration import daStudy
11
12 class OptimizerHooks:
13
14   def __init__(self, optim_algo):
15     self.optim_algo = optim_algo
16
17     # Gestion du compteur
18     self.sample_counter = 0
19     self.counter_lock = threading.Lock()
20
21   def create_sample(self, data, method):
22     sample = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/ParametricInput'))
23
24     # TODO Input, Output VarList
25     inputVarList  = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
26     outputVarList = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("string"))
27     inputVarList.pushBack("adao_default")
28     outputVarList.pushBack("adao_default")
29     sample.setEltAtRank("inputVarList", inputVarList)
30     sample.setEltAtRank("outputVarList", outputVarList)
31
32     # Les parametres specifiques à ADAO
33     specificParameters = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("SALOME_TYPES/Parameter"))
34     method_name = pilot.StructAny_New(self.optim_algo.runtime.getTypeCode('SALOME_TYPES/Parameter'))
35     method_name.setEltAtRank("name", "method")
36     method_name.setEltAtRank("value", method)
37     specificParameters.pushBack(method_name)
38     sample.setEltAtRank("specificParameters", specificParameters)
39
40     # Les données
41     # TODO à faire
42     #print data
43     #print data.ndim
44     #print data.shape
45     #print data[:,0]
46     #print data.flatten()
47     #print data.flatten().shape
48
49     parameter_1D = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
50     parameter_2D = pilot.SequenceAny_New(parameter_1D.getType())
51     parameters_3D = pilot.SequenceAny_New(parameter_2D.getType())
52     if isinstance(data, type((1,2))):
53       for dat in data:
54         param = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
55         it = dat.flat
56         for val in it:
57           param.pushBack(val)
58         parameter_2D.pushBack(param)
59     else:
60       param = pilot.SequenceAny_New(self.optim_algo.runtime.getTypeCode("double"))
61       it = data.flat
62       for val in it:
63         param.pushBack(val)
64       parameter_2D.pushBack(param)
65     parameters_3D.pushBack(parameter_2D)
66     sample.setEltAtRank("inputValues", parameters_3D)
67
68     return sample
69
70   def Direct(self, X, sync = 1):
71     print "Call Direct OptimizerHooks"
72     if sync == 1:
73       # 1: Get a unique sample number
74       self.counter_lock.acquire()
75       self.sample_counter += 1
76       local_counter = self.sample_counter
77
78       # 2: Put sample in the job pool
79       sample = self.create_sample(X, "Direct")
80       self.optim_algo.pool.pushInSample(local_counter, sample)
81
82       # 3: Wait
83       while 1:
84         print "waiting"
85         self.optim_algo.signalMasterAndWait()
86         print "signal"
87         if self.optim_algo.isTerminationRequested():
88           self.optim_algo.pool.destroyAll()
89           return
90         else:
91           # Get current Id
92           sample_id = self.optim_algo.pool.getCurrentId()
93           if sample_id == local_counter:
94             # 4: Data is ready
95             matrix_from_pool = self.optim_algo.pool.getOutSample(local_counter).getStringValue()
96
97             # 5: Release lock
98             # Have to be done before but need a new implementation
99             # of the optimizer loop
100             self.counter_lock.release()
101
102             # 6: return results
103             Y = pickle.loads(matrix_from_pool)
104             return Y
105     else:
106       print "sync false is not yet implemented"
107       self.optim_algo.setError("sync == false not yet implemented")
108
109   def Tangent(self, X, sync = 1):
110     print "Call Tangent OptimizerHooks"
111     if sync == 1:
112       # 1: Get a unique sample number
113       self.counter_lock.acquire()
114       self.sample_counter += 1
115       local_counter = self.sample_counter
116
117       # 2: Put sample in the job pool
118       sample = self.create_sample(X, "Tangent")
119       self.optim_algo.pool.pushInSample(local_counter, sample)
120
121       # 3: Wait
122       while 1:
123         self.optim_algo.signalMasterAndWait()
124         if self.optim_algo.isTerminationRequested():
125           self.optim_algo.pool.destroyAll()
126           return
127         else:
128           # Get current Id
129           sample_id = self.optim_algo.pool.getCurrentId()
130           if sample_id == local_counter:
131             # 4: Data is ready
132             matrix_from_pool = self.optim_algo.pool.getOutSample(local_counter).getStringValue()
133
134             # 5: Release lock
135             # Have to be done before but need a new implementation
136             # of the optimizer loop
137             self.counter_lock.release()
138
139             # 6: return results
140             Y = pickle.loads(matrix_from_pool)
141             return Y
142     else:
143       print "sync false is not yet implemented"
144       self.optim_algo.setError("sync == false not yet implemented")
145
146   def Adjoint(self, (X, Y), sync = 1):
147     print "Call Adjoint OptimizerHooks"
148     if sync == 1:
149       # 1: Get a unique sample number
150       self.counter_lock.acquire()
151       self.sample_counter += 1
152       local_counter = self.sample_counter
153
154       # 2: Put sample in the job pool
155       sample = self.create_sample((X,Y), "Adjoint")
156       self.optim_algo.pool.pushInSample(local_counter, sample)
157
158       # 3: Wait
159       while 1:
160         print "waiting"
161         self.optim_algo.signalMasterAndWait()
162         print "signal"
163         if self.optim_algo.isTerminationRequested():
164           self.optim_algo.pool.destroyAll()
165           return
166         else:
167           # Get current Id
168           sample_id = self.optim_algo.pool.getCurrentId()
169           if sample_id == local_counter:
170             # 4: Data is ready
171             matrix_from_pool = self.optim_algo.pool.getOutSample(local_counter).getStringValue()
172
173             # 5: Release lock
174             # Have to be done before but need a new implementation
175             # of the optimizer loop
176             self.counter_lock.release()
177
178             # 6: return results
179             Z = pickle.loads(matrix_from_pool)
180             return Z
181     else:
182       print "sync false is not yet implemented"
183       self.optim_algo.setError("sync == false not yet implemented")
184
185 class AssimilationAlgorithm_asynch(SALOMERuntime.OptimizerAlgASync):
186
187   def __init__(self):
188     SALOMERuntime.RuntimeSALOME_setRuntime()
189     SALOMERuntime.OptimizerAlgASync.__init__(self, None)
190     self.runtime = SALOMERuntime.getSALOMERuntime()
191
192     # Definission des types d'entres et de sorties pour le code de calcul
193     self.tin  = self.runtime.getTypeCode("SALOME_TYPES/ParametricInput")
194     self.tout = self.runtime.getTypeCode("pyobj")
195
196     self.optim_hooks = OptimizerHooks(self)
197
198   # input vient du port algoinit, input est un Any YACS !
199   def initialize(self,input):
200     print "Algorithme initialize"
201
202     # get the daStudy
203     #print "[Debug] Input is ", input
204     str_da_study = input.getStringValue()
205     self.da_study = pickle.loads(str_da_study)
206     #print "[Debug] da_study is ", self.da_study
207     self.da_study.initAlgorithm()
208     self.ADD = self.da_study.getAssimilationStudy()
209
210   def startToTakeDecision(self):
211     print "Algorithme startToTakeDecision"
212
213     # Check if ObservationOperator is already set
214     if self.da_study.getObservationOperatorType("Direct") == "Function" or self.da_study.getObservationOperatorType("Tangent") == "Function" or self.da_study.getObservationOperatorType("Adjoint") == "Function" :
215       print "Set Hooks"
216       # Use proxy function for YACS
217       self.hooks = OptimizerHooks(self)
218       direct = tangent = adjoint = None
219       if self.da_study.getObservationOperatorType("Direct") == "Function":
220         direct = self.hooks.Direct
221       if self.da_study.getObservationOperatorType("Tangent") == "Function" :
222         tangent = self.hooks.Tangent
223       if self.da_study.getObservationOperatorType("Adjoint") == "Function" :
224         adjoint = self.hooks.Adjoint
225
226       # Set ObservationOperator
227       self.ADD.setObservationOperator(asFunction = {"Direct":direct, "Tangent":tangent, "Adjoint":adjoint})
228
229
230     # Start Assimilation Study
231     print "ADD analyze"
232     self.ADD.analyze()
233
234     # Assimilation Study is finished
235     self.pool.destroyAll()
236
237   def getAlgoResult(self):
238     print "getAlgoResult"
239     self.ADD.prepare_to_pickle()
240     result = pickle.dumps(self.da_study)
241     return result
242
243   # Obligatoire ???
244   def finish(self):
245     print "Algorithme finish"
246   def parseFileToInit(self,fileName):
247     print "Algorithme parseFileToInit"
248
249   # Fonctions qui ne changent pas
250   def setPool(self,pool):
251     self.pool=pool
252   def getTCForIn(self):
253     return self.tin
254   def getTCForOut(self):
255     return self.tout
256   def getTCForAlgoInit(self):
257     return self.tout
258   def getTCForAlgoResult(self):
259     return self.tout
260