Próbuję wdrożyć FFNN w Javie z backpropagation i nie mam pojęcia, co robię źle. To zadziałało, gdy miałem tylko jeden neuron w sieci, ale napisałem inną klasę do obsługi większych sieci i nic się nie zbiegnie. Wydaje się, że jest to problem w matematyce - a raczej moja implementacja matematyki - ale sprawdziłem to kilka razy i nie mogę znaleźć niczego złego. To powinno działać. Klasa
Węzeł class
Implementacja sieci neuronowej w java
package arr;
import util.ActivationFunction;
import util.Functions;
public class Node {
public ActivationFunction f;
public double output;
public double error;
private double sumInputs;
private double sumErrors;
public Node(){
sumInputs = 0;
sumErrors = 0;
f = Functions.SIG;
output = 0;
error = 0;
}
public Node(ActivationFunction func){
this();
this.f = func;
}
public void addIW(double iw){
sumInputs += iw;
}
public void addIW(double input, double weight){
sumInputs += (input*weight);
}
public double calculateOut(){
output = f.eval(sumInputs);
return output;
}
public void addEW(double ew){
sumErrors+=ew;
}
public void addEW(double error, double weight){
sumErrors+=(error*weight);
}
public double calculateError(){
error = sumErrors * f.deriv(sumInputs);
return error;
}
public void resetValues(){
sumErrors = 0;
sumInputs = 0;
}
}
LineNetwork:
package arr;
import util.Functions;
public class LineNetwork {
public double[][][] weights; //layer of node to, # of node to, # of node from
public Node[][] nodes; //layer, #
public double lc;
public LineNetwork(){
weights = new double[2][][];
weights[0] = new double[2][1];
weights[1] = new double[1][3];
initializeWeights();
nodes = new Node[2][];
nodes[0] = new Node[2];
nodes[1] = new Node[1];
initializeNodes();
lc = 1;
}
private void initializeWeights(){
for(double[][] layer: weights)
for(double[] curNode: layer)
for(int i=0; i<curNode.length; i++)
curNode[i] = Math.random()/10;
}
private void initializeNodes(){
for(Node[] layer: nodes)
for(int i=0; i<layer.length; i++)
layer[i] = new Node();
nodes[nodes.length-1][0].f = Functions.HSF;
}
public double feedForward(double[] inputs) {
for(int j=0; j<nodes[0].length; j++)
nodes[0][j].addIW(inputs[j], weights[0][j][0]);
double[] outputs = new double[nodes[0].length];
for(int i=0; i<nodes[0].length; i++)
outputs[i] = nodes[0][i].calculateOut();
for(int l=1; l<nodes.length; l++){
for(int i=0; i<nodes[l].length; i++){
for(int j=0; j<nodes[l-1].length; j++)
nodes[l][i].addIW(
outputs[j],
weights[l][i][j]);
nodes[l][i].addIW(weights[l][i][weights[l][i].length-1]);
}
outputs = new double[nodes[l].length];
for(int i=0; i<nodes[l].length; i++)
outputs[i] = nodes[l][i].calculateOut();
}
return outputs[0];
}
public void backpropagate(double[] inputs, double expected) {
nodes[nodes.length-1][0].addEW(expected-nodes[nodes.length-1][0].output);
for(int l=nodes.length-2; l>=0; l--){
for(Node n: nodes[l+1])
n.calculateError();
for(int i=0; i<nodes[l].length; i++)
for(int j=0; j<nodes[l+1].length; j++)
nodes[l][i].addEW(nodes[l+1][j].error, weights[l+1][j][i]);
for(int j=0; j<nodes[l+1].length; j++){
for(int i=0; i<nodes[l].length; i++)
weights[l+1][j][i] += nodes[l][i].output*lc*nodes[l+1][j].error;
weights[l+1][j][nodes[l].length] += lc*nodes[l+1][j].error;
}
}
for(int i=0; i<nodes[0].length; i++){
weights[0][i][0] += inputs[i]*lc*nodes[0][i].calculateError();
}
}
public double train(double[] inputs, double expected) {
double r = feedForward(inputs);
backpropagate(inputs, expected);
return r;
}
public void resetValues() {
for(Node[] layer: nodes)
for(Node n: layer)
n.resetValues();
}
public static void main(String[] args) {
LineNetwork ln = new LineNetwork();
System.out.println(str2d(ln.weights[0]));
for(int i=0; i<10000; i++){
double[] in = {Math.round(Math.random()),Math.round(Math.random())};
int out = 0;
if(in[1]==1^in[0] ==1) out = 1;
ln.resetValues();
System.out.print(i+": {"+in[0]+", "+in[1]+"}: "+out+" ");
System.out.println((int)ln.train(in, out));
}
System.out.println(str2d(ln.weights[0]));
}
private static String str2d(double[][] a){
String str = "[";
for(double[] arr: a)
str = str + str1d(arr) + ",\n";
str = str.substring(0, str.length()-2)+"]";
return str;
}
private static String str1d(double[] a){
String str = "[";
for(double d: a)
str = str+d+", ";
str = str.substring(0, str.length()-2)+"]";
return str;
}
}
krótkie wyjaśnienie struktury: każdy węzeł ma aktywacja funkcji F; f.eval
ocenia funkcję, a f.deriv
ocenia jej pochodną. Functions.SIG
jest standardową funkcją sigmoidalną, a Functions.HSF
jest funkcją kroku Heaviside. Aby ustawić wejścia funkcji, należy wywołać addIW
z wartością, która już obejmuje wagę poprzedniego wyjścia. Podobnie jest w przypadku propagacji wstecznej z addEW
. Węzły są zorganizowane w tablicy 2d, a wagi są zorganizowane osobno w tablicy 3d zgodnie z opisem.
Zdaję sobie sprawę, że może to być trochę trudne do zapamiętania - i na pewno zdaję sobie sprawę, ile konwencji Java łamie ten kod - ale doceniam każdą pomoc, którą każdy może zaoferować.
EDYCJA: Ponieważ to pytanie i mój kod są gigantycznymi ścianami tekstu, jeśli istnieje linia zawierająca wiele skomplikowanych wyrażeń w nawiasach, których nie chcesz wymyślić, dodaj komentarz lub coś, co mnie pyta, a ja " Postaram się odpowiedzieć tak szybko, jak tylko mogę.
EDYCJA 2: Szczególny problem polega na tym, że ta sieć nie jest zbieżna na XOR. Oto niektóre wyjściowy zilustrować następująco:
9995 {1.0, 0.0} 1 1
9996 {0.0, 1.0} 1 1
9997 {0.0, 0.0} 0 1
9998 {0.0, 1.0}: 1 0
9999 {0.0, 1.0} 1 1
Każda linia w formacieTEST NUMBER: {INPUTS}: EXPECTED ACTUAL
wymaga siećtrain
każdym teście, tak więc sieć backpropagating 10000 razy.
Oto dwa dodatkowe zajęcia, jeśli ktoś chce go uruchomić:
package util;
public class Functions {
public static final ActivationFunction LIN = new ActivationFunction(){
public double eval(double x) {
return x;
}
public double deriv(double x) {
return 1;
}
};
public static final ActivationFunction SIG = new ActivationFunction(){
public double eval(double x) {
return 1/(1+Math.exp(-x));
}
public double deriv(double x) {
double ev = eval(x);
return ev * (1-ev);
}
};
public static final ActivationFunction HSF = new ActivationFunction(){
public double eval(double x) {
if(x>0) return 1;
return 0;
}
public double deriv(double x) {
return (1);
}
};
}
package util;
public interface ActivationFunction {
public double eval(double x);
public double deriv(double x);
}
Teraz to nawet dłużej. Cerować.
Jaki jest konkretny problem? Jaki jest oczekiwany wynik? Czy możesz zrobić krótszy program, aby go odtworzyć? W obecnym brzmieniu głosuję, aby zamknąć to ze względu na "Pytania dotyczące pomocy w zakresie debugowania (" dlaczego ten kod nie działa? ") Muszą zawierać pożądane zachowanie, konkretny problem lub błąd oraz najkrótszy kod niezbędny do odtworzenia go w pytaniu Pytania bez wyraźnego stwierdzenia problemu nie są przydatne dla innych czytelników. " –
Jeśli możesz wyszkolić pojedynczy neuron, problem prawdopodobnie występuje w twojej metodzie backpropagate. Czy próbowałeś obliczenia "od ręki" z małą siatką do porównania? Pomoże to również, jeśli możesz opublikować brakujące klasy, aby Twój kod mógł zostać uruchomiony. – jbkm
@KErlandsson: Dodałem konkretny problem, przyjrzę się krótszemu programowi, ale na pewno zajmie to trochę czasu, ponieważ nie jestem całkowicie pewien, co nie działa i co będę w stanie wyciągnąć. –