001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.neuralnet; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033import gov.sandia.cognition.math.matrix.MatrixFactory; 034import gov.sandia.cognition.math.matrix.Vector; 035import gov.sandia.cognition.math.matrix.mtj.DenseMatrix; 036import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ; 037 038import org.openimaj.data.RandomData; 039import org.openimaj.image.DisplayUtilities; 040import org.openimaj.image.FImage; 041import org.openimaj.image.MBFImage; 042import org.openimaj.image.colour.ColourMap; 043import org.openimaj.math.matrix.CFMatrixUtils; 044import org.openimaj.math.matrix.MatrixUtils; 045import org.openimaj.util.function.Function; 046 047 048/** 049 * Implement an online version of the backprop algorithm against an 2D 050 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 051 * 052 */ 053public class OnlineBackpropOneHidden { 054 055 056 private static final double LEARNRATE = 0.005; 057 private Matrix weightsL1; 058 private Matrix weightsL2; 059 MatrixFactory<? extends Matrix> DMF = DenseMatrixFactoryMTJ.getDenseDefault(); 060 /** 061 * @param nInput the number of input values 062 * @param nHidden the number of hidden values 063 * @param nFinal the number of final values 064 */ 065 private Function<Double, Double> g; 066 private Function<Matrix, Matrix> gMat; 067 private Function<Double, Double> gPrime; 068 private Function<Matrix, Matrix> gPrimeMat; 069 public OnlineBackpropOneHidden(int nInput, int nHidden, int nFinal) { 070 double[][] weightsL1dat = RandomData.getRandomDoubleArray(nInput+1,nHidden, -1, 1.); 071 double[][] weightsL2dat = RandomData.getRandomDoubleArray(nHidden+1,nFinal , -1, 1.); 072 073 074 weightsL1 = DMF.copyArray(weightsL1dat); 075 weightsL2 = DMF.copyArray(weightsL2dat);; 076 077 g = new Function<Double,Double>(){ 078 079 @Override 080 public Double apply(Double in) { 081 082 return 1. / (1 + Math.exp(-in)); 083 } 084 085 }; 086 087 gPrime = new Function<Double,Double>(){ 088 089 @Override 090 public Double apply(Double in) { 091 092 return g.apply(in) * (1 - g.apply(in)); 093 } 094 095 }; 096 097 gPrimeMat = new Function<Matrix,Matrix>(){ 098 099 @Override 100 public Matrix apply(Matrix in) { 101 Matrix out = DMF.copyMatrix(in); 102 for (int i = 0; i < in.getNumRows(); i++) { 103 for (int j = 0; j < in.getNumColumns(); j++) { 104 out.setElement(i, j, gPrime.apply(in.getElement(i, j))); 105 } 106 } 107 return out; 108 } 109 110 }; 111 112 gMat = new Function<Matrix,Matrix>(){ 113 114 @Override 115 public Matrix apply(Matrix in) { 116 Matrix out = DMF.copyMatrix(in); 117 for (int i = 0; i < in.getNumRows(); i++) { 118 for (int j = 0; j < in.getNumColumns(); j++) { 119 out.setElement(i, j, g.apply(in.getElement(i, j))); 120 } 121 } 122 return out; 123 } 124 125 }; 126 } 127 128 public void update(double[] x, double[] y){ 129 Matrix X = prepareMatrix(x); 130 Matrix Y = DMF.copyArray(new double[][]{y}); 131 132 Matrix hiddenOutput = weightsL1.transpose().times(X); // nHiddenLayers x nInputs (usually 2 x 1) 133 Matrix gHiddenOutput = prepareMatrix(gMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers + 1 x nInputs (usually 3x1) 134 Matrix gPrimeHiddenOutput = prepareMatrix(gPrimeMat.apply(hiddenOutput).getColumn(0)); // nHiddenLayers + 1 x nInputs (usually 3x1) 135 Matrix finalOutput = weightsL2.transpose().times(gHiddenOutput); 136 Matrix finalOutputGPrime = gPrimeMat.apply(finalOutput); // nFinalLayers x nInputs (usually 1x1) 137 138 Matrix errmat = Y.minus(finalOutput); 139 double err = errmat.sumOfColumns().sum(); 140 141 Matrix dL2 = finalOutputGPrime.times(gHiddenOutput.transpose()).scale(err * LEARNRATE).transpose(); // should be nHiddenLayers + 1 x nInputs (3 x 1) 142 Matrix dL1 = finalOutputGPrime.times(weightsL2.transpose().times(gPrimeHiddenOutput).times(X.transpose())).scale(err * LEARNRATE).transpose(); 143 144 dL1 = repmat(dL1,1,weightsL1.getNumColumns()); 145 dL2 = repmat(dL2,1,weightsL2.getNumColumns()); 146 147 this.weightsL1.plusEquals(dL1); 148 this.weightsL2.plusEquals(dL2); 149 150 } 151 152 private Matrix repmat(Matrix dL1, int nRows, int nCols) { 153 Matrix out = DMF.createMatrix(nRows * dL1.getNumRows(), nCols * dL1.getNumColumns()); 154 for (int i = 0; i < nRows; i++) { 155 for (int j = 0; j < nCols; j++) { 156 out.setSubMatrix(i * dL1.getNumRows(), j * dL1.getNumColumns(), dL1); 157 } 158 } 159 return out; 160 } 161 162 public Matrix predict(double[] x){ 163 Matrix X = prepareMatrix(x); 164 165 Matrix hiddenTimes = weightsL1.transpose().times(X); 166 Matrix hiddenVal = prepareMatrix(gMat.apply(hiddenTimes).getColumn(0)); 167 Matrix finalTimes = weightsL2.transpose().times(hiddenVal); 168 Matrix finalVal = gMat.apply(finalTimes); 169 170 return finalVal; 171 172 } 173 174 175 private Matrix prepareMatrix(Vector y) { 176 Matrix Y = DMF.createMatrix(1, y.getDimensionality() + 1); 177 Y.setElement(0, 0, 1); 178 Y.setSubMatrix(0, 1, DMF.copyRowVectors(y)); 179 return Y.transpose(); 180 } 181 182 private Matrix prepareMatrix(double[] y) { 183 Matrix Y = DMF.createMatrix(1, y.length + 1); 184 Y.setElement(0, 0, 1); 185 Y.setSubMatrix(0, 1, DMF.copyArray(new double[][]{y})); 186 return Y.transpose(); 187 } 188 189 public static void main(String[] args) throws InterruptedException { 190 OnlineBackpropOneHidden bp = new OnlineBackpropOneHidden(2, 2, 1); 191 FImage img = new FImage(200,200); 192 img = imagePredict(bp,img); 193 ColourMap m = ColourMap.Hot; 194 195 DisplayUtilities.displayName(m.apply(img), "xor"); 196 int npixels = img.width*img.height; 197 int half = img.width/2; 198 int[] pixels = RandomData.getUniqueRandomInts(npixels, 0, npixels); 199 while(true){ 200// for (int i = 0; i < pixels.length; i++) { 201// int pixel = pixels[i]; 202// int y = pixel / img.width; 203// int x = pixel - (y * img.width); 204// bp.update(new double[]{x < half ? -1 : 1,y < half ? -1 : 1},new double[]{xorValue(half,x,y)}); 205//// Thread.sleep(5); 206// } 207 bp.update(new double[]{0,0},new double[]{0}); 208 bp.update(new double[]{1,1},new double[]{0}); 209 bp.update(new double[]{0,1},new double[]{1}); 210 bp.update(new double[]{1,0},new double[]{1}); 211 imagePredict(bp, img); 212 DisplayUtilities.displayName(m.apply(img),"xor"); 213 } 214 } 215 216 private static FImage imagePredict(OnlineBackpropOneHidden bp, FImage img) { 217 double[] pos = new double[2]; 218 int half = img.width/2; 219 for (int y = 0; y < img.height; y++) { 220 for (int x = 0; x < img.width; x++) { 221 pos[0] = x < half ? 0 : 1; 222 pos[1] = y < half ? 0 : 1; 223 float ret = (float) bp.predict(pos).getElement(0, 0); 224 img.pixels[y][x] = ret; 225 } 226 } 227 return img; 228 } 229}