2017-01-12 10 views
6

Trenuję Generatywną Sieć Przeciwdziałania (GAN) w tensorflow, gdzie zasadniczo mamy dwie różne sieci, każda z własną optymalizacją.Przywróć podzestaw zmiennych w Tensorflow

self.G, self.layer = self.generator(self.inputCT,batch_size_tf) 
self.D, self.D_logits = self.discriminator(self.GT_1hot) 

... 

self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step) 

self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) \ 
         .minimize(self.d_loss, var_list=self.d_vars) 

Problem polega na tym, że najpierw trenuję jedną z sieci (g), a następnie chcę razem ćwiczyć g i d. Jednak, kiedy wywołanie funkcji Load:

self.sess.run(tf.initialize_all_variables()) 
self.sess.graph.finalize() 

self.load(self.checkpoint_dir) 

def load(self, checkpoint_dir): 
    print(" [*] Reading checkpoints...") 

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
     ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 
     self.saver.restore(self.sess, ckpt.model_checkpoint_path) 
     return True 
    else: 
     return False 

mam błąd podobny do tego (o wiele więcej traceback):

Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000 

mogę przywrócić sieć g i zachować szkolenia z tej funkcji, ale kiedy chcę wystrzelić d od podstaw, a g z przechowywanego modelu mam ten błąd.

Odpowiedz

17

Aby przywrócić podzestaw zmiennych, należy utworzyć nowy tf.train.Saver i przekazać mu określoną listę zmiennych do przywrócenia w opcjonalnym argumencie var_list.

Domyślnie tf.train.Saver stworzy ops, że (i) Zapisz każdą zmienną w swoim wykresie podczas rozmowy saver.save() oraz (ii) wyszukiwanie (według nazwy) każdej zmiennej w danym punkcie kontrolnym podczas rozmowy saver.restore(). Choć pracuje dla najbardziej typowych scenariuszy, trzeba dostarczyć więcej informacji na pracę z konkretnymi podzbiorów zmiennych:

  1. Jeśli chcesz tylko przywrócić podzbiór zmiennych, można uzyskać listę tych zmiennych dzwoniąc pod numer tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX), zakładając, że umieścisz sieć "g" we wspólnym bloku with tf.name_scope(G_NETWORK_PREFIX): lub tf.variable_scope(G_NETWORK_PREFIX):. Następnie możesz przekazać tę listę do konstruktora tf.train.Saver.

  2. Jeśli chcesz przywrócić podzbiór zmiennej i/albo zmiennych w punkcie kontrolnym mieć różne nazwy, można przekazać słownika jako var_list argument. Domyślnie każda zmienna w punkcie kontrolnym jest powiązana z kluczem , który jest wartością jego właściwości tf.Variable.name. Jeśli nazwa jest inna na wykresie docelowym (np. Z powodu dodania prefiksu zasięgu), możesz określić słownik, który odwzorowuje klucze ciągu (w pliku punktu kontrolnego) na obiekty tf.Variable (na wykresie docelowym).

0

Można utworzyć oddzielną instancję tf.train.Saver() z var_list argumentem ustawionym na zmiennych, które chcesz przywrócić. Utwórz osobną instancję, aby zapisać zmienne.

0

Zainspirowany przez @mrry, proponuję rozwiązanie tego problemu. Aby było jasne, sformułuję problem jako przywracanie podzbioru zmiennej z punktu kontrolnego, gdy model jest zbudowany na wcześniej wyszkolonym modelu. Po pierwsze, powinniśmy użyć funkcji print_tensors_in_checkpoint_file z biblioteki inspect_checkpoint lub po prostu wyodrębnić tę funkcję przez:

from tensorflow.python import pywrap_tensorflow 
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): 
    varlist=[] 
    reader = pywrap_tensorflow.NewCheckpointReader(file_name) 
    if all_tensors: 
     var_to_shape_map = reader.get_variable_to_shape_map() 
     for key in sorted(var_to_shape_map): 
     varlist.append(key) 
    return varlist 
varlist=print_tensors_in_checkpoint_file(file_name=the path of the ckpt file,all_tensors=True,tensor_name=None) 

Następnie używamy TF.get_collection() tak jak @mrry saied:

variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 

Wreszcie możemy zainicjować wygaszacza przez:

saver = tf.train.Saver(variable[:len(varlist)]) 

Pełną wersję można znaleźć na moim github: https://github.com/pobingwanghai/tensorflow_trick/blob/master/restore_from_checkpoint.py

w mojej sytuacji , nowe zmienne są dodawane na końcu modelu, więc mogę po prostu użyć [: length()] do zidentyfikowania potrzebnych zmiennych, dla bardziej złożonej sytuacji, być może będziesz musiał wykonać pewne ręczne wyrównanie lub napisać proste funkcja dopasowywania ciągów w celu określenia wymaganego v ariables.

Powiązane problemy