2016-02-23 13 views
18

Używam scikit-learn do klasyfikacji dokumentów tekstowych (22000) do 100 klas. Używam metody macierzy zamieszania scikit-learn w celu obliczenia macierzy zamieszania.Jak mogę wykreślić matrycę zamieszania?

model1 = LogisticRegression() 
model1 = model1.fit(matrix, labels) 
pred = model1.predict(test_matrix) 
cm=metrics.confusion_matrix(test_labels,pred) 
print(cm) 
plt.imshow(cm, cmap='binary') 

To jest jak moja macierz zamieszanie wygląda następująco:

[[3962 325 0 ..., 0 0 0] 
[ 250 2765 0 ..., 0 0 0] 
[ 2 8 17 ..., 0 0 0] 
..., 
[ 1 6 0 ..., 5 0 0] 
[ 1 1 0 ..., 0 0 0] 
[ 9 0 0 ..., 0 0 9]] 

Ja jednak nie otrzymują wyraźny lub czytelnego wykresu. Czy jest lepszy sposób to zrobić?

Odpowiedz

13

@ amillerrhodes daje doskonałą odpowiedź w How to plot confusion matrix with string axis rather than integer in python.

confusion matrix example

Oto kod, który generuje powyższy obraz

import numpy as np 
import matplotlib.pyplot as plt 

conf_arr = [[33,2,0,0,0,0,0,0,0,1,3], 
      [3,31,0,0,0,0,0,0,0,0,0], 
      [0,4,41,0,0,0,0,0,0,0,1], 
      [0,1,0,30,0,6,0,0,0,0,1], 
      [0,0,0,0,38,10,0,0,0,0,0], 
      [0,0,0,3,1,39,0,0,0,0,4], 
      [0,2,2,0,4,1,31,0,0,0,2], 
      [0,1,0,0,0,0,0,36,0,2,0], 
      [0,0,0,0,0,0,1,5,37,5,1], 
      [3,0,0,0,0,0,0,0,0,39,0], 
      [0,0,0,0,0,0,0,0,0,0,38]] 

norm_conf = [] 
for i in conf_arr: 
    a = 0 
    tmp_arr = [] 
    a = sum(i, 0) 
    for j in i: 
     tmp_arr.append(float(j)/float(a)) 
    norm_conf.append(tmp_arr) 

fig = plt.figure() 
plt.clf() 
ax = fig.add_subplot(111) 
ax.set_aspect(1) 
res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, 
       interpolation='nearest') 

width, height = conf_arr.shape 

for x in xrange(width): 
    for y in xrange(height): 
     ax.annotate(str(conf_arr[x][y]), xy=(y, x), 
        horizontalalignment='center', 
        verticalalignment='center') 

cb = fig.colorbar(res) 
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 
plt.xticks(range(width), alphabet[:width]) 
plt.yticks(range(height), alphabet[:height]) 
plt.savefig('confusion_matrix.png', format='png') 

nadzieję, że to pomaga.

43

enter image description here

można użyć plt.matshow() zamiast plt.imshow() lub użyć Seaborn modułu heatmap wykreślić matrycy Zamieszanie

import seaborn as sn 
import pandas as pd 
import matplotlib.pyplot as plt 
array = [[33,2,0,0,0,0,0,0,0,1,3], 
     [3,31,0,0,0,0,0,0,0,0,0], 
     [0,4,41,0,0,0,0,0,0,0,1], 
     [0,1,0,30,0,6,0,0,0,0,1], 
     [0,0,0,0,38,10,0,0,0,0,0], 
     [0,0,0,3,1,39,0,0,0,0,4], 
     [0,2,2,0,4,1,31,0,0,0,2], 
     [0,1,0,0,0,0,0,36,0,2,0], 
     [0,0,0,0,0,0,1,5,37,5,1], 
     [3,0,0,0,0,0,0,0,0,39,0], 
     [0,0,0,0,0,0,0,0,0,0,38]] 
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"], 
        columns = [i for i in "ABCDEFGHIJK"]) 
plt.figure(figsize = (10,7)) 
sn.heatmap(df_cm, annot=True) 
14

odpowiedź @bninopaul „s nie jest całkowicie dla początkujących

oto kod można "kopiuj i uruchom"

import seaborn as sn 
import pandas as pd 
import matplotlib.pyplot as plt 

array = [[13,1,1,0,2,0], 
    [3,9,6,0,1,0], 
    [0,0,16,2,0,0], 
    [0,0,0,13,0,0], 
    [0,0,0,0,15,0], 
    [0,0,1,0,0,15]]   
df_cm = pd.DataFrame(array, range(6), 
        range(6)) 
#plt.figure(figsize = (10,7)) 
sn.set(font_scale=1.4)#for label size 
sn.heatmap(df_cm, annot=True,annot_kws={"size": 16})# font size 

result