2015-12-28 18 views
8

Próbuję zaimplementować sugestię z odpowiedziami: Tensorflow: how to save/restore a model?tensorflow: zapisywanie i odtwarzanie sesji

Mam obiekt, który owija tensorflow model w sklearn stylu.

import tensorflow as tf 
class tflasso(): 
    saver = tf.train.Saver() 
    def __init__(self, 
       learning_rate = 2e-2, 
       training_epochs = 5000, 
        display_step = 50, 
        BATCH_SIZE = 100, 
        ALPHA = 1e-5, 
        checkpoint_dir = "./", 
      ): 
     ... 

    def _create_network(self): 
     ... 


    def _load_(self, sess, checkpoint_dir = None): 
     if checkpoint_dir: 
      self.checkpoint_dir = checkpoint_dir 

     print("loading a session") 
     ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) 
     if ckpt and ckpt.model_checkpoint_path: 
      self.saver.restore(sess, ckpt.model_checkpoint_path) 
     else: 
      raise Exception("no checkpoint found") 
     return 

    def fit(self, train_X, train_Y , load = True): 
     self.X = train_X 
     self.xlen = train_X.shape[1] 
     # n_samples = y.shape[0] 

     self._create_network() 
     tot_loss = self._create_loss() 
     optimizer = tf.train.AdagradOptimizer(self.learning_rate).minimize(tot_loss) 

     # Initializing the variables 
     init = tf.initialize_all_variables() 
     " training per se" 
     getb = batchgen(self.BATCH_SIZE) 

     yvar = train_Y.var() 
     print(yvar) 
     # Launch the graph 
     NUM_CORES = 3 # Choose how many cores to use. 
     sess_config = tf.ConfigProto(inter_op_parallelism_threads=NUM_CORES, 
                  intra_op_parallelism_threads=NUM_CORES) 
     with tf.Session(config= sess_config) as sess: 
      sess.run(init) 
      if load: 
       self._load_(sess) 
      # Fit all training data 
      for epoch in range(self.training_epochs): 
       for (_x_, _y_) in getb(train_X, train_Y): 
        _y_ = np.reshape(_y_, [-1, 1]) 
        sess.run(optimizer, feed_dict={ self.vars.xx: _x_, self.vars.yy: _y_}) 
       # Display logs per epoch step 
       if (1+epoch) % self.display_step == 0: 
        cost = sess.run(tot_loss, 
          feed_dict={ self.vars.xx: train_X, 
            self.vars.yy: np.reshape(train_Y, [-1, 1])}) 
        rsq = 1 - cost/yvar 
        logstr = "Epoch: {:4d}\tcost = {:.4f}\tR^2 = {:.4f}".format((epoch+1), cost, rsq) 
        print(logstr) 
        self.saver.save(sess, self.checkpoint_dir + 'model.ckpt', 
         global_step= 1+ epoch) 

      print("Optimization Finished!") 
     return self 

Kiedy biegnę:

tfl = tflasso() 
tfl.fit(train_X, train_Y , load = False) 

mam wyjścia:

Epoch: 50 cost = 38.4705 R^2 = -1.2036 
    b1: 0.118122 
Epoch: 100 cost = 26.4506 R^2 = -0.5151 
    b1: 0.133597 
Epoch: 150 cost = 22.4330 R^2 = -0.2850 
    b1: 0.142261 
Epoch: 200 cost = 20.0361 R^2 = -0.1477 
    b1: 0.147998 

Jednak gdy próbuję odzyskać parametry (nawet bez zabijania Object): tfl.fit(train_X, train_Y , load = True)

Dostaję dziwne wyniki. Po pierwsze, załadowana wartość nie odpowiada zapisanej.

loading a session 
loaded b1: 0.1   <------- Loaded another value than saved 
Epoch: 50 cost = 30.8483 R^2 = -0.7670 
    b1: 0.137484 

Jaki jest właściwy sposób ładowania i prawdopodobnie pierwsza inspekcja zapisanych zmiennych?

+0

Dokumentacja tensorflow jest pozbawiona całkiem prostych przykładów, trzeba przekopać się w folderach z przykładami i nadać jej sens głównie pod własnym numerem – diffeomorphism

Odpowiedz

10

TL; DR: Należy starać się przerobienie tej klasy, dzięki czemu self.create_network() nazywa (i) tylko raz, oraz (ii) przed tf.train.Saver() jest skonstruowany.

Istnieją tutaj dwa subtelne problemy, które wynikają ze struktury kodu i domyślnego zachowania się tf.train.Saver constructor. Gdy tworzysz wygaszacz bez żadnych argumentów (jak w kodzie), gromadzi on bieżący zestaw zmiennych w twoim programie i dodaje ops do wykresu w celu ich zapisania i przywrócenia. W twoim kodzie, gdy zadzwonisz pod numer tflasso(), skonstruujesz wygaszacz i nie będzie żadnych zmiennych (ponieważ create_network() nie został jeszcze wywołany). W rezultacie punkt kontrolny powinien być pusty.

Druga kwestia to domyślnie — — Format zapisanego punktu kontrolnego to mapa z name property of a variable do aktualnej wartości. Jeśli tworzysz dwie zmienne o tej samej nazwie, będą automatycznie „uniquified” przez TensorFlow:

v = tf.Variable(..., name="weights") 
assert v.name == "weights" 
w = tf.Variable(..., name="weights") 
assert v.name == "weights_1" # The "_1" is added by TensorFlow. 

Konsekwencją tego jest to, że podczas rozmowy self.create_network() w drugim naborze do tfl.fit(), zmienne będą miały różne nazwy od nazw, które są przechowywane w punkcie kontrolnym — lub byłyby, gdyby wygaszacz został skonstruowany po sieci. (Można tego uniknąć poprzez przepuszczenie Imieniny Variable słownika do konstruktora wygaszacza, ale zazwyczaj jest to dość kłopotliwe.)

Istnieją dwa główne sposoby ich obejścia:

  1. W każdym wywołaniu tflasso.fit(), tworzyć cały model na nowo, definiując nowy tf.Graph, następnie na tym wykresie budując sieć i tworząc tf.train.Saver.

  2. ZALECANE Tworzenie sieci, wówczas tf.train.Saver w konstruktorze tflasso i wykorzystać ten wykres na każdym wywołaniu tflasso.fit().Zauważ, że możesz potrzebować trochę więcej pracy, aby zreorganizować rzeczy (w szczególności, nie jestem pewien, co robisz z self.X i self.xlen), ale powinno być możliwe osiągnięcie tego z placeholders i karmienie.

+0

dziękuję! 'Xlen' jest używane w' self._create_network() ', aby ustawić rozmiar wejściowy' X' (placeholder init: 'self.vars.xx = tf.placeholder (" float ", shape = [None, self.xlen ])). Z tego co mówisz, preferowanym sposobem jest przekazanie 'xlen' do inicjalizatora. –

+0

Czy istnieje sposób resetowania uniquifier/clear old tf variables po ponownej inicjalizacji obiektu? –

+1

Aby to zrobić, musisz utworzyć nowy 'tf.Graph' i uczynić go domyślnym przed: (i) stworzeniem sieci i (ii) zrobieniem' Saver'. Jeśli otoczysz ciało 'tflasso.fit()' w bloku 'with tf.Graph(). As_default():' i przeniesiesz konstrukcję 'Saver' wewnątrz tego bloku, nazwy powinny być takie same za każdym razem, gdy wywołaj 'fit()'. – mrry

Powiązane problemy