2011-11-15 14 views
8

Próbuję użyć PyBrain dla niektórych prostych szkoleń NN. Nie wiem, jak to zrobić, aby załadować dane treningowe z pliku. Nie jest wyjaśnione na ich stronie internetowej w dowolnym miejscu. Nie obchodzi mnie format, ponieważ mogę go teraz zbudować, ale muszę to zrobić w pliku, zamiast ręcznie dodawać kolejne wiersze, ponieważ będę miał kilkaset wierszy.Jak załadować dane treningowe w PyBrain?

+1

Kilkaset wierszy oznacza, że ​​masz bardzo mały zestaw i nie należy martwić się o wydajność. Ale czy PyBrain nie akceptuje tylko tablic NumPy? –

+0

Nie wiem, dopiero zaczynam go używać, ale nigdzie nie mówią, jak używać tablic NumPy z ich NN:/ –

Odpowiedz

21

Oto jak to zrobiłem:

 
ds = SupervisedDataSet(6,3) 

tf = open('mycsvfile.csv','r') 

for line in tf.readlines(): 
    data = [float(x) for x in line.strip().split(',') if x != ''] 
    indata = tuple(data[:6]) 
    outdata = tuple(data[6:]) 
    ds.addSample(indata,outdata) 

n = buildNetwork(ds.indim,8,8,ds.outdim,recurrent=True) 
t = BackpropTrainer(n,learningrate=0.01,momentum=0.5,verbose=True) 
t.trainOnDataset(ds,1000) 
t.testOnData(verbose=True) 

W tym przypadku sieć neuronowa ma 6 wejść i 3 wyjścia. Plik csv ma ​​9 wartości w każdym wierszu oddzielonych przecinkami. Pierwsze 6 wartości to wartości wejściowe, a ostatnie trzy są wyjściami.

+0

, to świetnie, dziękuję bardzo. Czy wiesz, jak uzyskać dostęp do wartości ciężaru dla każdego neuronu? –

+1

Możesz uzyskać dostęp do poszczególnych warstw w następujący sposób: n ['in'] dla warstwy wejściowej i n ['out'] dla wyjścia lub n ['hidden0'] dla pierwszej ukrytej warstwy. Nie wiem, ale zgaduję, że możesz wtedy uzyskać dostęp do węzłów warstwy w jakiś sposób. dir (n ['in']) powinien dać ci wskazówkę co możesz zrobić – c0m4

+0

Nie mogę znaleźć, jak to zrobić. Zrobię nowe pytanie. Dziękuję za pomoc. –

1

po prostu użyć tablic pandy w ten sposób

import pandas as pd 

ds = SupervisedDataSet(6,3) 

dataset = pd.read_csv('mycsvfile.csv','r', delimiter=',',skiprows=1) 
ds.setfield('input' dataset.values[:,0:6]) 
ds.setfield('target', dataset.values[:,-2:-1]) 

i jesteś dobry, aby przejść.

Powiązane problemy