2017-05-03 18 views
6

Reading https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py dla funkcji average_gradients następujący komentarz jest: Note that this function provides a synchronization point across all towers. Czy funkcja average_gradients blokowania połączeń i co rozumie się przez synchronization point?Tensorflow punkt synchronizacji cifar

Zakładam, że jest to wywołanie blokujące, ponieważ aby obliczyć średnią gradientów, każdy gradient musi zostać obliczony indywidualnie? Ale gdzie jest kod blokujący, który czeka na wszystkie indywidualne obliczenia gradientowe?

Odpowiedz

6

Sam w sobie nie jest funkcją blokującą. Mogła to być kolejna funkcja z operacjami tensorflow i nadal byłaby to punkt synchronizacji. Blokowanie polega na tym, że używa on argumentu, który zależy od wszystkich wykresów utworzonych w poprzedniej pętli for.

Zasadniczo to, co dzieje się tutaj, to tworzenie wykresu treningowego. Po pierwsze w pętli for for i in xrange(FLAGS.num_gpus) tworzonych jest kilka "wątków" wykresów. Każdy wygląda tak:

strata oblicz -> obliczyć gradienty -> dołącz do tower_grads

Każda z tych wykresów „wątków” jest przypisany do innego GPU przez with tf.device('/gpu:%d' % i) i każdy z nich może pracować niezależnie od siebie (i później będzie działać równolegle). Teraz następnym razem, gdy używa się tower_grads bez specyfikacji urządzenia, tworzy on ciąg dalszy wykresu na głównym urządzeniu, wiążąc wszystkie oddzielne "wątki" wykresu w jedno. Tensorflow upewni się, że każdy "wątek" wykresu, który jest częścią stworzenia tower_grads, został ukończony przed uruchomieniem wykresu wewnątrz funkcji average_gradients. Dlatego później, gdy zostanie wywołana sess.run([train_op, loss]), będzie to punkt synchronizacji wykresu.

Powiązane problemy