Salome HOME
Updated copyright comment
[modules/yacs.git] / src / py2yacs / yacsdecorator.py
1 #!/usr/bin/env python3
2 # Copyright (C) 2006-2024  CEA, EDF
3 #
4 # This library is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU Lesser General Public
6 # License as published by the Free Software Foundation; either
7 # version 2.1 of the License, or (at your option) any later version.
8 #
9 # This library is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 # Lesser General Public License for more details.
13 #
14 # You should have received a copy of the GNU Lesser General Public
15 # License along with this library; if not, write to the Free Software
16 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
17 #
18 # See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
19 #
20 import sys
21 import json
22
23 # this is a pointer to the module object instance itself.
24 this_module = sys.modules[__name__]
25
26 class OutputPort:
27   def __init__(self, yacs_node, yacs_port):
28     self.yacs_node = yacs_node
29     self.yacs_port = yacs_port
30
31   def IAmAManagedPort(self):
32     """ Type check."""
33     return True
34
35   def linkTo(self, input_port, input_node, generator):
36     generator.proc.edAddLink(self.yacs_port, input_port)
37     generator.addCFLink(self.yacs_node, input_node)
38
39   def getPort(self):
40     return self.yacs_port
41
42   def getNode(self):
43     return self.yacs_node
44
45 class OutputPortWithCollector:
46   def __init__(self, output_port):
47     self.output_port = output_port
48     self.connectedInputPorts = []
49
50   def IAmAManagedPort(self):
51     """ Type check."""
52     return True
53
54   def linkTo(self, input_port, input_node, generator):
55     self.output_port.linkTo(input_port, input_node, generator)
56     self.connectedInputPorts.append(input_port)
57
58   def getPort(self):
59     return self.output_port.getPort()
60
61   def getNode(self):
62     return self.output_port.getNode()
63
64   def connectedPorts(self):
65     return self.connectedInputPorts
66
67 class LeafNodeType:
68   def __init__(self, path, fn_name, inputs, outputs, container_name):
69     self.path = path
70     self.fn_name = fn_name
71     self.inputs = inputs
72     self.outputs = outputs
73     self.container_name = container_name
74     self.number = 0
75
76   def newName(self):
77     name = self.fn_name + "_" + str(self.number)
78     self.number += 1
79     return name
80
81   def createNewNode(self, inputs):
82     """
83     inputs : dict {input_name:value}
84     """
85     generator = getGenerator()
86     output_ports = generator.createScriptNode(self, inputs)
87     return output_ports
88
89 class ContainerProperties():
90   def __init__(self, name, nb_cores, use_cache):
91     self.name = name
92     self.nb_cores = nb_cores
93     self.use_cache = use_cache
94
95 def jsonContainerEncoder(obj):
96   if isinstance(obj, ContainerProperties) :
97     return {
98             "name": obj.name,
99             "nb_cores": obj.nb_cores,
100             "use_cache": obj.use_cache }
101   else:
102     raise TypeError("Cannot serialize object "+str(obj))
103
104 def jsonContainerDecoder(dct):
105   if "name" in dct and "nb_cores" in dct and "use_cache" in dct :
106     return ContainerProperties(dct["name"], dct["nb_cores"], dct["use_cache"])
107   return dct
108
109 class ContainerManager():
110   defaultContainerName = "default_container"
111   def __init__(self):
112     self._containers = []
113     self._defaultContainer = ContainerProperties(
114                                 ContainerManager.defaultContainerName, 0, False)
115     self._containers.append(self._defaultContainer)
116
117   def setDefaultContainer(self, nb_cores, use_cache):
118     self._defaultContainer.nb_cores = nb_cores
119     self._defaultContainer.use_cache = use_cache
120
121   def loadFile(self, file_path):
122     with open(file_path, 'r') as json_file:
123       self._containers = json.load(json_file, object_hook=jsonContainerDecoder)
124     try:
125       self._defaultContainer = next(cont for cont in self._containers
126                           if cont.name == ContainerManager.defaultContainerName)
127     except StopIteration:
128       self._defaultContainer = ContainerProperties(
129                                 ContainerManager.defaultContainerName, 0, False)
130       self._containers.append(self._defaultContainer)
131
132   def saveFile(self, file_path):
133     with open(file_path, 'w') as json_file:
134       json.dump(self._containers, json_file,
135                 indent=2, default=jsonContainerEncoder)
136
137   def addContainer(self, name, nb_cores, use_cache):
138     try:
139       # if the name already exists
140       obj = next(cont for cont in self._containers if cont.name == name)
141       obj.nb_cores = nb_cores
142       obj.use_cache = use_cache
143     except StopIteration:
144       # new container
145       self._containers.append(ContainerProperties(name, nb_cores, use_cache))
146
147   def getContainer(self, name):
148     ret = self._defaultContainer
149     try:
150       ret = next(cont for cont in self._containers if cont.name == name)
151     except StopIteration:
152       # not found
153       pass
154     return ret
155
156 class SchemaGenerator():
157   """
158   Link to Salome for YACS schema generation.
159   """
160   def __init__(self):
161     import SALOMERuntime
162     SALOMERuntime.RuntimeSALOME.setRuntime()
163     self.runtime = SALOMERuntime.getSALOMERuntime()
164     self.proc = self.runtime.createProc("GeneratedSchema")
165     self.proc.setProperty("executor","workloadmanager")
166     self.containers = {}
167     self.pyobjtype = self.runtime.getTypeCode("pyobj")
168     self.seqpyobjtype = self.runtime.getTypeCode("seqpyobj")
169     self.booltype = self.runtime.getTypeCode("bool")
170     self.block_stack = [self.proc]
171     self.name_index = 0 # used to ensure unique names
172     self.container_manager = ContainerManager()
173
174   def newName(self, name):
175     new_name = name + "_" + str(self.name_index)
176     self.name_index += 1
177     return new_name
178
179   def isAManagedPort(self, port):
180     try:
181       isManagedPort = port.IAmAManagedPort()
182     except AttributeError:
183       isManagedPort = False
184     return isManagedPort
185
186   def getContextName(self):
187     context_name = ""
188     if len(self.block_stack) > 1:
189       # We are in a block
190       block_path = ".".join([ b.getName() for b in self.block_stack[1:] ])
191       context_name = block_path + "."
192     return context_name
193
194   def getContainer(self, container_type):
195     """
196     A new container may be created if it does not already exist for this type.
197     """
198     container_properties = self.container_manager.getContainer(container_type)
199     if container_type not in self.containers:
200       cont=self.proc.createContainer(container_properties.name,"Salome")
201       cont.setProperty("nb_parallel_procs", str(container_properties.nb_cores))
202       cont.setProperty("type","multi")
203       cont.usePythonCache(container_properties.use_cache)
204       cont.attachOnCloning()
205       self.containers[container_type] = cont
206     return self.containers[container_type]
207
208   def createScript(self, file_path, function_name, inputs, outputs):
209     import inspect
210     stack = inspect.stack()
211     stack_info = "Call stack\n"
212     # skip the first 4 levels in the stack
213     for level in stack[4:-1] :
214       info = inspect.getframeinfo(level[0])
215       stack_info += "file: {}, line: {}, function: {}, context: {}\n".format(
216         info.filename, info.lineno, info.function, info.code_context)
217      
218     if len(outputs) == 0:
219       result = ""
220     elif len(outputs) == 1:
221       result = "{} = ".format(outputs[0])
222     else:
223       result = ",".join(outputs)
224       result += " = "
225
226     if len(inputs) == 0:
227       params = ""
228     elif len(inputs) == 1:
229       params = "{} ".format(inputs[0])
230     else:
231       params = ",".join(inputs)
232     
233     script = """'''
234 {call_stack}
235 '''
236 import yacstools
237 study_function = yacstools.getFunction("{file_path}", "{function_name}")
238 {result}study_function({parameters})
239 """.format(call_stack=stack_info,
240            file_path=file_path,
241            function_name=function_name,
242            result=result,
243            parameters=params)
244     return script
245
246   def createScriptNode(self, leaf, input_values):
247     node_name = leaf.newName()
248     file_path = leaf.path
249     function_name = leaf.fn_name
250     inputs = leaf.inputs # names
251     outputs = leaf.outputs # names
252     script = self.createScript(file_path, function_name, inputs, outputs)
253     container = self.getContainer(leaf.container_name)
254     new_node = self.runtime.createScriptNode("Salome", node_name)
255     new_node.setContainer(container)
256     new_node.setExecutionMode("remote")
257     new_node.setScript(script)
258     self.block_stack[-1].edAddChild(new_node)
259     # create ports
260     for p in inputs:
261       new_node.edAddInputPort(p, self.pyobjtype)
262     output_obj_list = []
263     for p in outputs:
264       port = new_node.edAddOutputPort(p, self.pyobjtype)
265       output_obj_list.append(OutputPort(new_node, port))
266     # create links
267     for k,v in input_values.items():
268       input_port = new_node.getInputPort(k)
269       if self.isAManagedPort(v) :
270         v.linkTo(input_port, new_node, self)
271       else:
272         input_port.edInitPy(v)
273     # return output ports
274     result = None
275     if len(output_obj_list) == 1 :
276       result = output_obj_list[0]
277     elif len(output_obj_list) > 1 :
278       result = tuple(output_obj_list)
279     return result
280
281   def beginForeach(self, fn_name, input_values):
282     foreach_name = self.newName(fn_name)
283     new_foreach = self.runtime.createForEachLoopDyn(foreach_name,
284                                                     self.pyobjtype)
285     self.block_stack[-1].edAddChild(new_foreach)
286     block_name = "block_"+foreach_name
287     new_block = self.runtime.createBloc(block_name)
288     new_foreach.edAddChild(new_block)
289     sample_port = new_foreach.edGetSamplePort()
290     input_list_port = new_foreach.edGetSeqOfSamplesPort()
291     try:
292       isManagedPort = input_values.IAmAManagedPort()
293     except AttributeError:
294       isManagedPort = False
295     if self.isAManagedPort(input_values) :
296       # we need a conversion node pyobj -> seqpyobj
297       conversion_node = self.runtime.createScriptNode("Salome",
298                                                       "input_"+foreach_name)
299       port_name = "val"
300       input_port = conversion_node.edAddInputPort(port_name, self.pyobjtype)
301       output_port = conversion_node.edAddOutputPort(port_name,
302                                                     self.seqpyobjtype)
303       conversion_node.setExecutionMode("local") # no need for container
304       # no script, the same variable for input and output
305       conversion_node.setScript("")
306       self.block_stack[-1].edAddChild(conversion_node)
307       input_values.linkTo(input_port, conversion_node, self)
308       self.proc.edAddLink(output_port, input_list_port)
309       # No need to look for ancestors. Both nodes are on the same level.
310       self.proc.edAddCFLink(conversion_node, new_foreach)
311     else:
312       input_list_port.edInitPy(list(input_values))
313     self.block_stack.append(new_foreach)
314     self.block_stack.append(new_block)
315     return OutputPort(new_foreach, sample_port)
316
317   def endForeach(self, outputs):
318     self.block_stack.pop() # remove the block
319     for_each_node = self.block_stack.pop() # remove the foreach
320     converted_ret = None
321     if outputs is not None:
322       # We need a conversion node seqpyobj -> pyobj
323       if type(outputs) is tuple:
324         list_out = list(outputs)
325       else:
326         list_out = [outputs]
327       conversion_node_name = "output_" + for_each_node.getName()
328       conversion_node = self.runtime.createScriptNode("Salome",
329                                                       conversion_node_name)
330       conversion_node.setExecutionMode("local") # no need for container
331       conversion_node.setScript("")
332       self.block_stack[-1].edAddChild(conversion_node)
333       list_ret = []
334       idx_name = 0 # for unique port names
335       for port in list_out :
336         if self.isAManagedPort(port):
337           port_name = port.getPort().getName() + "_" + str(idx_name)
338           input_port = conversion_node.edAddInputPort(port_name,
339                                                       self.seqpyobjtype)
340           output_port = conversion_node.edAddOutputPort(port_name,
341                                                         self.pyobjtype)
342           self.proc.edAddLink(port.getPort(), input_port)
343           list_ret.append(OutputPort(conversion_node, output_port))
344           idx_name += 1
345         else:
346           list_ret.append(port)
347       self.proc.edAddCFLink(for_each_node, conversion_node)
348       if len(list_ret) > 1 :
349         converted_ret = tuple(list_ret)
350       else:
351         converted_ret = list_ret[0]
352     return converted_ret
353
354   def dump(self, file_path):
355     self.proc.saveSchema(file_path)
356
357   def addCFLink(self, node_from, node_to):
358     commonAncestor = self.proc.getLowestCommonAncestor(node_from, node_to)
359     if node_from.getName() != commonAncestor.getName() :
360       while node_from.getFather().getName() != commonAncestor.getName() :
361         node_from = node_from.getFather()
362       while node_to.getFather().getName() != commonAncestor.getName() :
363         node_to = node_to.getFather()
364       self.proc.edAddCFLink(node_from, node_to)
365     else:
366       # from node is ancestor of to node. No CF link needed.
367       pass
368
369   def beginWhileloop(self, fn_name, context):
370     whileloop_name = self.newName("whileloop_"+fn_name)
371     while_node = self.runtime.createWhileLoop(whileloop_name)
372     self.block_stack[-1].edAddChild(while_node)
373     if not self.isAManagedPort(context):
374       # create a init node in order to get a port for the context
375       indata_name = "Inputdata_" + whileloop_name
376       indata_node = self.runtime.createScriptNode("Salome", indata_name)
377       indata_inport = indata_node.edAddInputPort("context", self.pyobjtype)
378       indata_outport = indata_node.edAddOutputPort("context", self.pyobjtype)
379       indata_inport.edInitPy(context)
380       context = OutputPort(indata_node, indata_outport)
381       self.block_stack[-1].edAddChild(indata_node)
382
383     block_name = "block_"+whileloop_name
384     new_block = self.runtime.createBloc(block_name)
385     while_node.edAddChild(new_block)
386     self.block_stack.append(while_node)
387     self.block_stack.append(new_block)
388     self.proc.edAddCFLink(context.getNode(), while_node)
389     ret = OutputPortWithCollector(context)
390     return ret
391
392   def endWhileloop(self, condition, collected_context, loop_result):
393     while_node = self.block_stack[-2]
394     cport = while_node.edGetConditionPort()
395     # need a conversion node pyobj -> bool
396     conversion_node = self.runtime.createScriptNode("Salome",
397                                                     "while_condition")
398     conversion_node.setExecutionMode("local") # no need for container
399     conversion_node.setScript("")
400     port_name = "val"
401     input_port = conversion_node.edAddInputPort(port_name, self.pyobjtype)
402     output_port = conversion_node.edAddOutputPort(port_name, self.booltype)
403     self.block_stack[-1].edAddChild(conversion_node)
404     condition.linkTo(input_port, conversion_node, self)
405     self.proc.edAddLink(output_port, cport)
406     if not loop_result is None:
407       for p in collected_context.connectedPorts():
408         self.proc.edAddLink(loop_result.getPort(), p)
409     self.block_stack.pop() # remove the block
410     self.block_stack.pop() # remove the while node
411
412 _generator = None
413
414 _default_mode = "Default"
415 _yacs_mode = "YACS"
416 _exec_mode = _default_mode
417
418 # Public functions
419
420 def getGenerator():
421   """
422   Get the singleton object.
423   """
424   if this_module._generator is None:
425     if this_module._exec_mode == this_module._yacs_mode:
426       this_module._generator = SchemaGenerator()
427   return this_module._generator
428
429 def activateYacsMode():
430   this_module._exec_mode = this_module._yacs_mode
431
432 def activateDefaultMode():
433   this_module._exec_mode = this_module._default_mode
434
435 def loadContainers(file_path):
436   getGenerator().container_manager.loadFile(file_path)
437
438 def export(path):
439   if this_module._exec_mode == this_module._yacs_mode :
440     getGenerator().dump(path)
441
442 # Decorators
443 class LeafDecorator():
444   def __init__(self, container_name):
445     self.container_name = container_name
446
447   def __call__(self, f):
448     if this_module._exec_mode == this_module._default_mode:
449       return f
450     co = f.__code__
451     import py2yacs
452     props = py2yacs.function_properties(co.co_filename, co.co_name)
453     nodeType = LeafNodeType(co.co_filename, co.co_name,
454                             props.inputs, props.outputs, self.container_name)
455     def my_func(*args, **kwargs):
456       if len(args) + len(kwargs) != len(nodeType.inputs):
457         mes = "Wrong number of arguments when calling function '{}'.\n".format(
458                                                                 nodeType.fn_name)
459         mes += " {} arguments expected and {} arguments found.\n".format(
460                                     len(nodeType.inputs), len(args) + len(kwargs))
461         raise Exception(mes)
462       idx = 0
463       args_dic = {}
464       for a in args:
465         args_dic[nodeType.inputs[idx]] = a
466         idx += 1
467       for k,v in kwargs.items():
468         args_dic[k] = v
469       if len(args_dic) != len(nodeType.inputs):
470         mes="Wrong arguments when calling function {}.\n".format(nodeType.fn_name)
471         raise Exception(mes)
472       return nodeType.createNewNode(args_dic)
473     return my_func
474
475 def leaf(arg):
476   """
477   Decorator for python scripts.
478   """
479   if callable(arg):
480     # decorator used without parameters. arg is the function
481     container = ContainerManager.defaultContainerName
482     ret = (LeafDecorator(container))(arg)
483   else:
484     # decorator used with parameter. arg is the container name
485     ret = LeafDecorator(arg)
486   return ret
487
488 def block(f):
489   """
490   Decorator for blocks.
491   """
492   #co = f.__code__
493   #print("block :", co.co_name)
494   #print("  file:", co.co_filename)
495   #print("  line:", co.co_firstlineno)
496   #print("  args:", co.co_varnames)
497   return f
498
499 def seqblock(f):
500   """
501   Decorator for sequential blocks.
502   """
503   if this_module._exec_mode == this_module._yacs_mode:
504   # TODO create a new block and set a flag to add dependencies between
505   # nodes in the block
506     pass
507   return f
508
509 def default_foreach(f):
510   def my_func(lst):
511     result = []
512     for e in lst:
513       result.append(f(e))
514     t_result = result
515     if len(result) > 0 :
516       if type(result[0]) is tuple:
517         # transform the list of tuples in a tuple of lists
518         l_result = []
519         for e in result[0]:
520           l_result.append([])
521         for t in result:
522           idx = 0
523           for e in t:
524             l_result[idx].append(e)
525             idx += 1
526         t_result = tuple(l_result)
527     return t_result
528   return my_func
529
530 def yacs_foreach(f):
531   #co = f.__code__
532   #import yacsvisit
533   #props = yacsvisit.main(co.co_filename, co.co_name)
534   def my_func(input_list):
535     fn_name = f.__code__.co_name
536     generator = getGenerator()
537     sample_port = generator.beginForeach(fn_name, input_list)
538     output_list = f(sample_port)
539     output_list = generator.endForeach(output_list)
540     return output_list
541   return my_func
542
543 def foreach(f):
544   """
545   Decorator to generate foreach blocks
546   """
547   if this_module._exec_mode == this_module._default_mode:
548     return default_foreach(f)
549   elif this_module._exec_mode == this_module._yacs_mode:
550     return yacs_foreach(f)
551
552 def default_forloop(l, f, context):
553   for e in l:
554     context = f(e, context)
555   return context
556
557 def yacs_forloop(l, f, context):
558     # TODO
559     pass
560
561 def forloop(l, f, context):
562   """
563   Forloop structure for distributed computations.
564   This shall be used as a regular function, not as a decorator.
565   Parameters:
566   l : list of values to iterate on
567   f : a function which is the body of the loop
568   context : the value of the context for the first iteration.
569   Return: context of the last iteration.
570
571   The f function shall take two parameters. The first is an element of the list
572   and the second is the context returned by the previous iteration.
573   The f function shall return one value, which is the context needed by the next
574   iteration.
575   """
576   if this_module._exec_mode == this_module._default_mode:
577     return default_forloop(l, f, context)
578   elif this_module._exec_mode == this_module._yacs_mode:
579     return yacs_forloop(l, f, context)
580
581 def default_whileloop(f, context):
582   cond = True
583   while cond :
584     cond, context = f(context)
585   return context
586
587 def yacs_whileloop(f, context):
588   fn_name = f.__code__.co_name
589   generator = getGenerator()
590   managed_context = generator.beginWhileloop(fn_name, context)
591   # managed context extends the context with the list of all input ports
592   # the context is linked to
593   cond, ret = f(managed_context)
594   generator.endWhileloop(cond, managed_context, ret)
595   return ret
596
597 def whileloop( f, context):
598   """
599   Whileloop structure for distributed computations.
600   This shall be used as a regular function, not as a decorator.
601   Parameters:
602   f : a function which is the body of the loop
603   context : the value of the context for the first iteration.
604   Return: context of the last iteration.
605
606   The f function shall take one parameter which is the context returned by the
607   previous iteration. It shall return a tuple of two values. The first value
608   should be True or False, to say if the loop shall continue or not. The second
609   is the context used by the next iteration.
610   """
611   if this_module._exec_mode == this_module._default_mode:
612     return default_whileloop(f, context)
613   elif this_module._exec_mode == this_module._yacs_mode:
614     return yacs_whileloop(f, context)
615
616 DEFAULT_SWITCH_ID = -1973012217
617
618 def default_switch(t, cases, *args, **kwargs):
619   ret = None
620   if t in cases.keys():
621     ret = cases[t](*args, **kwargs)
622   elif DEFAULT_SWITCH_ID in cases.keys():
623     ret = cases[DEFAULT_SWITCH_ID](*args, **kwargs)
624   return ret
625
626 def yacs_switch(t, cases, *args, **kwargs):
627   # TODO
628   pass
629
630 def switch( t,       # integer value to test
631             cases,   # dic { value: function}
632             *args,   # args to call the function
633             **kwargs # kwargs to call the function
634            ):
635   if this_module._exec_mode == this_module._default_mode:
636     return default_switch(t, cases, *args, **kwargs)
637   elif this_module._exec_mode == this_module._yacs_mode:
638     return yacs_switch(t, cases, *args, **kwargs)
639
640 def begin_sequential_block():
641   if this_module._exec_mode == this_module._default_mode:
642     return
643   # TODO yacs mode
644
645 def end_sequential_block():
646   if this_module._exec_mode == this_module._default_mode:
647     return
648   # TODO yacs mode