]> SALOME platform Git repositories - modules/adao.git/commitdiff
Salome HOME
Allows extended keyword syntax settings for Parameters
authorJean-Philippe ARGAUD <jean-philippe.argaud@edf.fr>
Thu, 18 Jul 2019 12:41:52 +0000 (14:41 +0200)
committerJean-Philippe ARGAUD <jean-philippe.argaud@edf.fr>
Thu, 18 Jul 2019 12:41:52 +0000 (14:41 +0200)
src/daComposant/daCore/BasicObjects.py
src/daComposant/daCore/Templates.py
test/test6711/CTestTestfile.cmake
test/test6711/Doc_TUI_Exemple_02_Arguments.py [new file with mode: 0644]

index 5ac2e620ca3624751ffe129b8bc267646cab76dc..5a08802b8e225b9319db179931f25f4babce8f0c 100644 (file)
@@ -655,19 +655,24 @@ class Algorithm(object):
         #
         for k in self.StoredVariables:
             self.__canonical_stored_name[k.lower()] = k
+        #
+        for k, v in self.__variable_names_not_public.items():
+            self.__canonical_parameter_name[k.lower()] = k
+        self.__canonical_parameter_name["algorithm"] = "Algorithm"
+        self.__canonical_parameter_name["storesupplementarycalculations"] = "StoreSupplementaryCalculations"
 
     def _pre_run(self, Parameters, Xb=None, Y=None, R=None, B=None, Q=None ):
         "Pré-calcul"
         logging.debug("%s Lancement", self._name)
-        logging.debug("%s Taille mémoire utilisée de %.0f Mio", self._name, self._m.getUsedMemory("Mio"))
-        #
-        # Mise a jour de self._parameters avec Parameters
-        self.__setParameters(Parameters)
+        logging.debug("%s Taille mémoire utilisée de %.0f Mio"%(self._name, self._m.getUsedMemory("Mio")))
         #
+        # Mise a jour des paramètres internes avec le contenu de Parameters, en
+        # reprenant les valeurs par défauts pour toutes celles non définies
+        self.__setParameters(Parameters, reset=True)
         for k, v in self.__variable_names_not_public.items():
             if k not in self._parameters:  self.__setParameters( {k:v} )
         #
-        # Corrections et complements
+        # Corrections et compléments
         def __test_vvalue(argument, variable, argname):
             if argument is None:
                 if variable in self.__required_inputs["RequiredInputValues"]["mandatory"]:
@@ -856,6 +861,7 @@ class Algorithm(object):
                         raise ValueError("The value '%s' is not allowed for the parameter named '%s', it has to be in the list %s."%(v, __k, listval))
             elif __val not in listval:
                 raise ValueError("The value '%s' is not allowed for the parameter named '%s', it has to be in the list %s."%( __val, __k,listval))
+        #
         return __val
 
     def requireInputArguments(self, mandatory=(), optional=()):
@@ -865,16 +871,24 @@ class Algorithm(object):
         self.__required_inputs["RequiredInputValues"]["mandatory"] = tuple( mandatory )
         self.__required_inputs["RequiredInputValues"]["optional"]  = tuple( optional )
 
-    def __setParameters(self, fromDico={}):
+    def __setParameters(self, fromDico={}, reset=False):
         """
         Permet de stocker les paramètres reçus dans le dictionnaire interne.
         """
         self._parameters.update( fromDico )
+        __inverse_fromDico_keys = {}
+        for k in fromDico.keys():
+            if k.lower() in self.__canonical_parameter_name:
+                __inverse_fromDico_keys[self.__canonical_parameter_name[k.lower()]] = k
+        #~ __inverse_fromDico_keys = dict([(self.__canonical_parameter_name[k.lower()],k) for k in fromDico.keys()])
+        __canonic_fromDico_keys = __inverse_fromDico_keys.keys()
         for k in self.__required_parameters.keys():
-            if k in fromDico.keys():
-                self._parameters[k] = self.setParameterValue(k,fromDico[k])
-            else:
+            if k in __canonic_fromDico_keys:
+                self._parameters[k] = self.setParameterValue(k,fromDico[__inverse_fromDico_keys[k]])
+            elif reset:
                 self._parameters[k] = self.setParameterValue(k)
+            else:
+                pass
             logging.debug("%s %s : %s", self._name, self.__required_parameters[k]["message"], self._parameters[k])
 
 # ==============================================================================
index 1d1880da77cca775a1b760df52fc1e98b1a80721..aa9fcc3249f2727f7d6f38583e13192da029f455 100644 (file)
@@ -178,7 +178,7 @@ ObserverTemplates.store(
 ObserverTemplates.store(
     name    = "ValuePrinterSaverAndGnuPlotter",
     content = """print(str(info)+" "+str(var[-1]))\nimport numpy, re\nv=numpy.array(var[-1], ndmin=1)\nglobal istep\ntry:\n    istep += 1\nexcept:\n    istep = 0\nf='/tmp/value_%s_%05i.txt'%(info,istep)\nf=re.sub('\\s','_',f)\nprint('Value saved in \"%s\"'%f)\nnumpy.savetxt(f,v)\nimport Gnuplot\nglobal ifig,gp\ntry:\n    ifig += 1\n    gp(' set style data lines')\nexcept:\n    ifig = 0\n    gp = Gnuplot.Gnuplot(persist=1)\n    gp(' set style data lines')\ngp('set title  \"%s (Figure %i)\"'%(info,ifig))\ngp.plot( Gnuplot.Data( v, with_='lines lw 2' ) )""",
-    fr_FR   = "Imprime sur la sortie standard et, en même temps, enregistre dans un fichier du répertoire '/tmp' et affiche graphiquement la valeur courante de la variable ",
+    fr_FR   = "Imprime sur la sortie standard et, en même temps, enregistre dans un fichier du répertoire '/tmp' et affiche graphiquement la valeur courante de la variable",
     en_EN   = "Print on standard output and, in the same, time save in a file of the '/tmp' directory and graphically plot the current value of the variable",
     order   = "next",
     )
index 5bbc9128ef808e7b41c5307f79f621c32f0e2b60..e77828426a4b99b486d064cd68f90862f8faaeab 100644 (file)
@@ -21,6 +21,7 @@
 
 SET(TEST_NAMES
   Doc_TUI_Exemple_01_Savings
+  Doc_TUI_Exemple_02_Arguments
   )
 
 FOREACH(tfile ${TEST_NAMES})
diff --git a/test/test6711/Doc_TUI_Exemple_02_Arguments.py b/test/test6711/Doc_TUI_Exemple_02_Arguments.py
new file mode 100644 (file)
index 0000000..1c82de3
--- /dev/null
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright (C) 2008-2019 EDF R&D
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
+#
+# See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
+#
+# Author: Jean-Philippe Argaud, jean-philippe.argaud@edf.fr, EDF R&D
+"Verification d'un exemple de la documentation"
+
+import sys, os, tempfile
+import unittest
+
+# ==============================================================================
+class Test_Adao(unittest.TestCase):
+    def test1(self):
+        """Test"""
+        from numpy import array, matrix
+        from adao import adaoBuilder
+        #-----------------------------------------------------------------------
+        # Analyse avec les paramètres de casse correcte
+        print("\nCase 1")
+        case = adaoBuilder.New()
+        case.set( 'AlgorithmParameters', Algorithm='3DVAR' )
+        case.set( 'AlgorithmParameters',
+            Algorithm = '3DVAR',
+            Parameters = {
+                "Minimizer":"CG",
+                "MaximumNumberOfSteps":3,
+                "CostDecrementTolerance":1.e-2,
+                "SetSeed":1234567,
+                "StoreSupplementaryCalculations":[
+                    "CostFunctionJAtCurrentOptimum",
+                    "SimulatedObservationAtCurrentOptimum",
+                    ],
+                }
+            )
+        case.set( 'Background',          Vector=[0, 1, 2] )
+        case.set( 'BackgroundError',     ScalarSparseMatrix=1.0 )
+        case.set( 'Observation',         Vector=array([0.5, 1.5, 2.5]) )
+        case.set( 'ObservationError',    DiagonalSparseMatrix='1 1 1' )
+        case.set( 'ObservationOperator', Matrix='1 0 0;0 2 0;0 0 3' )
+        case.set( 'Observer',            Variable="CostFunctionJAtCurrentOptimum", Template="ValuePrinter" )
+        case.set( 'Observer',            Variable="Analysis", Template="ValuePrinter" )
+        case.execute()
+        xa1 = case.get("Analysis")[-1]
+        del case
+        #
+        #-----------------------------------------------------------------------
+        # Analyse avec les paramètres de casse quelconque
+        print("\nCase 2")
+        case = adaoBuilder.New()
+        case.set( 'AlgorithmParameters',
+            Algorithm = '3DVAR',
+            Parameters = {
+                "MINIMIZER":"CG",
+                "maximumnumberofsteps":3,
+                "COSTDecrementTOLERANCE":1.e-2,
+                "STORESUPPLEMENTARYCALCULATIONS":[
+                    "CostFunctionJAtCurrentOptimum",
+                    "SimulatedObservationAtCurrentOptimum",
+                    ],
+                }
+            )
+        case.set( 'Background',          Vector=[0, 1, 2] )
+        case.set( 'BackgroundError',     ScalarSparseMatrix=1.0 )
+        case.set( 'Observation',         Vector=array([0.5, 1.5, 2.5]) )
+        case.set( 'ObservationError',    DiagonalSparseMatrix='1 1 1' )
+        case.set( 'ObservationOperator', Matrix='1 0 0;0 2 0;0 0 3' )
+        case.set( 'Observer',            Variable="CostFunctionJAtCurrentOptimum", Template="ValuePrinter" )
+        case.set( 'Observer',            Variable="Analysis", Template="ValuePrinter" )
+        case.execute()
+        xa2 = case.get("Analysis")[-1]
+        del case
+        #
+        #-----------------------------------------------------------------------
+        ecart = assertAlmostEqualArrays(xa1, xa2, places = 15)
+        #
+        print("\nTest correct")
+
+# ==============================================================================
+def filesize(name):
+    statinfo = os.stat(name)
+    return statinfo.st_size # Bytes
+
+def assertAlmostEqualArrays(first, second, places=7, msg=None, delta=None):
+    "Compare two vectors, like unittest.assertAlmostEqual"
+    import numpy
+    if msg is not None:
+        print(msg)
+    if delta is not None:
+        if ( numpy.abs(numpy.asarray(first) - numpy.asarray(second)) > float(delta) ).any():
+            raise AssertionError("%s != %s within %s places"%(first,second,delta))
+    else:
+        if ( numpy.abs(numpy.asarray(first) - numpy.asarray(second)) > 10**(-int(places)) ).any():
+            raise AssertionError("%s != %s within %i places"%(first,second,places))
+    return max(abs(numpy.asarray(first) - numpy.asarray(second)))
+
+# ==============================================================================
+if __name__ == "__main__":
+    print('\nAUTODIAGNOSTIC\n')
+    sys.stderr = sys.stdout
+    unittest.main(verbosity=2)