001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.linear.experiments.sinabill; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033 034import java.io.File; 035import java.io.IOException; 036import java.util.ArrayList; 037import java.util.Collection; 038import java.util.List; 039 040import org.apache.log4j.ConsoleAppender; 041import org.apache.log4j.FileAppender; 042import org.apache.log4j.Level; 043import org.apache.log4j.Logger; 044import org.apache.log4j.PatternLayout; 045import org.openimaj.io.IOUtils; 046import org.openimaj.math.matrix.CFMatrixUtils; 047import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator; 048import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Fold; 049import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode; 050import org.openimaj.ml.linear.evaluation.BilinearEvaluator; 051import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator; 052import org.openimaj.ml.linear.learner.BilinearLearnerParameters; 053import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner; 054import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy; 055import org.openimaj.ml.linear.learner.loss.MatSquareLossFunction; 056import org.openimaj.util.pair.Pair; 057 058import com.google.common.primitives.Doubles; 059import com.jmatio.io.MatFileWriter; 060import com.jmatio.types.MLArray; 061 062/** 063 * Optimise lambda and eta0 and learning rates with a line search 064 * 065 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 066 */ 067public class LambdaSearchAustrian { 068 069 private static final int NFOLDS = 1; 070 private static final String ROOT = "/Users/ss/Experiments/bilinear/austrian/"; 071 private static final String OUTPUT_ROOT = "/Users/ss/Dropbox/TrendMiner/Collaboration/StreamingBilinear2014/experiments"; 072 private final Logger logger = Logger.getLogger(getClass()); 073 074 /** 075 * @param args 076 * @throws IOException 077 */ 078 public static void main(String[] args) throws IOException { 079 final LambdaSearchAustrian exp = new LambdaSearchAustrian(); 080 exp.performExperiment(); 081 } 082 083 private long expStartTime = System.currentTimeMillis(); 084 085 /** 086 * @throws IOException 087 */ 088 public void performExperiment() throws IOException { 089 final List<BillMatlabFileDataGenerator.Fold> folds = prepareFolds(); 090 final BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator( 091 new File(dataFromRoot("normalised.mat")), "user_vsr_for_polls_SINA", 092 new File(dataFromRoot("unnormalised.mat")), 093 98, false, 094 folds 095 ); 096 prepareExperimentLog(); 097 final BilinearEvaluator eval = new RootMeanSumLossEvaluator(); 098 for (int i = 0; i < bmfdg.nFolds(); i++) { 099 logger.info("Starting Fold: " + i); 100 final BilinearSparseOnlineLearner best = lineSearchParams(i, bmfdg); 101 logger.debug("Best params found! Starting test..."); 102 bmfdg.setFold(i, Mode.TEST); 103 eval.setLearner(best); 104 final double ev = eval.evaluate(bmfdg.generateAll()); 105 logger.debug("Test RMSE: " + ev); 106 107 } 108 } 109 110 private BilinearSparseOnlineLearner lineSearchParams(int fold, BillMatlabFileDataGenerator source) { 111 BilinearSparseOnlineLearner best = null; 112 double bestScore = Double.MAX_VALUE; 113 final BilinearEvaluator eval = new RootMeanSumLossEvaluator(); 114 int j = 0; 115 final List<BilinearLearnerParameters> parameterLineSearch = parameterLineSearch(); 116 logger.info("Optimising params, searching: " + parameterLineSearch.size()); 117 for (final BilinearLearnerParameters next : parameterLineSearch) { 118 logger.info(String.format("Optimising params %d/%d", j + 1, parameterLineSearch.size())); 119 logger.debug("Current Params:\n" + next.toString()); 120 final BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(next); 121 // Train the model with the new parameters 122 source.setFold(fold, Mode.TRAINING); 123 Pair<Matrix> pair = null; 124 logger.debug("Training..."); 125 while ((pair = source.generate()) != null) { 126 learner.process(pair.firstObject(), pair.secondObject()); 127 } 128 logger.debug("Generating score of validation set"); 129 // validate with the validation set 130 source.setFold(fold, Mode.VALIDATION); 131 eval.setLearner(learner); 132 final double loss = eval.evaluate(source.generateAll()); 133 logger.debug("Total RMSE: " + loss); 134 logger.debug("U sparcity: " + CFMatrixUtils.sparsity(learner.getU())); 135 logger.debug("W sparcity: " + CFMatrixUtils.sparsity(learner.getW())); 136 // record the best 137 if (loss < bestScore) { 138 logger.info("New best score detected!"); 139 bestScore = loss; 140 best = learner; 141 logger.info("New Best Config:\n" + best.getParams()); 142 logger.info("New Best Loss:" + loss); 143 saveFoldParameterLearner(fold, j, learner); 144 } 145 j++; 146 } 147 return best; 148 } 149 150 private void saveFoldParameterLearner(int fold, int j, BilinearSparseOnlineLearner learner) { 151 // save the state 152 final File learnerOut = new File(String.format("%s/fold_%d", currentOutputRoot(), fold), String.format( 153 "learner_%d", j)); 154 final File learnerOutMat = new File(String.format("%s/fold_%d", currentOutputRoot(), fold), String.format( 155 "learner_%d.mat", j)); 156 learnerOut.getParentFile().mkdirs(); 157 try { 158 IOUtils.writeBinary(learnerOut, learner); 159 final Collection<MLArray> data = new ArrayList<MLArray>(); 160 data.add(CFMatrixUtils.toMLArray("u", learner.getU())); 161 data.add(CFMatrixUtils.toMLArray("w", learner.getW())); 162 if (learner.getBias() != null) { 163 data.add(CFMatrixUtils.toMLArray("b", learner.getBias())); 164 } 165 final MatFileWriter writer = new MatFileWriter(learnerOutMat, data); 166 } catch (final IOException e) { 167 throw new RuntimeException(e); 168 } 169 } 170 171 private List<BilinearLearnerParameters> parameterLineSearch() { 172 final BilinearLearnerParameters params = prepareParams(); 173 final BilinearLearnerParametersLineSearch iter = new BilinearLearnerParametersLineSearch(params); 174 175 iter.addIteration(BilinearLearnerParameters.ETA0_U, Doubles.asList(new double[] { 0.0001 })); 176 iter.addIteration(BilinearLearnerParameters.ETA0_W, Doubles.asList(new double[] { 0.005 })); 177 iter.addIteration(BilinearLearnerParameters.ETA0_BIAS, Doubles.asList(new double[] { 50 })); 178 iter.addIteration(BilinearLearnerParameters.LAMBDA_U, Doubles.asList(new double[] { 0.00001 })); 179 iter.addIteration(BilinearLearnerParameters.LAMBDA_W, Doubles.asList(new double[] { 0.00001 })); 180 181 final List<BilinearLearnerParameters> ret = new ArrayList<BilinearLearnerParameters>(); 182 for (final BilinearLearnerParameters param : iter) { 183 ret.add(param); 184 } 185 return ret; 186 } 187 188 private List<Fold> prepareFolds() { 189 final List<Fold> set_fold = new ArrayList<BillMatlabFileDataGenerator.Fold>(); 190 191 // [24/02/2014 16:58:23] .@bill: 192 final int step = 5; // % test_size 193 final int t_size = 48; // % training_size 194 final int v_size = 8; 195 for (int i = 0; i < NFOLDS; i++) { 196 final int total = i * step + t_size; 197 final int[] training = new int[total - v_size]; 198 final int[] test = new int[step]; 199 final int[] validation = new int[v_size]; 200 int j = 0; 201 int traini = 0; 202 final int tt = (int) Math.round(total / 2.) - 1; 203 for (; j < tt - v_size / 2; j++, traini++) { 204 training[traini] = j; 205 } 206 for (int k = 0; k < validation.length; k++, j++) { 207 validation[k] = j; 208 } 209 for (; j < total; j++, traini++) { 210 training[traini] = j; 211 } 212 for (int k = 0; k < test.length; k++, j++) { 213 test[k] = j; 214 } 215 final Fold foldi = new Fold(training, test, validation); 216 set_fold.add(foldi); 217 } 218 // [24/02/2014 16:59:07] .@bill: set_fold{1,1} 219 return set_fold; 220 } 221 222 private BilinearLearnerParameters prepareParams() { 223 final BilinearLearnerParameters params = new BilinearLearnerParameters(); 224 225 params.put(BilinearLearnerParameters.ETA0_U, null); 226 params.put(BilinearLearnerParameters.ETA0_W, null); 227 params.put(BilinearLearnerParameters.LAMBDA_U, null); 228 params.put(BilinearLearnerParameters.LAMBDA_W, null); 229 params.put(BilinearLearnerParameters.ETA0_BIAS, null); 230 231 params.put(BilinearLearnerParameters.BICONVEX_TOL, 0.01); 232 params.put(BilinearLearnerParameters.BICONVEX_MAXITER, 10); 233 params.put(BilinearLearnerParameters.BIAS, true); 234 params.put(BilinearLearnerParameters.WINITSTRAT, new SparseZerosInitStrategy()); 235 params.put(BilinearLearnerParameters.UINITSTRAT, new SparseZerosInitStrategy()); 236 params.put(BilinearLearnerParameters.LOSS, new MatSquareLossFunction()); 237 return params; 238 } 239 240 /** 241 * @param data 242 * @return the data file from the root 243 */ 244 public static String dataFromRoot(String data) { 245 return String.format("%s/%s", ROOT, data); 246 } 247 248 protected void prepareExperimentLog() throws IOException { 249 final ConsoleAppender console = new ConsoleAppender(); // create 250 // appender 251 // configure the appender 252 final String PATTERN = "[%p->%C{1}] %m%n"; 253 console.setLayout(new PatternLayout(PATTERN)); 254 console.setThreshold(Level.INFO); 255 console.activateOptions(); 256 // add appender to any Logger (here is root) 257 Logger.getRootLogger().addAppender(console); 258 final File expRoot = prepareExperimentRoot(); 259 260 final File logFile = new File(expRoot, "log"); 261 if (logFile.exists()) 262 logFile.delete(); 263 final String TIMED_PATTERN = "[%d{HH:mm:ss} %p->%C{1}] %m%n"; 264 final FileAppender file = new FileAppender(new PatternLayout(TIMED_PATTERN), logFile.getAbsolutePath()); 265 file.setThreshold(Level.DEBUG); 266 file.activateOptions(); 267 Logger.getRootLogger().addAppender(file); 268 logger.info("Experiment root: " + expRoot); 269 270 } 271 272 /** 273 * @return 274 * @throws IOException 275 */ 276 public File prepareExperimentRoot() throws IOException { 277 final String experimentRoot = currentOutputRoot(); 278 final File expRoot = new File(experimentRoot); 279 if (expRoot.exists() && expRoot.isDirectory()) 280 return expRoot; 281 logger.debug("Experiment root: " + expRoot); 282 if (!expRoot.mkdirs()) 283 throw new IOException("Couldn't prepare experiment output"); 284 return expRoot; 285 } 286 287 private String currentOutputRoot() { 288 return String.format("%s/%s/%s", OUTPUT_ROOT, getExperimentSetName(), "" + currentExperimentTime()); 289 } 290 291 private long currentExperimentTime() { 292 return expStartTime; 293 } 294 295 private String getExperimentSetName() { 296 return "streamingBilinear/optimiselambda"; 297 } 298}