2016-03-06 11 views
26

Powiedzmy mam następujący kod:Jak dodać, jeśli warunek na wykresie TensorFlow?

x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input") 
condition = tf.placeholder("int32", shape=[1, 1], name = "condition") 
W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights") 
b = tf.Variable(tf.zeros([label_option]), name = "bias") 

if condition > 0: 
    y = tf.nn.softmax(tf.matmul(x, W) + b) 
else: 
    y = tf.nn.softmax(tf.matmul(x, W) - b) 

Czy praca if oświadczenie w obliczeniach (nie sądzę tak)? Jeśli nie, jak dodać instrukcję if do wykresu obliczeniowego TensorFlow?

Odpowiedz

51

Masz rację, że instrukcja if nie działa tutaj, ponieważ warunek jest oceniany w czasie budowy wykresu, podczas gdy prawdopodobnie chcesz, aby warunek był zależny od wartości podawanej do elementu zastępczego w czasie wykonywania. (W rzeczywistości, to zawsze pierwszy oddział, bo condition > 0 ocenia na Tensor, który jest "truthy" in Python.)

Wspieranie warunkową przepływ sterowania, TensorFlow zapewnia operatorowi tf.cond(), który ocenia jeden z dwóch oddziałów, w zależności od warunek boolowski. Aby pokazać, jak go używać, będę przepisać swój program tak, że condition jest skalarne tf.int32 wartość dla uproszczenia:

x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input") 
condition = tf.placeholder(tf.int32, shape=[], name="condition") 
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights") 
b = tf.Variable(tf.zeros([label_option]), name="bias") 

y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b) 
+1

Dziękuję bardzo za wyjaśnienie w szczegółach! –

+1

@mrry Czy obie gałęzie są domyślnie wykonywane? Mam tf.cond (c, lambda x: train_op1, lambda x: train_op2) i oba train_ops są wykonywane przy każdym wykonaniu warunku niezależnie od wartości c. czy robię coś źle? –

+5

@PiotrDabkowski To jest czasami zaskakujące zachowanie 'tf.cond()', które jest dotykane [w dokumentach] (https://www.tensorflow.org/api_docs/python/tf/cond). W skrócie, musisz utworzyć operacje, które chcesz uruchomić warunkowo * wewnątrz * odpowiednich lambd. Wszystko, co tworzysz poza lambdami, ale odsyłasz do którejkolwiek gałęzi, zostanie wykonane w obu przypadkach. – mrry

Powiązane problemy