Krótka odpowiedź:
- Trzeba zdać globalny krok do optymalizatora możesz przekazać do mon_sess.run. Umożliwia to zapisywanie i pobieranie zapisanych punktów kontrolnych.
- Możliwe jest jednoczesne uruchomienie treningu + sesji walidacji krzyżowej za pośrednictwem pojedynczej sesji MonitoredTraining. Po pierwsze, musisz przejść przez partie treningowe i krzyżować partie walidacji poprzez oddzielne strumienie tego samego wykresu (polecam wyszukać this guide po informacje, jak to zrobić). Po drugie, musisz - do mon_sess.run() - przekazać optymalizator dla strumienia treningowego, a także parametr straty (/ parametru, który chcesz śledzić) strumienia sprawdzania krzyżowego. Jeśli chcesz uruchomić sesję testową oddzielnie od szkolenia, po prostu uruchom tylko zestaw testowy na wykresie i uruchom tylko test_loss (/ inne parametry, które chcesz śledzić) za pośrednictwem wykresu. Aby uzyskać więcej informacji o tym, jak to zrobić, spójrz poniżej.
Długa odpowiedź:
będę aktualizować moją odpowiedź jak ja dostać lepszy widok na to, co można zrobić z tf.train.MonitoredSession (tf.train.MonitoredTrainingSession jest po prostu tworzenie specjalistycznego wersję tf.train.MonitoredSession, jak można zobaczyć w source code).
Poniżej przedstawiono przykładowy kod pokazujący sposób zapisywania punktów kontrolnych co 5 sekund do "./ckpt_dir".Gdy przerwane, to restart na ostatniej zapisanej checkpoint:
def train(inputs, labels_onehot, global_step):
out = tf.contrib.layers.fully_connected(
inputs,
num_outputs=10,
activation_fn=tf.nn.sigmoid)
loss = tf.reduce_mean(
tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=out,
labels=labels_onehot), axis=1))
train_op = opt.minimize(loss, global_step=global_step)
return train_op
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
inputs = ...
labels_onehot = ...
train_op = train(inputs, labels_onehot, global_step)
with tf.train.MonitoredTrainingSession(
checkpoint_dir='./ckpt_dir',
save_checkpoint_secs=5,
hooks=[ ... ] # Choose your hooks
) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
Co się dzieje w MonitoredTrainingSession aby osiągnąć to właściwie trzy rzeczy:
- tf.train.MonitoredTrainingSession tworzy się tf.train.Scaffold obiekt, który działa jak pająk w sieci; gromadzi elementy potrzebne do treningu, zapisywania i ładowania modelu.
- Tworzy obiekt tf.train.ChiefSessionCreator. Moja wiedza na ten temat jest ograniczona, ale z mojego rozumienia tego wynika, że kiedy twój algorytm tf rozprzestrzenia się na wiele serwerów. Moim zdaniem jest to, że mówi komputerowi, na którym działa plik, że jest to główny komputer, i że tutaj należy zapisać katalog punktów kontrolnych, a rejestratory powinny rejestrować swoje dane tutaj, itp.
- Tworzy tf.train.CheckpointSaverHook, który służy do zapisywania punktów kontrolnych.
Aby to zadziałało, tf.train.CheckpointSaverHook i tf.train.ChiefSessionCreator muszą być przekazane te same odniesienia do katalogu i ringu punktu kontrolnego. Jeśli tf.train.MonitoredTrainingSession z jego parametrów w przykładzie powyżej miały być realizowane z 3 elementów powyżej, będzie to wyglądać mniej więcej tak:
checkpoint_dir = './ckpt_dir'
scaffold = tf.train.Scaffold()
saverhook = tf.train.CheckpointSaverHook(
checkpoint_dir=checkpoint_dir,
save_secs=5
scaffold=scaffold
)
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir
)
with tf.train.MonitoredSession(
session_creator=session_creator,
hooks=[saverhook]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
W tym celu sesji sprawdzania pociąg + krzyż, ty trzeba tylko przechodzą dwa zestawy po jednym wykresie, a następnie uruchomić (w pętli powyżej)
mon_sess.run([train_op, cross_validation_loss])
Uruchomiony optymalizator szkolenia dla zbioru treningowego, jak również parametr validation_loss dla walidacji zestaw. Jeśli Twój wykres jest prawidłowo zaimplementowany, oznacza to, że wykres będzie trenowany tylko na zbiorze treningowym i sprawdzany tylko w zestawie sprawdzania poprawności krzyżowej.