2016-09-05 15 views
11

TensorFlow docs opisuje kilka sposobów odczytywania danych za pomocą TFRecordReader, TextLineReader, QueueRunner itp. I kolejek.Jak trenować sieć TensorFlow za pomocą generatora do produkcji wejść?

To, co chciałbym zrobić, jest o wiele prostsze: mam funkcję generatora pythonów, która generuje nieskończoną sekwencję danych treningowych jako krotki (X, y) (obie są tablicami numpy, a pierwszym wymiarem jest partia rozmiar). Chcę po prostu wyszkolić sieć wykorzystującą te dane jako dane wejściowe.

Czy istnieje prosty samodzielny przykład szkolenia sieci TensorFlow przy użyciu generatora, który generuje dane? (Wzdłuż linii przykładach MNIST lub CIFAR)

+2

Istnieje ['tf.data.Dataset.from_generator'] (https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator), który może być przydatny w twoim przypadku. – Jakub

Odpowiedz

15

Załóżmy, że masz funkcję, która generuje dane:

def generator(data): 
    ... 
    yield (X, y) 

teraz trzeba inną funkcję, która opisuje swój model architektury. Może to być dowolna funkcja, która przetwarza X i musi przewidzieć y jako wynik (powiedzmy, sieć neuronowa).

Załóżmy, że funkcja przyjmuje X i Y jako wejścia, oblicza przewidywanemu Y od X w jakiś sposób i powrocie funkcji strat (np przekrój entropii lub MSE w przypadku regresji) między Y i przewidywane Y:

def neural_network(X, y): 
    # computation of prediction for y using X 
    ... 
    return loss(y, y_pred) 

aby kontynuować pracę modelu, trzeba określić zastępcze zarówno dla X i Y, a następnie uruchomić sesję:

X = tf.placeholder(tf.float32, shape=(batch_size, x_dim)) 
y = tf.placeholder(tf.float32, shape=(batch_size, y_dim)) 

zastępcze są SOMET hing jak „wolnych” zmiennych, które należy określić podczas uruchamiania sesji przez feed_dict:

with tf.Session() as sess: 
    # variables need to be initialized before any sess.run() calls 
    tf.global_variables_initializer().run() 

    for X_batch, y_batch in generator(data): 
     feed_dict = {X: X_batch, y: y_batch} 
     _, loss_value, ... = sess.run([train_op, loss, ...], feed_dict) 
     # train_op here stands for optimization operation you have defined 
     # and loss for loss function (return value of neural_network function) 

Nadzieję, że okażą się przydatne. Należy jednak pamiętać, że nie jest to w pełni działająca implementacja, ale raczej pseudokod, ponieważ nie podano prawie żadnych szczegółów.

+0

Czy istnieje sposób przekazania funkcji generatora do modelu estymatora, zamiast ręcznego kodowania następnej funkcji? – skadoosh

+0

@skadoosh Myślę, że powinieneś rozważyć użycie Keras. –

+2

@skadoosh - jeśli chcesz korzystać z tensorflow, wersja 1.6 pozwala 'Estimator.train' akceptować' tf.data.Dataset', które można wykonać za pomocą 'tf.data.Dataset.from_generator'. – Jakub

Powiązane problemy