2015-08-06 8 views
5

Błąd:TypeError "Bad wejście argument funkcji Theano"

TypeError: ('Bad input argument to theano function with name "c2.py:77" at index 1(0-based)', 'Wrong number of dimensions: expected 2, got 1 with shape (128L,).')

Proszę doradzić jak to naprawić?

kod i dane można pobrać pod tym linkiem: http://u.163.com/axfWJ81e i wpisać ten kod: QU90WxTZ

I tu jest mój kod:

# -*- coding: utf-8 -*- 
import os 
import pandas as pd 
import theano 
from theano import tensor as T 
import numpy as np 

def normalizeX(X): 
    return X/255.0 
data = pd.read_csv("digits3a.csv") 
trX = normalizeX(data.values[:, 1:].astype(float)) 
trY = data.values[:, 0] 
data = pd.read_csv("digits3b.csv") 
teX = normalizeX(data.values.astype(float)) 

def floatX(X): 
    return np.asarray(X, dtype=theano.config.floatX) 

def init_weights(shape): 
    return theano.shared(floatX(np.random.randn(*shape) * 0.01)) 

def model(X, w): 
    return T.nnet.softmax(T.dot(X, w)) 

X = T.fmatrix() 
Y = T.fmatrix() 
w = init_weights((784, 10)) 
py_x = model(X, w) 
y_pred = T.argmax(py_x, axis=1) 
cost = T.mean(T.nnet.categorical_crossentropy(py_x, Y)) 
gradient = T.grad(cost=cost, wrt=w) 
update = [[w, w - gradient * 0.05]] 
train = theano.function(inputs=[X, Y], outputs=cost, updates=update, allow_input_downcast=True) 
predict = theano.function(inputs=[X], outputs=y_pred, allow_input_downcast=True) 

for i in range(10): 
    for start, end in zip(range(0, len(trX), 128), range(128, len(trX), 128)): 
     cost = train(trX[start:end], trY[start:end]) 
    print i, np.mean(np.argmax(teY, axis=1) == predict(teX)) 
+0

Który z nich to linia 77? – eickenberg

Odpowiedz

2

Problemem jest to, że powiedzieć Theano Y macierz wartości zmiennoprzecinkowych, ale wartość podana dla Y jest wektorem liczb całkowitych.

Nie jest do końca jasne, co jest poprawne, ale podejrzewam, że zamierzasz być wektorem liczb całkowitych i użyć 1-gorącego wariantu entropii krzyżowej. Jeśli tak, problem można rozwiązać, zmieniając definicję Theano z Y na

Y = T.lvector() 
Powiązane problemy