Salome HOME
Copyright update 2021
[modules/yacs.git] / src / py2yacs / yacsdecorator.py
1 #!/usr/bin/env python3
2 # Copyright (C) 2006-2021  CEA/DEN, EDF R&D
3 #
4 # This library is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU Lesser General Public
6 # License as published by the Free Software Foundation; either
7 # version 2.1 of the License, or (at your option) any later version.
8 #
9 # This library is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 # Lesser General Public License for more details.
13 #
14 # You should have received a copy of the GNU Lesser General Public
15 # License along with this library; if not, write to the Free Software
16 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
17 #
18 # See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
19 #
20 import sys
21 import json
22
23 # this is a pointer to the module object instance itself.
24 this_module = sys.modules[__name__]
25
26 class OutputPort:
27   def __init__(self, yacs_node, yacs_port):
28     self.yacs_node = yacs_node
29     self.yacs_port = yacs_port
30
31 class LeafNodeType:
32   def __init__(self, path, fn_name, inputs, outputs, container_name):
33     self.path = path
34     self.fn_name = fn_name
35     self.inputs = inputs
36     self.outputs = outputs
37     self.container_name = container_name
38     self.number = 0
39
40   def newName(self):
41     name = self.fn_name + "_" + str(self.number)
42     self.number += 1
43     return name
44
45   def createNewNode(self, inputs):
46     """
47     inputs : dict {input_name:value}
48     """
49     generator = getGenerator()
50     output_ports = generator.createScriptNode(self, inputs)
51     return output_ports
52
53 class ContainerProperties():
54   def __init__(self, name, nb_cores, use_cache):
55     self.name = name
56     self.nb_cores = nb_cores
57     self.use_cache = use_cache
58
59 def jsonContainerEncoder(obj):
60   if isinstance(obj, ContainerProperties) :
61     return {
62             "name": obj.name,
63             "nb_cores": obj.nb_cores,
64             "use_cache": obj.use_cache }
65   else:
66     raise TypeError("Cannot serialize object "+str(obj))
67
68 def jsonContainerDecoder(dct):
69   if "name" in dct and "nb_cores" in dct and "use_cache" in dct :
70     return ContainerProperties(dct["name"], dct["nb_cores"], dct["use_cache"])
71   return dct
72
73 class ContainerManager():
74   defaultContainerName = "default_container"
75   def __init__(self):
76     self._containers = []
77     self._defaultContainer = ContainerProperties(
78                                 ContainerManager.defaultContainerName, 0, False)
79     self._containers.append(self._defaultContainer)
80
81   def setDefaultContainer(self, nb_cores, use_cache):
82     self._defaultContainer.nb_cores = nb_cores
83     self._defaultContainer.use_cache = use_cache
84
85   def loadFile(self, file_path):
86     with open(file_path, 'r') as json_file:
87       self._containers = json.load(json_file, object_hook=jsonContainerDecoder)
88     try:
89       self._defaultContainer = next(cont for cont in self._containers
90                           if cont.name == ContainerManager.defaultContainerName)
91     except StopIteration:
92       self._defaultContainer = ContainerProperties(
93                                 ContainerManager.defaultContainerName, 0, False)
94       self._containers.append(self._defaultContainer)
95
96   def saveFile(self, file_path):
97     with open(file_path, 'w') as json_file:
98       json.dump(self._containers, json_file,
99                 indent=2, default=jsonContainerEncoder)
100
101   def addContainer(self, name, nb_cores, use_cache):
102     try:
103       # if the name already exists
104       obj = next(cont for cont in self._containers if cont.name == name)
105       obj.nb_cores = nb_cores
106       obj.use_cache = use_cache
107     except StopIteration:
108       # new container
109       self._containers.append(ContainerProperties(name, nb_cores, use_cache))
110
111   def getContainer(self, name):
112     ret = self._defaultContainer
113     try:
114       ret = next(cont for cont in self._containers if cont.name == name)
115     except StopIteration:
116       # not found
117       pass
118     return ret
119
120 class SchemaGenerator():
121   """
122   Link to Salome for YACS schema generation.
123   """
124   def __init__(self):
125     import SALOMERuntime
126     SALOMERuntime.RuntimeSALOME.setRuntime()
127     self.runtime = SALOMERuntime.getSALOMERuntime()
128     self.proc = self.runtime.createProc("GeneratedSchema")
129     self.proc.setProperty("executor","workloadmanager")
130     self.containers = {}
131     self.pyobjtype = self.runtime.getTypeCode("pyobj")
132     self.seqpyobjtype = self.runtime.getTypeCode("seqpyobj")
133     self.bloc_stack = [self.proc]
134     self.name_index = 0 # used to ensure unique names
135     self.container_manager = ContainerManager()
136
137   def newName(self, name):
138     new_name = name + "_" + str(self.name_index)
139     self.name_index += 1
140     return new_name
141
142   def getContextName(self):
143     context_name = ""
144     if len(self.bloc_stack) > 1:
145       # We are in a block
146       block_path = ".".join([ b.getName() for b in self.bloc_stack[1:] ])
147       context_name = block_path + "."
148     return context_name
149
150   def getContainer(self, container_type):
151     """
152     A new container may be created if it does not already exist for this type.
153     """
154     container_properties = self.container_manager.getContainer(container_type)
155     if container_type not in self.containers:
156       cont=self.proc.createContainer(container_properties.name,"Salome")
157       cont.setProperty("nb_parallel_procs", str(container_properties.nb_cores))
158       cont.setProperty("type","multi")
159       cont.usePythonCache(container_properties.use_cache)
160       cont.attachOnCloning()
161       self.containers[container_type] = cont
162     return self.containers[container_type]
163
164   def createScript(self, file_path, function_name, inputs, outputs):
165     import inspect
166     stack = inspect.stack()
167     stack_info = "Call stack\n"
168     # skip the first 4 levels in the stack
169     for level in stack[4:-1] :
170       info = inspect.getframeinfo(level[0])
171       stack_info += "file: {}, line: {}, function: {}, context: {}\n".format(
172         info.filename, info.lineno, info.function, info.code_context)
173      
174     if len(outputs) == 0:
175       result = ""
176     elif len(outputs) == 1:
177       result = "{} = ".format(outputs[0])
178     else:
179       result = ",".join(outputs)
180       result += " = "
181
182     if len(inputs) == 0:
183       params = ""
184     elif len(inputs) == 1:
185       params = "{} ".format(inputs[0])
186     else:
187       params = ",".join(inputs)
188     
189     script = """'''
190 {call_stack}
191 '''
192 import yacstools
193 study_function = yacstools.getFunction("{file_path}", "{function_name}")
194 {result}study_function({parameters})
195 """.format(call_stack=stack_info,
196            file_path=file_path,
197            function_name=function_name,
198            result=result,
199            parameters=params)
200     return script
201
202   def createScriptNode(self, leaf, input_values):
203     node_name = leaf.newName()
204     file_path = leaf.path
205     function_name = leaf.fn_name
206     inputs = leaf.inputs # names
207     outputs = leaf.outputs # names
208     script = self.createScript(file_path, function_name, inputs, outputs)
209     container = self.getContainer(leaf.container_name)
210     new_node = self.runtime.createScriptNode("Salome", node_name)
211     new_node.setContainer(container)
212     new_node.setExecutionMode("remote")
213     new_node.setScript(script)
214     self.bloc_stack[-1].edAddChild(new_node)
215     # create ports
216     for p in inputs:
217       new_node.edAddInputPort(p, self.pyobjtype)
218     output_obj_list = []
219     for p in outputs:
220       port = new_node.edAddOutputPort(p, self.pyobjtype)
221       output_obj_list.append(OutputPort(new_node, port))
222     # create links
223     for k,v in input_values.items():
224       input_port = new_node.getInputPort(k)
225       if isinstance(v, OutputPort):
226         self.proc.edAddLink(v.yacs_port, input_port)
227         self.addCFLink(v.yacs_node, new_node)
228         #self.proc.edAddCFLink(v.yacs_node, new_node)
229       else:
230         input_port.edInitPy(v)
231     # return output ports
232     result = None
233     if len(output_obj_list) == 1 :
234       result = output_obj_list[0]
235     elif len(output_obj_list) > 1 :
236       result = tuple(output_obj_list)
237     return result
238
239   def beginForeach(self, fn_name, input_values):
240     foreach_name = self.newName(fn_name)
241     new_foreach = self.runtime.createForEachLoopDyn(foreach_name,
242                                                     self.pyobjtype)
243     self.bloc_stack[-1].edAddChild(new_foreach)
244     bloc_name = "bloc_"+foreach_name
245     new_block = self.runtime.createBloc(bloc_name)
246     new_foreach.edAddChild(new_block)
247     sample_port = new_foreach.edGetSamplePort()
248     input_list_port = new_foreach.edGetSeqOfSamplesPort()
249     if isinstance(input_values, OutputPort):
250       # we need a conversion node pyobj -> seqpyobj
251       conversion_node = self.runtime.createScriptNode("Salome",
252                                                       "input_"+foreach_name)
253       port_name = "val"
254       input_port = conversion_node.edAddInputPort(port_name, self.pyobjtype)
255       output_port = conversion_node.edAddOutputPort(port_name,
256                                                     self.seqpyobjtype)
257       conversion_node.setExecutionMode("local") # no need for container
258       # no script, the same variable for input and output
259       conversion_node.setScript("")
260       self.bloc_stack[-1].edAddChild(conversion_node)
261       self.proc.edAddLink(input_values.yacs_port, input_port)
262       self.addCFLink(input_values.yacs_node, conversion_node)
263       self.proc.edAddLink(output_port, input_list_port)
264       # No need to look for ancestors. Both nodes are on the same level.
265       self.proc.edAddCFLink(conversion_node, new_foreach)
266     else:
267       input_list_port.edInitPy(list(input_values))
268     self.bloc_stack.append(new_foreach)
269     self.bloc_stack.append(new_block)
270     return OutputPort(new_foreach, sample_port)
271
272   def endForeach(self, outputs):
273     self.bloc_stack.pop() # remove the block
274     for_each_node = self.bloc_stack.pop() # remove the foreach
275     converted_ret = None
276     if outputs is not None:
277       # We need a conversion node seqpyobj -> pyobj
278       if type(outputs) is tuple:
279         list_out = list(outputs)
280       else:
281         list_out = [outputs]
282       conversion_node_name = "output_" + for_each_node.getName()
283       conversion_node = self.runtime.createScriptNode("Salome",
284                                                       conversion_node_name)
285       conversion_node.setExecutionMode("local") # no need for container
286       conversion_node.setScript("")
287       self.bloc_stack[-1].edAddChild(conversion_node)
288       list_ret = []
289       idx_name = 0 # for unique port names
290       for port in list_out :
291         if isinstance(port, OutputPort):
292           port_name = port.yacs_port.getName() + "_" + str(idx_name)
293           idx_name += 1
294           input_port = conversion_node.edAddInputPort(port_name,
295                                                       self.seqpyobjtype)
296           output_port = conversion_node.edAddOutputPort(port_name,
297                                                         self.pyobjtype)
298           self.proc.edAddLink(port.yacs_port, input_port)
299           list_ret.append(OutputPort(conversion_node, output_port))
300         else:
301           list_ret.append(port)
302       self.proc.edAddCFLink(for_each_node, conversion_node)
303       if len(list_ret) > 1 :
304         converted_ret = tuple(list_ret)
305       else:
306         converted_ret = list_ret[0]
307     return converted_ret
308
309   def dump(self, file_path):
310     self.proc.saveSchema(file_path)
311
312   def addCFLink(self, node_from, node_to):
313     commonAncestor = self.proc.getLowestCommonAncestor(node_from, node_to)
314     if node_from.getName() != commonAncestor.getName() :
315       link_from = node_from
316       while link_from.getFather().getName() != commonAncestor.getName() :
317         link_from = link_from.getFather()
318       link_to = node_to
319       while link_to.getFather().getName() != commonAncestor.getName() :
320         link_to = link_to.getFather()
321       self.proc.edAddCFLink(link_from, link_to)
322     else:
323       # from node is ancestor of to node. No CF link needed.
324       pass
325
326 _generator = None
327
328 _default_mode = "Default"
329 _yacs_mode = "YACS"
330 _exec_mode = _default_mode
331
332 # Public functions
333
334 def getGenerator():
335   """
336   Get the singleton object.
337   """
338   if this_module._generator is None:
339     if this_module._exec_mode == this_module._yacs_mode:
340       this_module._generator = SchemaGenerator()
341   return this_module._generator
342
343 def activateYacsMode():
344   this_module._exec_mode = this_module._yacs_mode
345
346 def activateDefaultMode():
347   this_module._exec_mode = this_module._default_mode
348
349 def loadContainers(file_path):
350   getGenerator().container_manager.loadFile(file_path)
351
352 def export(path):
353   if this_module._exec_mode == this_module._yacs_mode :
354     getGenerator().dump(path)
355
356 # Decorators
357 class LeafDecorator():
358   def __init__(self, container_name):
359     self.container_name = container_name
360
361   def __call__(self, f):
362     if this_module._exec_mode == this_module._default_mode:
363       return f
364     co = f.__code__
365     import py2yacs
366     props = py2yacs.function_properties(co.co_filename, co.co_name)
367     nodeType = LeafNodeType(co.co_filename, co.co_name,
368                             props.inputs, props.outputs, self.container_name)
369     def my_func(*args, **kwargs):
370       if len(args) + len(kwargs) != len(nodeType.inputs):
371         mes = "Wrong number of arguments when calling function '{}'.\n".format(
372                                                                 nodeType.fn_name)
373         mes += " {} arguments expected and {} arguments found.\n".format(
374                                     len(nodeType.inputs), len(args) + len(kwargs))
375         raise Exception(mes)
376       idx = 0
377       args_dic = {}
378       for a in args:
379         args_dic[nodeType.inputs[idx]] = a
380         idx += 1
381       for k,v in kwargs.items():
382         args_dic[k] = v
383       if len(args_dic) != len(nodeType.inputs):
384         mes="Wrong arguments when calling function {}.\n".format(nodeType.fn_name)
385         raise Exception(mes)
386       return nodeType.createNewNode(args_dic)
387     return my_func
388
389 def leaf(arg):
390   """
391   Decorator for python scripts.
392   """
393   if callable(arg):
394     # decorator used without parameters. arg is the function
395     container = ContainerManager.defaultContainerName
396     ret = (LeafDecorator(container))(arg)
397   else:
398     # decorator used with parameter. arg is the container name
399     ret = LeafDecorator(arg)
400   return ret
401
402 def bloc(f):
403   """
404   Decorator for blocs.
405   """
406   #co = f.__code__
407   #print("bloc :", co.co_name)
408   #print("  file:", co.co_filename)
409   #print("  line:", co.co_firstlineno)
410   #print("  args:", co.co_varnames)
411   return f
412
413 def default_foreach(f):
414   def my_func(lst):
415     result = []
416     for e in lst:
417       result.append(f(e))
418     t_result = result
419     if len(result) > 0 :
420       if type(result[0]) is tuple:
421         # transform the list of tuples in a tuple of lists
422         l_result = []
423         for e in result[0]:
424           l_result.append([])
425         for t in result:
426           idx = 0
427           for e in t:
428             l_result[idx].append(e)
429             idx += 1
430         t_result = tuple(l_result)
431     return t_result
432   return my_func
433
434 def yacs_foreach(f):
435   #co = f.__code__
436   #import yacsvisit
437   #props = yacsvisit.main(co.co_filename, co.co_name)
438   def my_func(input_list):
439     fn_name = f.__code__.co_name
440     generator = getGenerator()
441     sample_port = generator.beginForeach(fn_name, input_list)
442     output_list = f(sample_port)
443     output_list = generator.endForeach(output_list)
444     return output_list
445   return my_func
446
447 def foreach(f):
448   """
449   Decorator to generate foreach blocs
450   """
451   if this_module._exec_mode == this_module._default_mode:
452     return default_foreach(f)
453   elif this_module._exec_mode == this_module._yacs_mode:
454     return yacs_foreach(f)
455
456