Tuesday, December 1, 2009

matlab sparse matrix format

I have encountered several cases to handle matlab sparse matrix in mex c or c++.

Here, I use a simple example to explain how the sparse matrix is stored in matlab (Very creative, but not necessarily for my usage. Hence, it is always convenient to understand the structure)

The sparse matrix format in matlab involve several components:
nz: the number of non-zeros
m: the number of rows
n: the number of cols

jc: the index of columns.
ir: the index of rows;
pr: the nonzero entries stored in a double vector.

The tricky part is the relation between jc, ir and pr.

For example, consider a 7-by-3 sparse mxArray named Sparrow containing six nonzero elements, created by typing:

Sparrow = zeros(7,3);
Sparrow(2,1) = 1;
Sparrow(5,1) = 1;
Sparrow(3,2) = 1;
Sparrow(2,3) = 2;
Sparrow(5,3) = 1;
Sparrow(6,3) = 1;
Sparrow = sparse(Sparrow);
Then, the matrix looks like below:

>> full(Sparrow)

ans =

0 0 0
1 0 2
0 1 0
0 0 0
1 0 1
0 0 1
0 0 0

Then

The contents of the ir, jc, and pr arrays are listed in this table.

Subscript

ir

pr

jc

Comment

(2,1)

1

1

0

Column 1 contains two nonzero elements, with rows designated by ir[0] and ir[1]

(5,1)

4

1

2

Column 2 contains one nonzero element, with row designated by ir[2]

(3,2)

2

1

3

Column 3 contains three nonzero elements, with rows designated by ir[3],ir[4], and ir[5]

(2,3)

1

2

6

There are six nonzero elements in all.

(5,3)

4

1



(6,3)

5

1




If the jth column of the sparse mxArray has any nonzero elements:

  • jc[j] is the index in ir, pr, and pi (if it exists) of the first nonzero element in the jth column.

  • jc[j+1]-1 is the index of the last nonzero element in the jth column.

  • For the jth column of the sparse matrix, jc[j] is the total number of nonzero elements in all preceding columns.

The number of nonzero elements in the jth column of the sparse mxArray is:

jc[j+1] - jc[j];

Note that the size of jc is n+1.
the size of ir is nz, the same as pr.

Hence, to iterate over the spare matrix in c, you can use the following code:

for (col=0; col < n; col++){
startIndex = jc[col];
endIndex = jc[col+1];
for (i=startIndex; i < endIndex; i++){
row = ir[i];
val = pr[i];
......

}

}