2017-03-29 19 views
6

W Tensorflow mogliśmy tworzyć i tworzyć wiele sesji Tensorflow, korzystając z szkolenia Between-graph Replication. MonitoredTrainingSession() koordynuje wiele sesji Tensorflow i istnieje argument checkpoint_dir dla MonitoredTrainingSession() do przywrócenia sesji/wykresu Tensorflow. Teraz mam następujące pytania:W jaki sposób funkcja `MonitoredTrainingSession()` działa z "przywracaniem" i "trybem testowania"?

  1. zazwyczaj używamy przedmiot tf.train.Saver() przywrócić wykresy Tensorflow przez saver.restore(...). Ale jak możemy je przywrócić za pomocą MonitoredTrainingSession()?
  2. Ponieważ prowadzimy wiele procesów, a każdy proces buduje i tworzy sesję Tensorflow do treningu, zastanawiam się, czy po treningu musimy również uruchomić wiele procesów do testowania (lub przewidywania). Innymi słowy, w jaki sposób MonitoredTrainingSession() działa z trybem testowania (lub przewidywania)?

Przeczytałem Tensorflow Doc, ale nie znalazłem odpowiedzi na te 2 pytania. Naprawdę doceniam, jeśli ktoś ma rozwiązania. Dzięki!

Odpowiedz

-1
  1. Wygląda na to, że przywracanie jest obsługiwane. Int on Dokumentacja API mówi, że dzwoni MonitoredTrainingSession zwraca instancję MonitoredSession który na stworzeniu „... przywraca zmienne jeśli punkt kontrolny istnieje ...”

  2. odjazdu tf.contrib.learn.Estimator(..).predict(..) a dokładniej tf.contrib.learn.Estimator(..)._infer_model(..) metody here i here. Tworzą także tam MonitoredSession.

0

Krótka odpowiedź:

  1. Trzeba zdać globalny krok do optymalizatora możesz przekazać do mon_sess.run. Umożliwia to zapisywanie i pobieranie zapisanych punktów kontrolnych.
  2. Możliwe jest jednoczesne uruchomienie treningu + sesji walidacji krzyżowej za pośrednictwem pojedynczej sesji MonitoredTraining. Po pierwsze, musisz przejść przez partie treningowe i krzyżować partie walidacji poprzez oddzielne strumienie tego samego wykresu (polecam wyszukać this guide po informacje, jak to zrobić). Po drugie, musisz - do mon_sess.run() - przekazać optymalizator dla strumienia treningowego, a także parametr straty (/ parametru, który chcesz śledzić) strumienia sprawdzania krzyżowego. Jeśli chcesz uruchomić sesję testową oddzielnie od szkolenia, po prostu uruchom tylko zestaw testowy na wykresie i uruchom tylko test_loss (/ inne parametry, które chcesz śledzić) za pośrednictwem wykresu. Aby uzyskać więcej informacji o tym, jak to zrobić, spójrz poniżej.

Długa odpowiedź:

będę aktualizować moją odpowiedź jak ja dostać lepszy widok na to, co można zrobić z tf.train.MonitoredSession (tf.train.MonitoredTrainingSession jest po prostu tworzenie specjalistycznego wersję tf.train.MonitoredSession, jak można zobaczyć w source code).

Poniżej przedstawiono przykładowy kod pokazujący sposób zapisywania punktów kontrolnych co 5 sekund do "./ckpt_dir".Gdy przerwane, to restart na ostatniej zapisanej checkpoint:

def train(inputs, labels_onehot, global_step): 
    out = tf.contrib.layers.fully_connected(
          inputs, 
          num_outputs=10, 
          activation_fn=tf.nn.sigmoid) 
    loss = tf.reduce_mean(
      tf.reduce_sum(
       tf.nn.sigmoid_cross_entropy_with_logits(
          logits=out, 
          labels=labels_onehot), axis=1)) 
    train_op = opt.minimize(loss, global_step=global_step) 
    return train_op 

with tf.Graph().as_default(): 
    global_step = tf.train.get_or_create_global_step() 
    inputs = ... 
    labels_onehot = ... 
    train_op = train(inputs, labels_onehot, global_step) 

    with tf.train.MonitoredTrainingSession(
     checkpoint_dir='./ckpt_dir', 
     save_checkpoint_secs=5, 
     hooks=[ ... ] # Choose your hooks 
    ) as mon_sess: 
     while not mon_sess.should_stop(): 
      mon_sess.run(train_op) 

Co się dzieje w MonitoredTrainingSession aby osiągnąć to właściwie trzy rzeczy:

  1. tf.train.MonitoredTrainingSession tworzy się tf.train.Scaffold obiekt, który działa jak pająk w sieci; gromadzi elementy potrzebne do treningu, zapisywania i ładowania modelu.
  2. Tworzy obiekt tf.train.ChiefSessionCreator. Moja wiedza na ten temat jest ograniczona, ale z mojego rozumienia tego wynika, że ​​kiedy twój algorytm tf rozprzestrzenia się na wiele serwerów. Moim zdaniem jest to, że mówi komputerowi, na którym działa plik, że jest to główny komputer, i że tutaj należy zapisać katalog punktów kontrolnych, a rejestratory powinny rejestrować swoje dane tutaj, itp.
  3. Tworzy tf.train.CheckpointSaverHook, który służy do zapisywania punktów kontrolnych.

Aby to zadziałało, tf.train.CheckpointSaverHook i tf.train.ChiefSessionCreator muszą być przekazane te same odniesienia do katalogu i ringu punktu kontrolnego. Jeśli tf.train.MonitoredTrainingSession z jego parametrów w przykładzie powyżej miały być realizowane z 3 elementów powyżej, będzie to wyglądać mniej więcej tak:

checkpoint_dir = './ckpt_dir' 

scaffold = tf.train.Scaffold() 
saverhook = tf.train.CheckpointSaverHook(
    checkpoint_dir=checkpoint_dir, 
    save_secs=5 
    scaffold=scaffold 
) 
session_creator = tf.train.ChiefSessionCreator(
    scaffold=scaffold, 
    checkpoint_dir=checkpoint_dir 
) 

with tf.train.MonitoredSession(
    session_creator=session_creator, 
    hooks=[saverhook]) as mon_sess: 
     while not mon_sess.should_stop(): 
      mon_sess.run(train_op) 

W tym celu sesji sprawdzania pociąg + krzyż, ty trzeba tylko przechodzą dwa zestawy po jednym wykresie, a następnie uruchomić (w pętli powyżej)

mon_sess.run([train_op, cross_validation_loss]) 

Uruchomiony optymalizator szkolenia dla zbioru treningowego, jak również parametr validation_loss dla walidacji zestaw. Jeśli Twój wykres jest prawidłowo zaimplementowany, oznacza to, że wykres będzie trenowany tylko na zbiorze treningowym i sprawdzany tylko w zestawie sprawdzania poprawności krzyżowej.

Powiązane problemy