2016-08-24 14 views
10

Napisałem RNN language model using TensorFlow. Model jest implementowany jako klasa RNN. Struktura wykresu jest wbudowana w konstruktorze, a następnie uruchamia go metoda RNN.train i RNN.test.Jak ustawić stan TensorFlow RNN, gdy state_is_tuple = True?

Chcę móc zresetować stan RNN po przejściu do nowego dokumentu w zestawie treningowym lub gdy chcę uruchomić zestaw sprawdzania poprawności podczas treningu. Robię to, zarządzając stanem wewnątrz pętli treningowej, przekazując go do wykresu za pomocą słownika.

W konstruktorze I zdefiniować w RNN jak tak

cell = tf.nn.rnn_cell.LSTMCell(hidden_units) 
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers) 
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32) 
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state") 
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True, 
                initial_state=self.state) 

Pętla trening wygląda następująco

for document in document: 
    state = session.run(self.reset_state) 
    for x, y in document: 
      _, state = session.run([self.train_step, self.next_state], 
           feed_dict={self.x:x, self.y:y, self.state:state}) 

x i y są partie danych treningowych w dokumencie. Chodzi o to, że przekazuję najnowszy stan po każdej partii, z wyjątkiem sytuacji, gdy zaczynam nowy dokument, kiedy zeruję stan, uruchamiając self.reset_state.

To wszystko działa. Teraz chcę zmienić mój RNN, aby użyć zalecanego state_is_tuple=True. Nie wiem jednak, jak przekazać bardziej skomplikowany obiekt stanu LSTM przez słownik kanału. Również nie wiem, jakie argumenty należy przekazać do linii self.state = tf.placeholder(...) w moim konstruktorze.

Jaka jest tutaj poprawna strategia? Wciąż nie ma wiele przykładowego kodu lub dokumentacji dla dynamic_rnn dostępnych.


TensorFlow kwestie 2695 i 2838 pojawiają się istotne.

A blog post na WILDML rozwiązuje te problemy, ale nie podaje bezpośrednio odpowiedzi.

Zobacz także TensorFlow: Remember LSTM state for next batch (stateful LSTM).

+0

sprawdź "rnn_cell._unpacked_state" i "rnn_cell._packed_state". Są one używane w 'rnn._dynamic_rnn_loop()' do przekazania stanu jako listy tensorów argumentów do funkcji pętli. – JunkMechanic

+0

Nie widzę ciągów "_unpacked_state" i "_packed_state" w najnowszym źródle TensorFlow. Czy te nazwy zostały zmienione? –

+0

Hmm. Te zostały usunięte. Zamiast tego wprowadzono nowy moduł 'tf.python.util.nest' z analogami' flatten' i 'pack_sequence_as'. – JunkMechanic

Odpowiedz

13

Jednym z problemów z symbolem zastępczym Tensorflow jest to, że można go podać tylko z listą Pythona lub tablicą Numpy (chyba). Nie można więc zapisać stanu między przebiegami w krotkach LSTMStateTuple.

Rozwiązałem to przez zapisywanie stanu w tensora jak ten

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

masz dwa składniki w warstwie LSTM stan komórka i ukryty stan, to jest to co „2 " pochodzi z. (Ten artykuł jest świetny: https://arxiv.org/pdf/1506.00019.pdf)

Przy budowie wykres ty rozpakować i utworzyć stan krotny takiego:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size]) 
l = tf.unpack(state_placeholder, axis=0) 
rnn_tuple_state = tuple(
     [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1]) 
      for idx in range(num_layers)] 
) 

Następnie pojawi się nowy stan zwykły sposób

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True) 
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True) 

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state) 

It nie powinno tak być ... może pracują nad rozwiązaniem.

+0

Jeśli masz tylko jedną warstwę, staje się ona 'state_placeholder = tf.placeholder (tf.float32, [2, batch_size, state_size])' i 'initial_state = np.zeros ((2, batch_size, state_size))'? – Lukeyb

1

Prostym sposobem na podawanie w stanie RNN jest po prostu podawanie obu składników krotki stanu osobno.

# Constructing the graph 
self.state = rnn_cell.zero_state(...) 
self.output, self.next_state = tf.nn.dynamic_rnn(
    rnn_cell, 
    self.input, 
    initial_state=self.state) 

# Running with initial state 
output, state = sess.run([self.output, self.next_state], feed_dict={ 
    self.input: input 
}) 

# Running with subsequent state: 
output, state = sess.run([self.output, self.next_state], feed_dict={ 
    self.input: input, 
    self.state[0]: state[0], 
    self.state[1]: state[1] 
}) 
Powiązane problemy