Salome HOME
Initialization.
[tools/ydefx.git] / src / pydefx / samplecsvmanager.py
1 import csv
2 import inspect
3 import os
4 import pathlib
5 from .samplecsviterator import SampleIterator
6 from . import samplecsviterator
7 from . import sample
8
9 class SampleManager:
10   def __init__(self):
11     pass
12
13   def prepareRun(self, sample, directory):
14     """
15     Create a dump of the sample in the given directory.
16     Return a list of files to add to the input files list of the job.
17     """
18     datapath = os.path.join(directory, SampleIterator.DATAFILE)
19     with open(datapath, 'w', newline='') as csvfile:
20       writer = csv.DictWriter(csvfile,
21                               fieldnames=sample.getInputNames(),
22                               quoting=csv.QUOTE_NONNUMERIC )
23       writer.writeheader()
24       writer.writerows(sample.inputIterator())
25     
26     outnamespath = os.path.join(directory, SampleIterator.OUTPUTNAMESFILE)
27     with open(outnamespath, 'w') as outputfile:
28       for v in sample.getOutputNames():
29         outputfile.write(v+'\n')
30     filename = inspect.getframeinfo(inspect.currentframe()).filename
31     install_directory = pathlib.Path(filename).resolve().parent
32     iteratorFile = os.path.join(install_directory, "samplecsviterator.py")
33     return [datapath,
34             outnamespath,
35             iteratorFile
36             ]
37
38   def loadResult(self, sample, directory):
39     """ The directory should contain a file with the name given by
40     getResultFileName. The results are loaded from that file to the sample.
41     Return the modified sample.
42     """
43     datapath = os.path.join(directory, SampleIterator.RESULTFILE)
44     with open(datapath, newline='') as datafile:
45       data = csv.DictReader(datafile, quoting=csv.QUOTE_NONNUMERIC)
46       for elt in data:
47         index = int(elt[SampleIterator.IDCOLUMN]) # float to int
48         input_vals = {}
49         for name in sample.getInputNames():
50           input_vals[name] = elt[name]
51         output_vals = {}
52         for name in sample.getOutputNames():
53           output_vals[name] = elt[name]
54         try:
55           sample.checkId(index, input_vals)
56         except Exception as err:
57           extraInfo = "Error on processing file {} index number {}:".format(
58                                                 datapath,       str(index))
59           raise Exception(extraInfo + str(err))
60         sample.addResult(index, output_vals, elt[SampleIterator.ERRORCOLUMN])
61     return sample
62
63   def loadSample(self, directory):
64     """ The directory should contain the files created by prepareRun. A new
65     sample object is created and returned from those files.
66     This function is used to recover a previous run.
67     """
68     outputnamesfile = os.path.join(directory, SampleIterator.OUTPUTNAMESFILE)
69     outputnames = samplecsviterator._loadOutputNames(outputnamesfile)
70     inputFilePath = os.path.join(directory, SampleIterator.DATAFILE)
71     with open(inputFilePath) as datafile:
72       data = csv.DictReader(datafile, quoting=csv.QUOTE_NONNUMERIC)
73       inputvalues = {}
74       for name in data.fieldnames:
75         inputvalues[name] = []
76       for line in data:
77         for name in data.fieldnames:
78           inputvalues[name].append(line[name])
79     result = sample.Sample(data.fieldnames, outputnames)
80     result.setInputValues(inputvalues)
81     return result
82         
83
84   def getModuleName(self):
85     """
86     Return the module name which contains the class SampleIterator.
87     """
88     return "samplecsviterator"
89   
90   def getResultFileName(self):
91     return SampleIterator.RESULTFILE