001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.neuralnet;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033import gov.sandia.cognition.math.matrix.MatrixFactory;
034import gov.sandia.cognition.math.matrix.Vector;
035import gov.sandia.cognition.math.matrix.mtj.DenseMatrix;
036import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ;
037
038import org.openimaj.data.RandomData;
039import org.openimaj.image.DisplayUtilities;
040import org.openimaj.image.FImage;
041import org.openimaj.image.MBFImage;
042import org.openimaj.image.colour.ColourMap;
043import org.openimaj.math.matrix.CFMatrixUtils;
044import org.openimaj.math.matrix.MatrixUtils;
045import org.openimaj.util.function.Function;
046
047
048/**
049 * Implement an online version of the backprop algorithm against an 2D 
050 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
051 *
052 */
053public class OnlineBackpropOneHidden {
054        
055
056        private static final double LEARNRATE = 0.005;
057        private Matrix weightsL1;
058        private Matrix weightsL2;
059        MatrixFactory<? extends Matrix> DMF = DenseMatrixFactoryMTJ.getDenseDefault();
060        /**
061         * @param nInput the number of input values
062         * @param nHidden the number of hidden values
063         * @param nFinal the number of final values
064         */
065        private Function<Double, Double> g;
066        private Function<Matrix, Matrix> gMat;
067        private Function<Double, Double> gPrime;
068        private Function<Matrix, Matrix> gPrimeMat;
069        public OnlineBackpropOneHidden(int nInput, int nHidden, int nFinal) {
070                double[][] weightsL1dat = RandomData.getRandomDoubleArray(nInput+1,nHidden, -1, 1.);
071                double[][] weightsL2dat = RandomData.getRandomDoubleArray(nHidden+1,nFinal , -1, 1.);
072                
073                
074                weightsL1 = DMF.copyArray(weightsL1dat);
075                weightsL2 = DMF.copyArray(weightsL2dat);;
076                
077                g = new Function<Double,Double>(){
078
079                        @Override
080                        public Double apply(Double in) {
081                                
082                                return 1. / (1 + Math.exp(-in));
083                        }
084                        
085                };
086                
087                gPrime = new Function<Double,Double>(){
088
089                        @Override
090                        public Double apply(Double in) {
091                                
092                                return g.apply(in) * (1 - g.apply(in));
093                        }
094                        
095                };
096                
097                gPrimeMat = new Function<Matrix,Matrix>(){
098
099                        @Override
100                        public Matrix apply(Matrix in) {
101                                Matrix out = DMF.copyMatrix(in);
102                                for (int i = 0; i < in.getNumRows(); i++) {
103                                        for (int j = 0; j < in.getNumColumns(); j++) {
104                                                out.setElement(i, j, gPrime.apply(in.getElement(i, j)));
105                                        }
106                                }
107                                return out;
108                        }
109                        
110                };
111                
112                gMat = new Function<Matrix,Matrix>(){
113
114                        @Override
115                        public Matrix apply(Matrix in) {
116                                Matrix out = DMF.copyMatrix(in);
117                                for (int i = 0; i < in.getNumRows(); i++) {
118                                        for (int j = 0; j < in.getNumColumns(); j++) {
119                                                out.setElement(i, j, g.apply(in.getElement(i, j)));
120                                        }
121                                }
122                                return out;
123                        }
124                        
125                };
126        }
127        
128        public void update(double[] x, double[] y){
129                Matrix X = prepareMatrix(x);
130                Matrix Y = DMF.copyArray(new double[][]{y});
131                
132                Matrix hiddenOutput = weightsL1.transpose().times(X); // nHiddenLayers x nInputs (usually 2 x 1)
133                Matrix gHiddenOutput = prepareMatrix(gMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers + 1 x nInputs (usually 3x1)
134                Matrix gPrimeHiddenOutput = prepareMatrix(gPrimeMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers + 1 x nInputs (usually 3x1)
135                Matrix finalOutput = weightsL2.transpose().times(gHiddenOutput);
136                Matrix finalOutputGPrime = gPrimeMat.apply(finalOutput); // nFinalLayers x nInputs (usually 1x1)
137                
138                Matrix errmat = Y.minus(finalOutput);
139                double err = errmat.sumOfColumns().sum();
140                
141                Matrix dL2 = finalOutputGPrime.times(gHiddenOutput.transpose()).scale(err * LEARNRATE).transpose(); // should be nHiddenLayers + 1 x nInputs (3 x 1)
142                Matrix dL1 = finalOutputGPrime.times(weightsL2.transpose().times(gPrimeHiddenOutput).times(X.transpose())).scale(err * LEARNRATE).transpose();
143                
144                dL1 = repmat(dL1,1,weightsL1.getNumColumns());
145                dL2 = repmat(dL2,1,weightsL2.getNumColumns());
146                
147                this.weightsL1.plusEquals(dL1);
148                this.weightsL2.plusEquals(dL2);
149                
150        }
151        
152        private Matrix repmat(Matrix dL1, int nRows, int nCols) {
153                Matrix out = DMF.createMatrix(nRows * dL1.getNumRows(), nCols * dL1.getNumColumns());
154                for (int i = 0; i < nRows; i++) {
155                        for (int j = 0; j < nCols; j++) {
156                                out.setSubMatrix(i * dL1.getNumRows(), j * dL1.getNumColumns(), dL1);
157                        }
158                }
159                return out;
160        }
161
162        public Matrix predict(double[] x){
163                Matrix X = prepareMatrix(x);
164                
165                Matrix hiddenTimes = weightsL1.transpose().times(X);
166                Matrix hiddenVal = prepareMatrix(gMat.apply(hiddenTimes).getColumn(0));
167                Matrix finalTimes = weightsL2.transpose().times(hiddenVal);
168                Matrix finalVal = gMat.apply(finalTimes);
169                
170                return finalVal;
171                
172        }
173        
174        
175        private Matrix prepareMatrix(Vector y) {
176                Matrix Y = DMF.createMatrix(1, y.getDimensionality() + 1);
177                Y.setElement(0, 0, 1);
178                Y.setSubMatrix(0, 1, DMF.copyRowVectors(y));
179                return Y.transpose();
180        }
181
182        private Matrix prepareMatrix(double[] y) {
183                Matrix Y = DMF.createMatrix(1, y.length + 1);
184                Y.setElement(0, 0, 1);
185                Y.setSubMatrix(0, 1, DMF.copyArray(new double[][]{y}));
186                return Y.transpose();
187        }
188        
189        public static void main(String[] args) throws InterruptedException {
190                OnlineBackpropOneHidden bp = new OnlineBackpropOneHidden(2, 2, 1);
191                FImage img = new FImage(200,200);
192                img = imagePredict(bp,img);
193                ColourMap m = ColourMap.Hot;
194                
195                DisplayUtilities.displayName(m.apply(img), "xor");
196                int npixels = img.width*img.height;
197                int half = img.width/2;
198                int[] pixels = RandomData.getUniqueRandomInts(npixels, 0, npixels);
199                while(true){
200//                      for (int i = 0; i < pixels.length; i++) {
201//                              int pixel = pixels[i];
202//                              int y = pixel / img.width;
203//                              int x = pixel - (y * img.width);
204//                              bp.update(new double[]{x < half ? -1 : 1,y < half ? -1 : 1},new double[]{xorValue(half,x,y)});
205////                            Thread.sleep(5);
206//                      }
207                        bp.update(new double[]{0,0},new double[]{0});
208                        bp.update(new double[]{1,1},new double[]{0});
209                        bp.update(new double[]{0,1},new double[]{1});
210                        bp.update(new double[]{1,0},new double[]{1});
211                        imagePredict(bp, img);
212                        DisplayUtilities.displayName(m.apply(img),"xor");
213                }
214        }
215
216        private static FImage imagePredict(OnlineBackpropOneHidden bp, FImage img) {
217                double[] pos = new double[2];
218                int half = img.width/2;
219                for (int y = 0; y < img.height; y++) {
220                        for (int x = 0; x < img.width; x++) {
221                                pos[0] = x < half ? 0 : 1;
222                                pos[1] = y < half ? 0 : 1;
223                                float ret = (float) bp.predict(pos).getElement(0, 0);
224                                img.pixels[y][x] = ret;
225                        }
226                }
227                return img;
228        }
229}