2016-10-21 14 views
6

Próbuję napisać skrypt, który pozwoli mi narysować obraz cyfry, a następnie określić, jaka jest cyfra z modelem wyszkolonym na MNIST.Tensorflow - Testowanie sieci neuronowej mniszka z moimi własnymi obrazami

Oto mój kod:

import random 
import image 
from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
import numpy as np 
import scipy.ndimage 

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 


x = tf.placeholder(tf.float32, [None, 784]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 

y = tf.nn.softmax(tf.matmul(x, W) + b) 
y_ = tf.placeholder(tf.float32, [None, 10]) 

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 

init = tf.initialize_all_variables() 

sess = tf.Session() 
sess.run(init) 

for i in range(1000): 
    batch_xs, batch_ys = mnist.train.next_batch(1000) 
    sess.run(train_step, feed_dict= {x: batch_xs, y_: batch_ys}) 

print ("done with training") 


data = np.ndarray.flatten(scipy.ndimage.imread("im_01.jpg", flatten=True)) 

result = sess.run(tf.argmax(y,1), feed_dict={x: [data]}) 

print (' '.join(map(str, result))) 

Z jakiegoś powodu zawsze wyniki są złe, ale ma dokładność 92%, gdy używam standardowej metody testowania.

Myślę, że problem może być jak ja zakodowany obraz:

data = np.ndarray.flatten(scipy.ndimage.imread("im_01.jpg", flatten=True)) 

Próbowałem patrząc w kodzie tensorflow dla the next_batch() function aby zobaczyć, jak oni to zrobili, ale nie mam pojęcia, w jaki sposób można porównać z moim podejście.

Problem może być również w innym miejscu.

Każda pomoc w dokładności 80 +% byłaby bardzo doceniana.

+1

Jeśli chodzi o kodowanie obrazu, spróbuj użyć .png. Z moich testów format .jpg jest zły, o ile pozostawia artefakty (szare piksele) na obrazie. – Link

Odpowiedz

6

znalazłem mój błąd: jest kodowany odwrotnie, czarni byli na 255 zamiast 0.

data = np.vectorize(lambda x: 255 - x)(np.ndarray.flatten(scipy.ndimage.imread("im_01.jpg", flatten=True))) 

naprawił.

+0

Dziękuję, to było pomocne. –

+0

Czy możesz podać wymiary użytego obrazu? – Pre

+1

@ Tutaj Oto [jeden z obrazów, które testowałem przeciwko] (https://github.com/cloutier/tf/blob/master/im_01.jpg) –

Powiązane problemy