2017-01-04 8 views
8

Próbuję nauczyć TensorFlow i studiując na przykład: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynbTensorFlow: w jaki sposób zdefiniowano zbiór dataset.train.next_batch?

Potem kilka pytań w poniższym kodzie:

for epoch in range(training_epochs): 
    # Loop over all batches 
    for i in range(total_batch): 
     batch_xs, batch_ys = mnist.train.next_batch(batch_size) 
     # Run optimization op (backprop) and cost op (to get loss value) 
     _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs}) 
    # Display logs per epoch step 
    if epoch % display_step == 0: 
     print("Epoch:", '%04d' % (epoch+1), 
       "cost=", "{:.9f}".format(c)) 

Od mnist jest tylko zbiór danych, co dokładnie ma mnist.train.next_batch oznacza? Jak zdefiniowano dataset.train.next_batch?

Dzięki!

Odpowiedz

19

Obiekt mnist jest zwracany z read_data_sets() function zdefiniowanego w module tf.contrib.learn. Metoda mnist.train.next_batch(batch_size) jest zaimplementowana jako here i zwraca krotkę dwóch tablic, gdzie pierwsza reprezentuje grupę obrazów batch_size MNIST, a druga reprezentuje grupę etykiet batch-size odpowiadających tym obrazom.

Obrazy są zwrócone w tablicy 2-D NumPy wielkości [batch_size, 784] (ponieważ istnieje 784 pikseli w MNIST obrazu) i etykiety są zwrócone w każdej tablicy 1 D NumPy wielkości [batch_size] (jeśli read_data_sets() był wywołana z one_hot=False) lub 2-D NumPy tablica o rozmiarze [batch_size, 10] (jeśli read_data_sets() został wywołany z one_hot=True).

+7

Warto wspomnieć, że [next_batch] (https://github.com/tensorflow/tensorflow/blob/7c36309c37b04843030664cdc64aca2bb7d6ecaa/tensorflow/contrib/learn/python/learn/datasets/mnist.py#L160) przetasowania przykłady po przejściu przez wszystkie z nich w każdej epoce. Możesz śledzić, gdzie jesteś w epoce, przez 'DataSet._index_in_epoch', np.' Mnist.train._index_in_epoch' –

Powiązane problemy