2016-03-18 22 views
9

Podążam za tutorialem w this link i próbuję zmienić metodę oceny dla modelu (na dole). Chciałbym, aby uzyskać najwyższą ocenę i-5 Próbuję użyć poniższego kodu:TensorFlow in_top_k dane wejściowe do oceny

topFiver=tf.nn.in_top_k(y, y_, 5, name=None) 

to jednak daje następujący błąd:

File "AlexNet.py", line 111, in <module> 
    topFiver = tf.nn.in_top_k(pred, y, 5, name=None) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_nn_ops.py", line 346, in in_top_k 
    targets=targets, k=k, name=name) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 486, in apply_op 
    _Attr(op_def, input_arg.type_attr)) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 59, in _SatisfiesTypeConstraint 
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) 
TypeError: DataType float32 for attr 'T' not in list of allowed values: int32, int64 

O ile mogę powiedzieć, Problem polega na tym, że tf.nn.in_top_k() działa tylko dla danych tf.int32 lub tf.int64, ale moje dane są w formacie tf.float32. Czy istnieje jakieś obejście tego problemu?

Odpowiedz

19

Argument targets dla tf.nn.in_top_k(predictions, targets, k) musi być wektorem ID klas (tj. Wskaźników kolumn w macierzy predictions). Oznacza to, że działa tylko w przypadku problemów z klasyfikacją jednej klasy.

Jeśli Twój problem jest problemem jednej klasy, to zakładam, że twój tensor y_ jest jednoznacznym kodowaniem prawdziwych etykiet dla twoich przykładów (na przykład, ponieważ przekazujesz je również do operacji typu tf.nn.softmax_cross_entropy_with_logits(). przypadek, masz dwie możliwości:.

  • Jeśli etykiety były pierwotnie zapisany jako liczba całkowita etykiet, przekazać je bezpośrednio do tf.nn.in_top_k() bez konwertowania ich do jednego gorąco (również rozważyć użycie tf.nn.sparse_softmax_cross_entropy_with_logits() jako funkcji straty, ponieważ może być bardziej wydajnym.)
  • Jeśli etykiety były pierwotnie przechowywane w jeden gorący Format, można przekonwertować je do liczb całkowitych przy użyciu tf.argmax():

    labels = tf.argmax(y_, 1) 
    topFiver = tf.nn.in_top_k(y, labels, 5) 
    
Powiązane problemy