]> SALOME platform Git repositories - tools/medcoupling.git/commitdiff
Salome HOME
Add MEDCouplingRemapper.ToCSRMatrix static method to compare matrices
authorAnthony Geay <anthony.geay@edf.fr>
Wed, 24 Mar 2021 16:07:01 +0000 (17:07 +0100)
committerAnthony Geay <anthony.geay@edf.fr>
Wed, 24 Mar 2021 16:07:01 +0000 (17:07 +0100)
src/MEDCoupling_Swig/MEDCouplingRemapperCommon.i
src/MEDCoupling_Swig/MEDCouplingRemapperTest.py

index a3fb3653d49f4642e7d21196fe10a03e89123aa0..b5133e81ea6c54b807f5a01e8523ef68ccc852b1 100644 (file)
@@ -91,6 +91,12 @@ namespace MEDCoupling
            {
              return ToCSRMatrix(self->getCrudeMatrix(),self->getNumberOfColsOfMatrix());
            }
+           static PyObject *ToCSRMatrix(PyObject *m, mcIdType nbOfCols)
+           {
+              std::vector<std::map<mcIdType,double> > mCpp;
+              convertToVectMapIntDouble(m,mCpp);
+              return ToCSRMatrix(mCpp,nbOfCols);
+           }
 #endif
            void setCrudeMatrix(const MEDCouplingMesh *srcMesh, const MEDCouplingMesh *targetMesh, const std::string& method, PyObject *m)
            {
index 49ff6b9e305735d66cce4f8d88f80ce4bca12602..00a716ccfdcbae83fb66e890303483d7d156803e 100644 (file)
@@ -1492,6 +1492,22 @@ class MEDCouplingBasicsTest(unittest.TestCase):
         self.checkMatrix(rem.getCrudeMatrix(),[{0:0.65,1:0.35}],src.getNumberOfNodes(),1e-12)
         pass
 
+    @unittest.skipUnless(MEDCouplingHasNumPyBindings() and MEDCouplingHasSciPyBindings(),"requires numpy AND scipy")
+    def testRemToCSRMatrix(self):
+        import scipy
+        mPy = [{0:1.0,1:3.0,3:7.0,6:10.},{1:12.0,2:23.0}]
+        m = MEDCouplingRemapper.ToCSRMatrix(mPy,8)
+        self.assertTrue(isinstance(m,scipy.sparse.csr.csr_matrix))
+        self.assertEqual(m.getnnz(),6)
+        self.assertAlmostEqual(m[0,0],1.0,12)
+        self.assertAlmostEqual(m[0,1],3.0,12)
+        self.assertAlmostEqual(m[0,3],7.0,12)
+        self.assertAlmostEqual(m[0,6],10.0,12)
+        self.assertAlmostEqual(m[1,1],12.0,12)
+        self.assertAlmostEqual(m[1,2],23.0,12)
+        self.assertEqual(m.shape,(2,8))
+        pass
+
     def checkMatrix(self,mat1,mat2,nbCols,eps):
         self.assertEqual(len(mat1),len(mat2))
         for i in range(len(mat1)):