Salome HOME
05db2f940fa2d0c7dce2096406d1a02eac0c81fd
[tools/ydefx.git] / src / pydefx / samplecsvmanager.py
1 # Copyright (C) 2019  EDF R&D
2 #
3 # This library is free software; you can redistribute it and/or
4 # modify it under the terms of the GNU Lesser General Public
5 # License as published by the Free Software Foundation; either
6 # version 2.1 of the License, or (at your option) any later version.
7 #
8 # This library is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
11 # Lesser General Public License for more details.
12 #
13 # You should have received a copy of the GNU Lesser General Public
14 # License along with this library; if not, write to the Free Software
15 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
16 #
17 # See http://www.salome-platform.org/ or email : webmaster.salome@opencascade.com
18 #
19 import csv
20 import inspect
21 import os
22 import pathlib
23 from . import sample
24 from . import samplecsviterator
25 SampleIterator = samplecsviterator.SampleIterator
26
27 class SampleManager:
28   """
29   The SampleManager is used by the study for reading and writing a sample from
30   and to the file system. This SampleManager uses the csv format.
31   The following services are needed by the study:
32   - write the sample on the local file system (prepareRun).
33   - know what files were written in order to copy them on the remote file system
34   (return value of prepareRun).
35   - know what files contain the result in order to bring them back from the
36   remote file system to the local one (getResultFileName).
37   - load the results from the local file system to a sample (loadResult).
38   - restore a sample from a local directory when you want to recover a job
39   launched in a previous session.
40   - the name of the module which contains the class SampleIterator in order to
41   iterate over the input values of the sample (getModuleName).
42   This name is written by the study in a configuration file and it is used by
43   the optimizer loop plugin.
44   """
45   def __init__(self):
46     pass
47
48   # Functions used by the study
49   def prepareRun(self, sample, directory):
50     """
51     Create a dump of the sample in the given directory.
52     sample: Sample object.
53     directory: path to a local working directory where all the working files are
54                copied. This directory should be already created.
55     Return a list of files to add to the input files list of the job.
56     """
57     datapath = os.path.join(directory, SampleIterator.DATAFILE)
58     with open(datapath, 'w', newline='') as csvfile:
59       writer = csv.DictWriter(csvfile,
60                               fieldnames=sample.getInputNames(),
61                               quoting=csv.QUOTE_NONNUMERIC )
62       writer.writeheader()
63       writer.writerows(sample.inputIterator())
64
65     outnamespath = os.path.join(directory, SampleIterator.OUTPUTNAMESFILE)
66     with open(outnamespath, 'w') as outputfile:
67       for v in sample.getOutputNames():
68         outputfile.write(v+'\n')
69     filename = inspect.getframeinfo(inspect.currentframe()).filename
70     install_directory = pathlib.Path(filename).resolve().parent
71     iteratorFile = os.path.join(install_directory, "samplecsviterator.py")
72     return [datapath,
73             outnamespath,
74             iteratorFile
75             ]
76
77   def loadResult(self, sample, directory):
78     """
79     The directory should contain a RESULTDIR directory with the result files.
80     The results are loaded into the sample.
81     Return the modified sample.
82     """
83     resultdir = os.path.join(directory, SampleIterator.RESULTDIR)
84     datapath = os.path.join(resultdir, SampleIterator.RESULTFILE)
85     with open(datapath, newline='') as datafile:
86       data = csv.DictReader(datafile, quoting=csv.QUOTE_NONNUMERIC)
87       for elt in data:
88         index = int(elt[SampleIterator.IDCOLUMN]) # float to int
89         input_vals = {}
90         for name in sample.getInputNames():
91           input_vals[name] = elt[name]
92         output_vals = {}
93         for name in sample.getOutputNames():
94           output_vals[name] = samplecsviterator._decodeOutput(elt[name],
95                                                               resultdir)
96         try:
97           sample.checkId(index, input_vals)
98         except Exception as err:
99           extraInfo = "Error on processing file {} index number {}:".format(
100                                                 datapath,       str(index))
101           raise Exception(extraInfo + str(err))
102         sample.addResult(index, output_vals, elt[SampleIterator.ERRORCOLUMN])
103     return sample
104
105   def restoreSample(self, directory):
106     """ The directory should contain the files created by prepareRun. A new
107     sample object is created and returned from those files.
108     This function is used to recover a previous run.
109     """
110     sampleIt = SampleIterator(directory)
111     inputvalues = {}
112     for name in sampleIt.inputnames:
113       inputvalues[name] = []
114     for newid, values in sampleIt:
115       for name in sampleIt.inputnames:
116         inputvalues[name].append(values[name])
117     
118     result = sample.Sample(sampleIt.inputnames, sampleIt.outputnames)
119     result.setInputValues(inputvalues)
120     sampleIt.terminate()
121     return result
122
123   def getModuleName(self):
124     """
125     Return the module name which contains the class SampleIterator.
126     """
127     return "samplecsviterator"
128   
129   def getResultFileName(self):
130     """
131     Name of the file or directory which contains the result and needs to be
132     copied from the remote computer.
133     """
134     return SampleIterator.RESULTDIR
135