From fd1c80f1c77841056befaaff88f631b0749748fe Mon Sep 17 00:00:00 2001 From: Anthony Geay Date: Wed, 24 Mar 2021 17:07:01 +0100 Subject: [PATCH] Add MEDCouplingRemapper.ToCSRMatrix static method to compare matrices --- src/MEDCoupling_Swig/MEDCouplingRemapperCommon.i | 6 ++++++ src/MEDCoupling_Swig/MEDCouplingRemapperTest.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/MEDCoupling_Swig/MEDCouplingRemapperCommon.i b/src/MEDCoupling_Swig/MEDCouplingRemapperCommon.i index a3fb3653d..b5133e81e 100644 --- a/src/MEDCoupling_Swig/MEDCouplingRemapperCommon.i +++ b/src/MEDCoupling_Swig/MEDCouplingRemapperCommon.i @@ -91,6 +91,12 @@ namespace MEDCoupling { return ToCSRMatrix(self->getCrudeMatrix(),self->getNumberOfColsOfMatrix()); } + static PyObject *ToCSRMatrix(PyObject *m, mcIdType nbOfCols) + { + std::vector > mCpp; + convertToVectMapIntDouble(m,mCpp); + return ToCSRMatrix(mCpp,nbOfCols); + } #endif void setCrudeMatrix(const MEDCouplingMesh *srcMesh, const MEDCouplingMesh *targetMesh, const std::string& method, PyObject *m) { diff --git a/src/MEDCoupling_Swig/MEDCouplingRemapperTest.py b/src/MEDCoupling_Swig/MEDCouplingRemapperTest.py index 49ff6b9e3..00a716ccf 100644 --- a/src/MEDCoupling_Swig/MEDCouplingRemapperTest.py +++ b/src/MEDCoupling_Swig/MEDCouplingRemapperTest.py @@ -1492,6 +1492,22 @@ class MEDCouplingBasicsTest(unittest.TestCase): self.checkMatrix(rem.getCrudeMatrix(),[{0:0.65,1:0.35}],src.getNumberOfNodes(),1e-12) pass + @unittest.skipUnless(MEDCouplingHasNumPyBindings() and MEDCouplingHasSciPyBindings(),"requires numpy AND scipy") + def testRemToCSRMatrix(self): + import scipy + mPy = [{0:1.0,1:3.0,3:7.0,6:10.},{1:12.0,2:23.0}] + m = MEDCouplingRemapper.ToCSRMatrix(mPy,8) + self.assertTrue(isinstance(m,scipy.sparse.csr.csr_matrix)) + self.assertEqual(m.getnnz(),6) + self.assertAlmostEqual(m[0,0],1.0,12) + self.assertAlmostEqual(m[0,1],3.0,12) + self.assertAlmostEqual(m[0,3],7.0,12) + self.assertAlmostEqual(m[0,6],10.0,12) + self.assertAlmostEqual(m[1,1],12.0,12) + self.assertAlmostEqual(m[1,2],23.0,12) + self.assertEqual(m.shape,(2,8)) + pass + def checkMatrix(self,mat1,mat2,nbCols,eps): self.assertEqual(len(mat1),len(mat2)) for i in range(len(mat1)): -- 2.39.2