2015-11-23 17 views
22

Próbuję zdefiniować własne RNNCell (Echo State Network) w Tensorflow, zgodnie z poniższą definicją.Jak mogę zaimplementować niestandardowy RNN (w szczególności ESN) w Tensorflow?

x (t + 1) = tanh (Win * U (t) + W * x (t) + WFB * y (t))

y (t) = Wout * z (t)

z (t) = [x (t), u (t)]

x jest stanem, u jest wejściem, y jest wynikiem. Win, W i Wfb nie nadają się do treningu. Wszystkie masy są losowo inicjowane, a W jest modyfikowany w następujący sposób: „. Należy pewien procent elementów W 0, wagi W, aby utrzymać widmowej promień poniżej 1,0

że tego kodu do generowania równanie

x = tf.Variable(tf.reshape(tf.zeros([N]), [-1, N]), trainable=False, name="state_vector") 
W = tf.Variable(tf.random_normal([N, N], 0.0, 0.05), trainable=False) 
# TODO: setup W according to the ESN paper 
W_x = tf.matmul(x, W) 

u = tf.placeholder("float", [None, K], name="input_vector") 
W_in = tf.Variable(tf.random_normal([K, N], 0.0, 0.05), trainable=False) 
W_in_u = tf.matmul(u, W_in) 

z = tf.concat(1, [x, u]) 
W_out = tf.Variable(tf.random_normal([K + N, L], 0.0, 0.05)) 
y = tf.matmul(z, W_out) 
W_fb = tf.Variable(tf.random_normal([L, N], 0.0, 0.05), trainable=False) 
W_fb_y = tf.matmul(y, W_fb) 

x_next = tf.tanh(W_in_u + W_x + W_fb_y) 

y_ = tf.placeholder("float", [None, L], name="train_output") 

Mój problem jest dwojaki. Po pierwsze nie wiem jak zaimplementować to jako nadklasą RNNCell. Po drugie nie wiem jak wygenerować tensor w według powyższej specyfikacji.

Każda pomoc na temat któregokolwiek z tych pytań jest bardzo doceniane. Może uda mi się wymyślić sposób na przygotowanie W, ale na pewno jak diabli nie rozumiem, jak wdrożyć własne RNN jako nadklasa RNNCell.

Odpowiedz

10

dać szybkie podsumowanie:

spojrzeć w kodzie źródłowym TensorFlow pod python/ops/rnn_cell.py też sprawdzić jak podklasę RNNCell. Zazwyczaj jest to:

class MyRNNCell(RNNCell): 
    def __init__(...): 

    @property 
    def output_size(self): 
    ... 

    @property 
    def state_size(self): 
    ... 

    def __call__(self, input_, state, name=None): 
    ... your per-step iteration here ... 
Powiązane problemy