2016-06-12 13 views
5

Wierzę, że ciężko jest mi zrozumieć, jak działają wykresy w tensorflow i jak uzyskać do nich dostęp. Mam intuicję, że linie pod "z wykresem:" utworzą wykres jako pojedynczy byt. Dlatego postanowiłem stworzyć klasę, która będzie budować wykres, gdy zostanie utworzona i będzie posiadać funkcję, która uruchomi wykres, jak następuje;Tensorflow: Tworzenie wykresu w klasie i przeprowadzanie go ouside

class Graph(object): 

    #To build the graph when instantiated 
    def __init__(self, parameters): 
     self.graph = tf.Graph() 
     with self.graph.as_default(): 
      ... 
      prediction = ... 
      cost  = ... 
      optimizer = ... 
      ... 
    # To launch the graph 
    def launchG(self, inputs): 
     with tf.Session(graph=self.graph) as sess: 
      ... 
      sess.run(optimizer, feed_dict) 
      loss = sess.run(cost, feed_dict) 
      ... 
     return variables 

Kolejne kroki są do tworzenia głównego plik zbierze się parametry dla danej klasy, do zbudowania wykresu, a następnie uruchomić go;

#Main file 
... 
parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... } 

#Building graph 
G = Graph(parameters_dict) 
P = G.launchG(Input) 
... 

Jest to bardzo eleganckie dla mnie, ale nie działa (oczywiście). Rzeczywiście, wygląda na to, że funkcje launchG nie mają dostępu do węzłów zdefiniowanych na wykresie, które dają mi taki błąd;

---> 26 sess.run(optimizer, feed_dict) 

NameError: name 'optimizer' is not defined 

Być może to moja python (i tensorflow) zrozumienie, że jest zbyt ograniczony, ale byłem pod dziwnym wrażeniem, że z wykresu (G) utworzona, prowadzenie sesji z tego wykresu jako argument powinien dać dostęp do węzłów w nim, bez wymogu udzielenia wyraźnego dostępu.

Jakieś oświecenie?

Odpowiedz

7

Węzły prediction, cost, a optimizer są zmienne lokalne utworzone w metodzie __init__, nie mogą być dostępne w metodzie launchG.

Najprostszym fix byłoby zadeklarować je jako atrybuty klasy Graph:

class Graph(object): 

    #To build the graph when instantiated 
    def __init__(self, parameters): 
     self.graph = tf.Graph() 
     with self.graph.as_default(): 
      ... 
      self.prediction = ... 
      self.cost  = ... 
      self.optimizer = ... 
      ... 
    # To launch the graph 
    def launchG(self, inputs): 
     with tf.Session(graph=self.graph) as sess: 
      ... 
      sess.run(self.optimizer, feed_dict) 
      loss = sess.run(self.cost, feed_dict) 
      ... 
     return variables 

Można również pobrać węzły wykresu przy użyciu ich dokładną nazwę z graph.get_tensor_by_name i graph.get_operation_by_name.

Powiązane problemy