2017-05-04 36 views
6

Używam Python API for Tensorflow. Staram się realizować Rosenbrock function podane poniżej bez użycia pętli Pythona:Opis pętli while w Tensorflow

Rosenbrock function

Moja obecna implementacja jest następująca:

def rosenbrock(data_tensor): 
    columns = tf.unstack(data_tensor) 

    summation = 0 
    for i in range(1, len(columns) - 1): 
     first_term = tf.square(tf.subtract(columns[i + 1], tf.square(columns[i]))) 
     second_term = tf.square(tf.subtract(columns[i], 1.0)) 
     summation += tf.add(tf.multiply(100.0, first_term), second_term) 

    return summation 

Próbowałem realizacji sumowanie w tf.while_loop(); jednak uznałem, że interfejs API jest nieco nieintuicyjny, jeśli chodzi o używanie indeksu liczbowego, który ma pozostać oddzielony od danych. Przykład podany w documentation wykorzystuje dane jako wskaźnik (lub vice versa)

i = tf.constant(0) 
c = lambda i: tf.less(i, 10) 
b = lambda i: tf.add(i, 1) 
r = tf.while_loop(c, b, [i]) 
+0

Czy należy używać pętli for? Jaka jest korzyść z używania while_loop? Czy to konieczne? – lerner

+0

W swoim kodzie powyżej pętla for wykona kod Pythona. Jeśli nazwiemy ciało jego pętli for "f", wtedy można pomyśleć o kodzie Pythona jako wykonywaniu f, f, f, f, f, ... f. Tak więc N wywoła tę funkcję "ciała" N, a wykres funkcji będzie miał tę funkcję N razy. Jeśli użyjesz tf.while_loop, zobaczysz tę funkcję tylko raz na wykresie. –

+0

Zaletą tf.while_loop jest: 1) możesz wykonywać iteracje równolegle i 2) możesz mieć stałe runtime w twoich instrukcjach warunków. Na przykład, jeśli chcesz uruchomić optymalizator, dopóki nie zostanie spełniona pewna tolerancja, musisz użyć wariantu tf.while_loop, ponieważ python nie może wstępnie ocenić warunku –

Odpowiedz

10

Można to osiągnąć za pomocą tf.while_loop() i standardowe tuples jak w drugim przykładzie w documentation.

def rosenbrock(data_tensor): 
    columns = tf.unstack(data_tensor) 

    # Track both the loop index and summation in a tuple in the form (index, summation) 
    index_summation = (tf.constant(1), tf.constant(0.0)) 

    # The loop condition, note the loop condition is 'i < n-1' 
    def condition(index, summation): 
     return tf.less(index, tf.subtract(tf.shape(columns)[0], 1)) 

    # The loop body, this will return a result tuple in the same form (index, summation) 
    def body(index, summation): 
     x_i = tf.gather(columns, index) 
     x_ip1 = tf.gather(columns, tf.add(index, 1)) 

     first_term = tf.square(tf.subtract(x_ip1, tf.square(x_i))) 
     second_term = tf.square(tf.subtract(x_i, 1.0)) 
     summand = tf.add(tf.multiply(100.0, first_term), second_term) 

     return tf.add(index, 1), tf.add(summation, summand) 

    # We do not care about the index value here, return only the summation 
    return tf.while_loop(condition, body, index_summation)[1] 

Należy zauważyć, że przyrost wskaźnika powinien nastąpić w ciele pętli podobnej do pętli standardowej. W podanym rozwiązaniu jest to pierwsza pozycja w krotce zwróconej przez funkcję body().

Dodatkowo funkcja warunku pętli musi nadać parametrowi sumę, mimo że nie jest używana w tym konkretnym przykładzie.

Powiązane problemy