2016-08-19 14 views
7

Chcę replikować następujący kod numpy w tensorflow. Na przykład chcę przypisać 0 do wszystkich indeksów tensorów, które poprzednio miały wartość 1.Warunkowe przypisanie wartości tensora w TensorFlow

a = np.array([1, 2, 3, 1]) 
a[a==1] = 0 

# a should be [0, 2, 3, 0] 

Jeśli piszę podobny kod w tensorflow otrzymuję następujący błąd.

TypeError: 'Tensor' object does not support item assignment 

warunek w nawiasach kwadratowych powinna być arbitralne, jak w a[a<1] = 0.

Czy istnieje sposób na zrealizowanie tego "warunkowego zlecenia" (z braku lepszego imienia) w tensorflow?

Odpowiedz

13

Several comparison operators są dostępne w ramach interfejsu API TensorFlow.

Jednak nic nie jest równoznaczne ze zwięzłą składnią NumPy, jeśli chodzi o bezpośrednie manipulowanie tensorami. Aby wykonać tę samą operację, należy użyć operatorów indywidualnych comparison, where i assign.

Odpowiednik kodu do NumPy przykładzie jest to:

import tensorflow as tf 

a = tf.Variable([1,2,3,1])  
start_op = tf.global_variables_initializer()  
comparison = tf.equal(a, tf.constant(1))  
conditional_assignment_op = a.assign(tf.where (comparison, tf.zeros_like(a), a)) 

with tf.Session() as session: 
    # Equivalent to: a = np.array([1, 2, 3, 1]) 
    session.run(start_op) 
    print(a.eval())  
    # Equivalent to: a[a==1] = 0 
    session.run(conditional_assignment_op) 
    print(a.eval()) 

# Output is: 
# [1 2 3 1] 
# [0 2 3 0] 

Sprawozdania drukujące są oczywiście opcjonalne, są one po prostu tam, aby wykazać kod wykonuje poprawnie.

0

ja też dopiero zaczynają używać tensorflow Może ktoś wypełni moje podejście bardziej intuicyjne

import tensorflow as tf 

conditionVal = 1 
init_a = tf.constant([1, 2, 3, 1], dtype=tf.int32, name='init_a') 
a = tf.Variable(init_a, dtype=tf.int32, name='a') 
target = tf.fill(a.get_shape(), conditionVal, name='target') 

init = tf.initialize_all_variables() 
condition = tf.not_equal(a, target) 
defaultValues = tf.zeros(a.get_shape(), dtype=a.dtype) 
calculate = tf.select(condition, a, defaultValues) 

with tf.Session() as session: 
    session.run(init) 
    session.run(calculate) 
    print(calculate.eval()) 

Głównym problemem jest to, że jest to trudne do zrealizowania „niestandardową logikę”. jeśli nie potrafisz wyjaśnić swojej logiki w liniowych kategoriach matematycznych, musisz napisać bibliotekę "custom op" dla tensorflow (more details here)

+0

Technicznie nie aktualizuje ona 'a', tzn. brakuje kroku przydziału żądanego przez OP –

+0

tf. select() jest przestarzałe: https://github.com/tensorflow/tensorflow/issues/6899 – gizzmole

Powiązane problemy