Note that there are some explanatory texts on larger screens.

plurals
  1. POImplementing a Neural Network in Java: Training and Backpropagation issues
    text
    copied!<p>I'm trying to implement a feed-forward neural network in Java. I've created three classes NNeuron, NLayer and NNetwork. The "simple" calculations seem fine (I get correct sums/activations/outputs), but when it comes to the training process, I don't seem to get correct results. Can anyone, please tell what I'm doing wrong ? The whole code for the NNetwork class is quite long, so I'm posting the part that is causing the problem: [EDIT]: this is actually pretty much all of the NNetwork class</p> <pre class="lang-java prettyprint-override"><code>import java.util.ArrayList; import java.util.Arrays; import java.util.List; public class NNetwork { public static final double defaultLearningRate = 0.4; public static final double defaultMomentum = 0.8; private NLayer inputLayer; private ArrayList&lt;NLayer&gt; hiddenLayers; private NLayer outputLayer; private ArrayList&lt;NLayer&gt; layers; private double momentum = NNetwork1.defaultMomentum; // alpha: momentum, default! 0.3 private ArrayList&lt;Double&gt; learningRates; public NNetwork (int nInputs, int nOutputs, Integer... neuronsPerHiddenLayer) { this(nInputs, nOutputs, Arrays.asList(neuronsPerHiddenLayer)); } public NNetwork (int nInputs, int nOutputs, List&lt;Integer&gt; neuronsPerHiddenLayer) { // the number of neurons on the last layer build so far (i.e. the number of inputs for each neuron of the next layer) int prvOuts = 1; this.layers = new ArrayList&lt;&gt;(); // input layer this.inputLayer = new NLayer(nInputs, prvOuts, this); this.inputLayer.setAllWeightsTo(1.0); this.inputLayer.setAllBiasesTo(0.0); this.inputLayer.useSigmaForOutput(false); prvOuts = nInputs; this.layers.add(this.inputLayer); // hidden layers this.hiddenLayers = new ArrayList&lt;&gt;(); for (int i=0 ; i&lt;neuronsPerHiddenLayer.size() ; i++) { this.hiddenLayers.add(new NLayer(neuronsPerHiddenLayer.get(i), prvOuts, this)); prvOuts = neuronsPerHiddenLayer.get(i); } this.layers.addAll(this.hiddenLayers); // output layer this.outputLayer = new NLayer(nOutputs, prvOuts, this); this.layers.add(this.outputLayer); this.initCoeffs(); } private void initCoeffs () { this.learningRates = new ArrayList&lt;&gt;(); // learning rates of the hidden layers for (int i=0 ; i&lt;this.hiddenLayers.size(); i++) this.learningRates.add(NNetwork1.defaultLearningRate); // learning rate of the output layer this.learningRates.add(NNetwork1.defaultLearningRate); } public double getLearningRate (int layerIndex) { if (layerIndex &gt; 0 &amp;&amp; layerIndex &lt;= this.hiddenLayers.size()+1) { return this.learningRates.get(layerIndex-1); } else { return 0; } } public ArrayList&lt;Double&gt; getLearningRates () { return this.learningRates; } public void setLearningRate (int layerIndex, double newLearningRate) { if (layerIndex &gt; 0 &amp;&amp; layerIndex &lt;= this.hiddenLayers.size()+1) { this.learningRates.set( layerIndex-1, newLearningRate); } } public void setLearningRates (Double... newLearningRates) { this.setLearningRates(Arrays.asList(newLearningRates)); } public void setLearningRates (List&lt;Double&gt; newLearningRates) { int len = (this.learningRates.size() &lt;= newLearningRates.size()) ? this.learningRates.size() : newLearningRates.size(); for (int i=0; i&lt;len; i++) this.learningRates .set(i, newLearningRates.get(i)); } public double getMomentum () { return this.momentum; } public void setMomentum (double momentum) { this.momentum = momentum; } public NNeuron getNeuron (int layerIndex, int neuronIndex) { if (layerIndex == 0) return this.inputLayer.getNeurons().get(neuronIndex); else if (layerIndex == this.hiddenLayers.size()+1) return this.outputLayer.getNeurons().get(neuronIndex); else return this.hiddenLayers.get(layerIndex-1).getNeurons().get(neuronIndex); } public ArrayList&lt;Double&gt; getOutput (ArrayList&lt;Double&gt; inputs) { ArrayList&lt;Double&gt; lastOuts = inputs; // the last computed outputs of the last 'called' layer so far // input layer //lastOuts = this.inputLayer.getOutput(lastOuts); lastOuts = this.getInputLayerOutputs(lastOuts); // hidden layers for (NLayer layer : this.hiddenLayers) lastOuts = layer.getOutput(lastOuts); // output layer lastOuts = this.outputLayer.getOutput(lastOuts); return lastOuts; } public ArrayList&lt;ArrayList&lt;Double&gt;&gt; getAllOutputs (ArrayList&lt;Double&gt; inputs) { ArrayList&lt;ArrayList&lt;Double&gt;&gt; outs = new ArrayList&lt;&gt;(); // input layer outs.add(this.getInputLayerOutputs(inputs)); // hidden layers for (NLayer layer : this.hiddenLayers) outs.add(layer.getOutput(outs.get(outs.size()-1))); // output layer outs.add(this.outputLayer.getOutput(outs.get(outs.size()-1))); return outs; } public ArrayList&lt;ArrayList&lt;Double&gt;&gt; getAllSums (ArrayList&lt;Double&gt; inputs) { //* ArrayList&lt;ArrayList&lt;Double&gt;&gt; sums = new ArrayList&lt;&gt;(); ArrayList&lt;Double&gt; lastOut; // input layer sums.add(inputs); lastOut = this.getInputLayerOutputs(inputs); // hidden nodes for (NLayer layer : this.hiddenLayers) { sums.add(layer.getSums(lastOut)); lastOut = layer.getOutput(lastOut); } // output layer sums.add(this.outputLayer.getSums(lastOut)); return sums; } public ArrayList&lt;Double&gt; getInputLayerOutputs (ArrayList&lt;Double&gt; inputs) { ArrayList&lt;Double&gt; outs = new ArrayList&lt;&gt;(); for (int i=0 ; i&lt;this.inputLayer.getNeurons().size() ; i++) outs.add(this .inputLayer .getNeuron(i) .getOutput(inputs.get(i))); return outs; } public void changeWeights ( ArrayList&lt;ArrayList&lt;Double&gt;&gt; deltaW, ArrayList&lt;ArrayList&lt;Double&gt;&gt; inputSet, ArrayList&lt;ArrayList&lt;Double&gt;&gt; targetSet, boolean checkError) { for (int i=0 ; i&lt;deltaW.size()-1 ; i++) this.hiddenLayers.get(i).changeWeights(deltaW.get(i), inputSet, targetSet, checkError); this.outputLayer.changeWeights(deltaW.get(deltaW.size()-1), inputSet, targetSet, checkError); } public int train2 ( ArrayList&lt;ArrayList&lt;Double&gt;&gt; inputSet, ArrayList&lt;ArrayList&lt;Double&gt;&gt; targetSet, double maxError, int maxIterations) { ArrayList&lt;Double&gt; input, target; ArrayList&lt;ArrayList&lt;ArrayList&lt;Double&gt;&gt;&gt; prvNetworkDeltaW = null; double error; int i = 0, j = 0, traininSetLength = inputSet.size(); do // during each itreration... { error = 0.0; for (j = 0; j &lt; traininSetLength; j++) // ... for each training element... { input = inputSet.get(j); target = targetSet.get(j); prvNetworkDeltaW = this.train2_bp(input, target, prvNetworkDeltaW); // ... do backpropagation, and return the new weight deltas error += this.getInputMeanSquareError(input, target); } i++; } while (error &gt; maxError &amp;&amp; i &lt; maxIterations); // iterate as much as necessary/possible return i; } public ArrayList&lt;ArrayList&lt;ArrayList&lt;Double&gt;&gt;&gt; train2_bp ( ArrayList&lt;Double&gt; input, ArrayList&lt;Double&gt; target, ArrayList&lt;ArrayList&lt;ArrayList&lt;Double&gt;&gt;&gt; prvNetworkDeltaW) { ArrayList&lt;ArrayList&lt;Double&gt;&gt; layerSums = this.getAllSums(input); // the sums for each layer ArrayList&lt;ArrayList&lt;Double&gt;&gt; layerOutputs = this.getAllOutputs(input); // the outputs of each layer // get the layer deltas (inc the input layer that is null) ArrayList&lt;ArrayList&lt;Double&gt;&gt; layerDeltas = this.train2_getLayerDeltas(layerSums, layerOutputs, target); // get the weight deltas ArrayList&lt;ArrayList&lt;ArrayList&lt;Double&gt;&gt;&gt; networkDeltaW = this.train2_getWeightDeltas(layerOutputs, layerDeltas, prvNetworkDeltaW); // change the weights this.train2_updateWeights(networkDeltaW); return networkDeltaW; } public void train2_updateWeights (ArrayList&lt;ArrayList&lt;ArrayList&lt;Double&gt;&gt;&gt; networkDeltaW) { for (int i=1; i&lt;this.layers.size(); i++) this.layers.get(i).train2_updateWeights(networkDeltaW.get(i)); } public ArrayList&lt;ArrayList&lt;ArrayList&lt;Double&gt;&gt;&gt; train2_getWeightDeltas ( ArrayList&lt;ArrayList&lt;Double&gt;&gt; layerOutputs, ArrayList&lt;ArrayList&lt;Double&gt;&gt; layerDeltas, ArrayList&lt;ArrayList&lt;ArrayList&lt;Double&gt;&gt;&gt; prvNetworkDeltaW) { ArrayList&lt;ArrayList&lt;ArrayList&lt;Double&gt;&gt;&gt; networkDeltaW = new ArrayList&lt;&gt;(this.layers.size()); ArrayList&lt;ArrayList&lt;Double&gt;&gt; layerDeltaW; ArrayList&lt;Double&gt; neuronDeltaW; for (int i=0; i&lt;this.layers.size(); i++) networkDeltaW.add(new ArrayList&lt;ArrayList&lt;Double&gt;&gt;()); double deltaW, x, learningRate, prvDeltaW, d; int i, j, k; for (i=this.layers.size()-1; i&gt;0; i--) // for each layer { learningRate = this.getLearningRate(i); layerDeltaW = new ArrayList&lt;&gt;(); networkDeltaW.set(i, layerDeltaW); for (j=0; j&lt;this.layers.get(i).getNeurons().size(); j++) // for each neuron of this layer { neuronDeltaW = new ArrayList&lt;&gt;(); layerDeltaW.add(neuronDeltaW); for (k=0; k&lt;this.layers.get(i-1).getNeurons().size(); k++) // for each weight (i.e. each neuron of the previous layer) { d = layerDeltas.get(i).get(j); x = layerOutputs.get(i-1).get(k); prvDeltaW = (prvNetworkDeltaW != null) ? prvNetworkDeltaW.get(i).get(j).get(k) : 0.0; deltaW = -learningRate * d * x + this.momentum * prvDeltaW; neuronDeltaW.add(deltaW); } // the bias !! d = layerDeltas.get(i).get(j); x = 1; prvDeltaW = (prvNetworkDeltaW != null) ? prvNetworkDeltaW.get(i).get(j).get(prvNetworkDeltaW.get(i).get(j).size()-1) : 0.0; deltaW = -learningRate * d * x + this.momentum * prvDeltaW; neuronDeltaW.add(deltaW); } } return networkDeltaW; } ArrayList&lt;ArrayList&lt;Double&gt;&gt; train2_getLayerDeltas ( ArrayList&lt;ArrayList&lt;Double&gt;&gt; layerSums, ArrayList&lt;ArrayList&lt;Double&gt;&gt; layerOutputs, ArrayList&lt;Double&gt; target) { // get ouput deltas ArrayList&lt;Double&gt; outputDeltas = new ArrayList&lt;&gt;(); // the output layer deltas double oErr, // output error given a target s, // sum o, // output d; // delta int nOutputs = target.size(), // @TODO ?== this.outputLayer.size() nLayers = this.hiddenLayers.size()+2; // @TODO ?== layerOutputs.size() for (int i=0; i&lt;nOutputs; i++) // for each neuron... { s = layerSums.get(nLayers-1).get(i); o = layerOutputs.get(nLayers-1).get(i); oErr = (target.get(i) - o); d = -oErr * this.getNeuron(nLayers-1, i).sigmaPrime(s); // @TODO "s" or "o" ?? outputDeltas.add(d); } // get hidden deltas ArrayList&lt;ArrayList&lt;Double&gt;&gt; hiddenDeltas = new ArrayList&lt;&gt;(); for (int i=0; i&lt;this.hiddenLayers.size(); i++) hiddenDeltas.add(new ArrayList&lt;Double&gt;()); NLayer nextLayer = this.outputLayer; ArrayList&lt;Double&gt; nextDeltas = outputDeltas; int h, k, nHidden = this.hiddenLayers.size(), nNeurons = this.hiddenLayers.get(nHidden-1).getNeurons().size(); double wdSum = 0.0; for (int i=nHidden-1; i&gt;=0; i--) // for each hidden layer { hiddenDeltas.set(i, new ArrayList&lt;Double&gt;()); for (h=0; h&lt;nNeurons; h++) { wdSum = 0.0; for (k=0; k&lt;nextLayer.getNeurons().size(); k++) { wdSum += nextLayer.getNeuron(k).getWeight(h) * nextDeltas.get(k); } s = layerSums.get(i+1).get(h); d = this.getNeuron(i+1, h).sigmaPrime(s) * wdSum; hiddenDeltas.get(i).add(d); } nextLayer = this.hiddenLayers.get(i); nextDeltas = hiddenDeltas.get(i); } ArrayList&lt;ArrayList&lt;Double&gt;&gt; deltas = new ArrayList&lt;&gt;(); // input layer deltas: void deltas.add(null); // hidden layers deltas deltas.addAll(hiddenDeltas); // output layer deltas deltas.add(outputDeltas); return deltas; } public double getInputMeanSquareError (ArrayList&lt;Double&gt; input, ArrayList&lt;Double&gt; target) { double diff, mse=0.0; ArrayList&lt;Double&gt; output = this.getOutput(input); for (int i=0; i&lt;target.size(); i++) { diff = target.get(i) - output.get(i); mse += (diff * diff); } mse /= 2.0; return mse; } } </code></pre> <p>Some methods' names (with their return values/types) are quite self-explanatory, like "this.getAllSums" that returns the sums (sum(x_i*w_i) for each neuron) of each layer, "this.getAllOutputs" that return the outputs (sigmoid(sum) for each neuron) of each layer and "this.getNeuron(i,j)" that returns the j'th neuron of the i'th layer.</p> <p>Thank you in advance for your help :)</p>
 

Querying!

 
Guidance

SQuiL has stopped working due to an internal error.

If you are curious you may find further information in the browser console, which is accessible through the devtools (F12).

Reload