Zauważ, że gdy mówimy o dokładności jednej klasy jeden może odnosić się do jednej z następujących (nie odpowiednik) dwóch kwot:.
- dokładność, co dla klasy C jest stosunek przykładach oznaczonych klasy C że przewiduje się posiadają stopień C.
- przywołanie, co dla klasy C jest stosunek przykładach przewiduje się klasy C, które są w rzeczywistości oznaczonego klasy C.
Zamiast wykonywać skomplikowane indeksowanie, możesz po prostu polegać na maskowaniu w celu obliczenia. Zakładając, że mówimy tu o precyzji (zmiana na przywołanie byłaby trywialna).
from keras import backend as K
INTERESTING_CLASS_ID = 0 # Choose the class of interest
def single_class_accuracy(y_true, y_pred):
class_id_true = K.argmax(y_true, axis=-1)
class_id_preds = K.argmax(y_pred, axis=-1)
# Replace class_id_preds with class_id_true for recall here
accuracy_mask = K.cast(K.equal(class_id_preds, INTERESTING_CLASS_ID), 'int32')
class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask
class_acc = K.sum(class_acc_tensor)/K.maximum(K.sum(accuracy_mask), 1)
return class_acc
Jeśli chcesz być bardziej elastyczne, można również klasę zainteresowania sparametryzowane:
from keras import backend as K
def single_class_accuracy(interesting_class_id):
def fn(y_true, y_pred):
class_id_true = K.argmax(y_true, axis=-1)
class_id_preds = K.argmax(y_pred, axis=-1)
# Replace class_id_preds with class_id_true for recall here
accuracy_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int32')
class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask
class_acc = K.sum(class_acc_tensor)/K.maximum(K.sum(accuracy_mask), 1)
return class_acc
return fn
a używać go jako:
model.compile(..., metrics=[single_class_accuracy(INTERESTING_CLASS_ID)])
ja nie znam Keras i nie wiem, czy twój kod będzie działał z maskami boolowskimi czy jawnymi indeksami. Czy rzuciłeś maskę, aby wpisać boolean? tf.cast (binary_mask, tf.bool). Za pomocą Theano możesz użyć funkcji bool_mask.nonzero(), aby uzyskać indeksy maski boolowskiej. Daj nam znać, czy to rozwiązanie działa. – rafaelvalle
Czy zaakceptowałbyś odpowiedź, która używa wywołania zwrotnego? –
Aby się upewnić - y_true jest 2D? co mają reprezentować wiersze i kolumny? – ginge