2016-08-14 15 views
7

Jak mogę odczytać zmienne i ich stany z punktu kontrolnego?Tensorflow. Lista zmiennych w punkcie kontrolnym

Pracuję z AutoCenoderami, a mój punkt kontrolny zawiera kompletny stan sieci, tj. Koder, dekoder, optymalizator itp. Chcę się wygłupiać z kodowaniem i dlatego potrzebuję tylko części dekodera sieci w moim trybie oceny.

To samo pytanie w bardziej abstrakcyjny sposób: jak mogę odczytać tylko konkretne zmienne z istniejącego punktu kontrolnego do ponownego użycia w innym modelu?

Czy powinienem podać odpowiednią zmienną? Albo czy jest jakiś sposób, aby uzyskać coś takiego:

w_init = read_from_state(state_location, var_name) 

def read_from_state(state_location, var_name): 
    # the magic goes here 
    pass 

Odpowiedz

14

Jest list_variables metoda checkpoint_utils.py który pozwala zobaczyć wszystkie zapisane zmienne.

Jednak w przypadku Twojego przypadku użycie przywracania może być łatwiejsze. Jeśli znasz nazwy zmiennych podczas zapisywania punktu kontrolnego, możesz utworzyć nowy wygaszacz i nakazać zainicjowanie tych nazw w nowych obiektach Variable (prawdopodobnie o różnych nazwach). Jest to używane w przykładzie CIFAR do wybrania przywracania subset of variables. Zobacz Choosing which Variables to Save and Restore w Howto

0

Innym sposobem, który będzie drukować wszystkie tensory punktów kontrolnych (lub tylko jeden, jeśli określono) wraz z ich zawartością:

from tensorflow.python.tools import inspect_checkpoint as inch 
inch.print_tensors_in_checkpoint_file('path/to/ckpt', '', True) 
""" 
Args: 
    file_name: Name of the checkpoint file. 
    tensor_name: Name of the tensor in the checkpoint file to print. 
    all_tensors: Boolean indicating whether to print all tensors. 
""" 

Zawsze będzie drukować zawartość tensora.

A, podczas gdy jesteśmy w nim, oto jak używać checkpoint_utils, sugerowane przez poprzedniej odpowiedzi:

from tensorflow.contrib.framework.python.framework import checkpoint_utils 
    var_list = checkpoint_utils.list_variables('path/to/ckpt') 
    for v in var_list: print(v) 
Powiązane problemy