Napisałem RNN language model using TensorFlow. Model jest implementowany jako klasa RNN
. Struktura wykresu jest wbudowana w konstruktorze, a następnie uruchamia go metoda RNN.train
i RNN.test
.Jak ustawić stan TensorFlow RNN, gdy state_is_tuple = True?
Chcę móc zresetować stan RNN po przejściu do nowego dokumentu w zestawie treningowym lub gdy chcę uruchomić zestaw sprawdzania poprawności podczas treningu. Robię to, zarządzając stanem wewnątrz pętli treningowej, przekazując go do wykresu za pomocą słownika.
W konstruktorze I zdefiniować w RNN jak tak
cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
initial_state=self.state)
Pętla trening wygląda następująco
for document in document:
state = session.run(self.reset_state)
for x, y in document:
_, state = session.run([self.train_step, self.next_state],
feed_dict={self.x:x, self.y:y, self.state:state})
x
i y
są partie danych treningowych w dokumencie. Chodzi o to, że przekazuję najnowszy stan po każdej partii, z wyjątkiem sytuacji, gdy zaczynam nowy dokument, kiedy zeruję stan, uruchamiając self.reset_state
.
To wszystko działa. Teraz chcę zmienić mój RNN, aby użyć zalecanego state_is_tuple=True
. Nie wiem jednak, jak przekazać bardziej skomplikowany obiekt stanu LSTM przez słownik kanału. Również nie wiem, jakie argumenty należy przekazać do linii self.state = tf.placeholder(...)
w moim konstruktorze.
Jaka jest tutaj poprawna strategia? Wciąż nie ma wiele przykładowego kodu lub dokumentacji dla dynamic_rnn
dostępnych.
TensorFlow kwestie 2695 i 2838 pojawiają się istotne.
A blog post na WILDML rozwiązuje te problemy, ale nie podaje bezpośrednio odpowiedzi.
Zobacz także TensorFlow: Remember LSTM state for next batch (stateful LSTM).
sprawdź "rnn_cell._unpacked_state" i "rnn_cell._packed_state". Są one używane w 'rnn._dynamic_rnn_loop()' do przekazania stanu jako listy tensorów argumentów do funkcji pętli. – JunkMechanic
Nie widzę ciągów "_unpacked_state" i "_packed_state" w najnowszym źródle TensorFlow. Czy te nazwy zostały zmienione? –
Hmm. Te zostały usunięte. Zamiast tego wprowadzono nowy moduł 'tf.python.util.nest' z analogami' flatten' i 'pack_sequence_as'. – JunkMechanic