2016-01-24 12 views
6

Czy można użyć elementu zastępczego dla parametru num_split do tf.split()?Używanie zmiennej dla num_splits dla tf.split()

idealnie bym chciał zrobić coś takiego:

num_splits = tf.placeholder(tf.int32) 
inputs = tf.placeholder(tf.int32, [5, None]) 
split_inputs = tf.split(1, num_splits, inputs) 


TypeError: Expected int for argument 'num_split' not <tensorflow.python.framework.ops.Tensor object at 0x10b819ad0>. 

Istnieje ewentualnie coś nie tak z moim podejściem. Szukam wyliczenia w całym wymiarze w tensor o zmiennym kształcie. Dzięki!

Odpowiedz

11

Istnieje ogólna filozofia "tensorowej in-tensor out" dla podstawowych operacji graficznych, więc może to uprościć sytuację, jeśli można zrestrukturyzować swoje obliczenia, aby radzić sobie z pojedynczym tensorem zmiennej wielkości zamiast zmiennej liczby tensorów.

Ops jak pack, unpack, split czynienia z wieloma tensorów ale skompilować do OPS "tensor in/tensor-out" podczas wykresu czas budowy, dlatego num_splits musi zostać naprawiony. Operacje takie jak dynamic_partition, dynamic_stitch, dequeue_many przejmują część tej funkcji dla pojedynczych tensorów ze zmienną 0 -ty wymiar.

Jeśli naprawdę potrzebujesz radzić sobie ze zmienną liczbą tensorów, typowym podejściem jest łamanie obliczeń w wielu połączeniach session.run, z jednym tensorem wejściowym na każde połączenie run i powiązanie elementów za pomocą kolejek. Jest slice_input_producer który dzieli wejście zmiennej wielkości wzdłuż 0'th wymiar i tworzy tensor dla każdego wiersza, więc jeśli chciał oceniać myfunction w pętli na każdym rzędzie inputs można zrobić to

def myfunction(vector): 
    result = tf.reduce_sum(vector) 
    print_result = tf.Print(result, [result], "myfunction called ") 
    return print_result 

MAX_ROWS = 10 

# input matrix with 2 columns and unknown number of rows (<MAX_ROWS) 
inputs = tf.placeholder(tf.int32, [None, 2]) 
# copy of inputs, will need to have a persistent copy of it because we will 
# be fetching rows in different session.run calls 
data = tf.Variable(inputs, validate_shape=False) 
# input producer that iterates over the rows and pushes them onto Queue 
row = tf.train.slice_input_producer([data], num_epochs=1, shuffle=False)[0] 
myfunction_op = myfunction(row) 

# this op will save placeholder values into the variable 
init_op = tf.initialize_all_variables() 

# Coordinator is not necessary in this case, but you'll need it if you have 
# more than one Queue in order to close all queues together 
sess = tf.Session() 
coord = tf.train.Coordinator() 
threads = tf.train.start_queue_runners(sess=sess, coord=coord) 

sess.run([init_op], feed_dict={inputs:[[0, 0], [1, 1], [2, 2]]}) 

try: 
    for i in range(MAX_ROWS): 
    sess.run([myfunction_op]) 
except tf.errors.OutOfRangeError: 
    print('Done iterating') 
finally: 
    # When done, ask other threads to stop. 
    coord.request_stop() 

przypadku uruchomienia to, powinieneś zobaczyć

Powiązane problemy