]> SALOME platform Git repositories - modules/med.git/commitdiff
Salome HOME
CSR output matrix MEDCouplingRemapper::getCrudeCSRMatrix
authorageay <ageay>
Mon, 23 Sep 2013 10:42:35 +0000 (10:42 +0000)
committerageay <ageay>
Mon, 23 Sep 2013 10:42:35 +0000 (10:42 +0000)
src/MEDCoupling_Swig/MEDCouplingRemapperTest.py

index 60e52912e4b6fa882f93fdc9803d40fcefa3f1ee..36c6e50a9e2a5bd3d022ba1e2c10f976b4763858 100644 (file)
@@ -586,6 +586,63 @@ class MEDCouplingBasicsTest(unittest.TestCase):
         self.assertAlmostEqual(0.3438363717836785 ,m[6][188],12)
         self.assertAlmostEqual(0.3521445110626687 ,m[6][170],12)
         pass
+
+    @unittest.skipUnless(MEDCouplingHasNumPyBindings(),"requires numpy")
+    def testGetCrudeCSRMatrix1(self):
+        """ testing CSR matrix output using numpy/scipy.
+        """
+        from scipy.sparse import diags
+        import scipy
+        from numpy import array
+        arr=DataArrayDouble(3) ; arr.iota()
+        m=MEDCouplingCMesh() ; m.setCoords(arr,arr)
+        src=m.buildUnstructured()
+        trg=src.deepCpy() ; trg=trg[[0,1,3]]
+        trg.getCoords()[:]*=0.5 ; trg.getCoords()[:]+=[0.3,0.25]
+        # Let's interpolate.
+        rem=MEDCouplingRemapper()
+        rem.prepare(src,trg,"P0P0")
+        # Internal crude sparse matrix computed. Let's manipulate it using CSR matrix in scipy.
+        for i in xrange(10):
+            m=rem.getCrudeCSRMatrix()
+            pass
+        m2=rem.getCrudeCSRMatrix()
+        diff=m-m2
+        assert(isinstance(m,scipy.sparse.csr.csr_matrix))
+        assert(m.getnnz()==7)
+        self.assertAlmostEqual(m[0,0],0.25,12)
+        self.assertAlmostEqual(m[1,0],0.1,12)
+        self.assertAlmostEqual(m[1,1],0.15,12)
+        self.assertAlmostEqual(m[2,0],0.05,12)
+        self.assertAlmostEqual(m[2,1],0.075,12)
+        self.assertAlmostEqual(m[2,2],0.05,12)
+        self.assertAlmostEqual(m[2,3],0.075,12)
+        self.assertEqual(diff.getnnz(),0)
+        # IntegralGlobConstraint (division by sum of cols)
+        colSum=m.sum(axis=0)
+        m_0=m*diags(array(1/colSum),[0])
+        del colSum
+        self.assertAlmostEqual(m_0[0,0],0.625,12)
+        self.assertAlmostEqual(m_0[1,0],0.25,12)
+        self.assertAlmostEqual(m_0[1,1],0.6666666666666667,12)
+        self.assertAlmostEqual(m_0[2,0],0.125,12)
+        self.assertAlmostEqual(m_0[2,1],0.3333333333333333,12)
+        self.assertAlmostEqual(m_0[2,2],1.,12)
+        self.assertAlmostEqual(m_0[2,3],1.,12)
+        assert(m_0.getnnz()==7)
+        # ConservativeVolumic (division by sum of rows)
+        rowSum=m.sum(axis=1)
+        m_1=diags(array(1/rowSum.transpose()),[0])*m
+        del rowSum
+        self.assertAlmostEqual(m_1[0,0],1.,12)
+        self.assertAlmostEqual(m_1[1,0],0.4,12)
+        self.assertAlmostEqual(m_1[1,1],0.6,12)
+        self.assertAlmostEqual(m_1[2,0],0.2,12)
+        self.assertAlmostEqual(m_1[2,1],0.3,12)
+        self.assertAlmostEqual(m_1[2,2],0.2,12)
+        self.assertAlmostEqual(m_1[2,3],0.3,12)
+        assert(m_1.getnnz()==7)
+        pass
     
     def build2DSourceMesh_1(self):
         sourceCoords=[-0.3,-0.3, 0.7,-0.3, -0.3,0.7, 0.7,0.7]