From bcebadd7a080432c6f78e4989abb634a56559fbe Mon Sep 17 00:00:00 2001 From: ageay Date: Mon, 23 Sep 2013 10:42:35 +0000 Subject: [PATCH] CSR output matrix MEDCouplingRemapper::getCrudeCSRMatrix --- .../MEDCouplingRemapperTest.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/src/MEDCoupling_Swig/MEDCouplingRemapperTest.py b/src/MEDCoupling_Swig/MEDCouplingRemapperTest.py index 60e52912e..36c6e50a9 100644 --- a/src/MEDCoupling_Swig/MEDCouplingRemapperTest.py +++ b/src/MEDCoupling_Swig/MEDCouplingRemapperTest.py @@ -586,6 +586,63 @@ class MEDCouplingBasicsTest(unittest.TestCase): self.assertAlmostEqual(0.3438363717836785 ,m[6][188],12) self.assertAlmostEqual(0.3521445110626687 ,m[6][170],12) pass + + @unittest.skipUnless(MEDCouplingHasNumPyBindings(),"requires numpy") + def testGetCrudeCSRMatrix1(self): + """ testing CSR matrix output using numpy/scipy. + """ + from scipy.sparse import diags + import scipy + from numpy import array + arr=DataArrayDouble(3) ; arr.iota() + m=MEDCouplingCMesh() ; m.setCoords(arr,arr) + src=m.buildUnstructured() + trg=src.deepCpy() ; trg=trg[[0,1,3]] + trg.getCoords()[:]*=0.5 ; trg.getCoords()[:]+=[0.3,0.25] + # Let's interpolate. + rem=MEDCouplingRemapper() + rem.prepare(src,trg,"P0P0") + # Internal crude sparse matrix computed. Let's manipulate it using CSR matrix in scipy. + for i in xrange(10): + m=rem.getCrudeCSRMatrix() + pass + m2=rem.getCrudeCSRMatrix() + diff=m-m2 + assert(isinstance(m,scipy.sparse.csr.csr_matrix)) + assert(m.getnnz()==7) + self.assertAlmostEqual(m[0,0],0.25,12) + self.assertAlmostEqual(m[1,0],0.1,12) + self.assertAlmostEqual(m[1,1],0.15,12) + self.assertAlmostEqual(m[2,0],0.05,12) + self.assertAlmostEqual(m[2,1],0.075,12) + self.assertAlmostEqual(m[2,2],0.05,12) + self.assertAlmostEqual(m[2,3],0.075,12) + self.assertEqual(diff.getnnz(),0) + # IntegralGlobConstraint (division by sum of cols) + colSum=m.sum(axis=0) + m_0=m*diags(array(1/colSum),[0]) + del colSum + self.assertAlmostEqual(m_0[0,0],0.625,12) + self.assertAlmostEqual(m_0[1,0],0.25,12) + self.assertAlmostEqual(m_0[1,1],0.6666666666666667,12) + self.assertAlmostEqual(m_0[2,0],0.125,12) + self.assertAlmostEqual(m_0[2,1],0.3333333333333333,12) + self.assertAlmostEqual(m_0[2,2],1.,12) + self.assertAlmostEqual(m_0[2,3],1.,12) + assert(m_0.getnnz()==7) + # ConservativeVolumic (division by sum of rows) + rowSum=m.sum(axis=1) + m_1=diags(array(1/rowSum.transpose()),[0])*m + del rowSum + self.assertAlmostEqual(m_1[0,0],1.,12) + self.assertAlmostEqual(m_1[1,0],0.4,12) + self.assertAlmostEqual(m_1[1,1],0.6,12) + self.assertAlmostEqual(m_1[2,0],0.2,12) + self.assertAlmostEqual(m_1[2,1],0.3,12) + self.assertAlmostEqual(m_1[2,2],0.2,12) + self.assertAlmostEqual(m_1[2,3],0.3,12) + assert(m_1.getnnz()==7) + pass def build2DSourceMesh_1(self): sourceCoords=[-0.3,-0.3, 0.7,-0.3, -0.3,0.7, 0.7,0.7] -- 2.39.2