2015-12-14 16 views
5

Próbuję obliczyć średnią wartości niezerowych w każdym wierszu macierzy rzadkich rzędów. Stosując metodę średniej matrycy nie robi tego:Średnia niezerowych wartości w macierzy rzadkiej?

>>> from scipy.sparse import csr_matrix 
>>> a = csr_matrix([[0, 0, 2], [1, 3, 8]]) 
>>> a.mean(axis=1) 
matrix([[ 0.66666667], 
     [ 4.  ]]) 

następujące prace ale jest powolne dla dużych matrycach:

>>> import numpy as np 
>>> b = np.zeros(a.shape[0]) 
>>> for i in range(a.shape[0]): 
... b[i] = a.getrow(i).data.mean() 
... 
>>> b 
array([ 2., 4.]) 

Czy ktoś proszę powiedzieć, czy istnieje szybszy sposób?

Odpowiedz

4

Wydaje typowy problem, gdzie można skorzystać numpy.bincount. Do tego skorzystała z trzech funkcji:

(x,y,z)=scipy.sparse.find(a) 

zwraca wierszy (x), kolumny (y) i wartości (z) z nielicznych matrycy. Dla instace, x jest zwraca, dla każdego numeru wiersza, jak meny niezerowe ma Ciebie.

numpy.bincount(x,wights=z) zwraca dla każdego wiersza sumy elementów niezerowych.

Ostateczna kod roboczych:

from scipy.sparse import csr_matrix 
a = csr_matrix([[0, 0, 2], [1, 3, 8]]) 

import numpy 
import scipy.sparse 
(x,y,z)=scipy.sparse.find(a) 
countings=numpy.bincount(x) 
sums=numpy.bincount(x,weights=z) 
averages=sums/countings 

print(averages) 

powraca:

[ 2. 4.] 
+0

Doskonale, dzięki – batsc

5

z matrycą formatu CSR, można to zrobić jeszcze łatwiej:

sums = a.sum(axis=1).A1 
counts = np.diff(a.indptr) 
averages = sums/counts 

wiersz sumy są bezpośrednio obsługiwany, a struktura formatu CSR oznacza, że ​​różnica między kolejnymi wartościami w indptr array odpowiadają dokładnie liczbie niezerowych elementów w każdym wierszu.

1

Zawsze lubię sumowanie wartości w dowolnej osi, którą interesuje i dzielenie przez sumę niezerowych elementów w odpowiednim wierszu/kolumnie.

tak:

sp_arr = csr_matrix([[0, 0, 2], [1, 3, 8]]) 
col_avg = sp_arr.sum(0)/(sp_arr != 0).sum(0) 
row_avg = sp_arr.sum(1)/(sp_arr != 0).sum(1) 
print(col_avg) 
matrix([[ 1., 3., 5.]]) 
print(row_avg) 
matrix([[ 2.], 
     [ 4.]]) 

Zasadniczo są zsumowanie wartości wszystkich wpisów wzdłuż danej osi i podzielenie przez sumę True wpisów gdzie macierz = 0 (czyli liczbę prawdziwe! wpisy).

Uważam, że to podejście jest mniej skomplikowane i łatwiejsze niż inne opcje.

Powiązane problemy