001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.experiments.sinabill;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033
034import java.io.File;
035import java.io.IOException;
036import java.util.ArrayList;
037import java.util.Collection;
038import java.util.List;
039
040import org.apache.log4j.ConsoleAppender;
041import org.apache.log4j.FileAppender;
042import org.apache.log4j.Level;
043import org.apache.log4j.Logger;
044import org.apache.log4j.PatternLayout;
045import org.openimaj.io.IOUtils;
046import org.openimaj.math.matrix.CFMatrixUtils;
047import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator;
048import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Fold;
049import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode;
050import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
051import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator;
052import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
053import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
054import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy;
055import org.openimaj.ml.linear.learner.loss.MatSquareLossFunction;
056import org.openimaj.util.pair.Pair;
057
058import com.google.common.primitives.Doubles;
059import com.jmatio.io.MatFileWriter;
060import com.jmatio.types.MLArray;
061
062/**
063 * Optimise lambda and eta0 and learning rates with a line search
064 * 
065 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
066 */
067public class LambdaSearchAustrian {
068
069        private static final int NFOLDS = 1;
070        private static final String ROOT = "/Users/ss/Experiments/bilinear/austrian/";
071        private static final String OUTPUT_ROOT = "/Users/ss/Dropbox/TrendMiner/Collaboration/StreamingBilinear2014/experiments";
072        private final Logger logger = Logger.getLogger(getClass());
073
074        /**
075         * @param args
076         * @throws IOException
077         */
078        public static void main(String[] args) throws IOException {
079                final LambdaSearchAustrian exp = new LambdaSearchAustrian();
080                exp.performExperiment();
081        }
082
083        private long expStartTime = System.currentTimeMillis();
084
085        /**
086         * @throws IOException
087         */
088        public void performExperiment() throws IOException {
089                final List<BillMatlabFileDataGenerator.Fold> folds = prepareFolds();
090                final BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator(
091                                new File(dataFromRoot("normalised.mat")), "user_vsr_for_polls_SINA",
092                                new File(dataFromRoot("unnormalised.mat")),
093                                98, false,
094                                folds
095                                );
096                prepareExperimentLog();
097                final BilinearEvaluator eval = new RootMeanSumLossEvaluator();
098                for (int i = 0; i < bmfdg.nFolds(); i++) {
099                        logger.info("Starting Fold: " + i);
100                        final BilinearSparseOnlineLearner best = lineSearchParams(i, bmfdg);
101                        logger.debug("Best params found! Starting test...");
102                        bmfdg.setFold(i, Mode.TEST);
103                        eval.setLearner(best);
104                        final double ev = eval.evaluate(bmfdg.generateAll());
105                        logger.debug("Test RMSE: " + ev);
106
107                }
108        }
109
110        private BilinearSparseOnlineLearner lineSearchParams(int fold, BillMatlabFileDataGenerator source) {
111                BilinearSparseOnlineLearner best = null;
112                double bestScore = Double.MAX_VALUE;
113                final BilinearEvaluator eval = new RootMeanSumLossEvaluator();
114                int j = 0;
115                final List<BilinearLearnerParameters> parameterLineSearch = parameterLineSearch();
116                logger.info("Optimising params, searching: " + parameterLineSearch.size());
117                for (final BilinearLearnerParameters next : parameterLineSearch) {
118                        logger.info(String.format("Optimising params %d/%d", j + 1, parameterLineSearch.size()));
119                        logger.debug("Current Params:\n" + next.toString());
120                        final BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(next);
121                        // Train the model with the new parameters
122                        source.setFold(fold, Mode.TRAINING);
123                        Pair<Matrix> pair = null;
124                        logger.debug("Training...");
125                        while ((pair = source.generate()) != null) {
126                                learner.process(pair.firstObject(), pair.secondObject());
127                        }
128                        logger.debug("Generating score of validation set");
129                        // validate with the validation set
130                        source.setFold(fold, Mode.VALIDATION);
131                        eval.setLearner(learner);
132                        final double loss = eval.evaluate(source.generateAll());
133                        logger.debug("Total RMSE: " + loss);
134                        logger.debug("U sparcity: " + CFMatrixUtils.sparsity(learner.getU()));
135                        logger.debug("W sparcity: " + CFMatrixUtils.sparsity(learner.getW()));
136                        // record the best
137                        if (loss < bestScore) {
138                                logger.info("New best score detected!");
139                                bestScore = loss;
140                                best = learner;
141                                logger.info("New Best Config:\n" + best.getParams());
142                                logger.info("New Best Loss:" + loss);
143                                saveFoldParameterLearner(fold, j, learner);
144                        }
145                        j++;
146                }
147                return best;
148        }
149
150        private void saveFoldParameterLearner(int fold, int j, BilinearSparseOnlineLearner learner) {
151                // save the state
152                final File learnerOut = new File(String.format("%s/fold_%d", currentOutputRoot(), fold), String.format(
153                                "learner_%d", j));
154                final File learnerOutMat = new File(String.format("%s/fold_%d", currentOutputRoot(), fold), String.format(
155                                "learner_%d.mat", j));
156                learnerOut.getParentFile().mkdirs();
157                try {
158                        IOUtils.writeBinary(learnerOut, learner);
159                        final Collection<MLArray> data = new ArrayList<MLArray>();
160                        data.add(CFMatrixUtils.toMLArray("u", learner.getU()));
161                        data.add(CFMatrixUtils.toMLArray("w", learner.getW()));
162                        if (learner.getBias() != null) {
163                                data.add(CFMatrixUtils.toMLArray("b", learner.getBias()));
164                        }
165                        final MatFileWriter writer = new MatFileWriter(learnerOutMat, data);
166                } catch (final IOException e) {
167                        throw new RuntimeException(e);
168                }
169        }
170
171        private List<BilinearLearnerParameters> parameterLineSearch() {
172                final BilinearLearnerParameters params = prepareParams();
173                final BilinearLearnerParametersLineSearch iter = new BilinearLearnerParametersLineSearch(params);
174
175                iter.addIteration(BilinearLearnerParameters.ETA0_U, Doubles.asList(new double[] { 0.0001 }));
176                iter.addIteration(BilinearLearnerParameters.ETA0_W, Doubles.asList(new double[] { 0.005 }));
177                iter.addIteration(BilinearLearnerParameters.ETA0_BIAS, Doubles.asList(new double[] { 50 }));
178                iter.addIteration(BilinearLearnerParameters.LAMBDA_U, Doubles.asList(new double[] { 0.00001 }));
179                iter.addIteration(BilinearLearnerParameters.LAMBDA_W, Doubles.asList(new double[] { 0.00001 }));
180
181                final List<BilinearLearnerParameters> ret = new ArrayList<BilinearLearnerParameters>();
182                for (final BilinearLearnerParameters param : iter) {
183                        ret.add(param);
184                }
185                return ret;
186        }
187
188        private List<Fold> prepareFolds() {
189                final List<Fold> set_fold = new ArrayList<BillMatlabFileDataGenerator.Fold>();
190
191                // [24/02/2014 16:58:23] .@bill:
192                final int step = 5; // % test_size
193                final int t_size = 48; // % training_size
194                final int v_size = 8;
195                for (int i = 0; i < NFOLDS; i++) {
196                        final int total = i * step + t_size;
197                        final int[] training = new int[total - v_size];
198                        final int[] test = new int[step];
199                        final int[] validation = new int[v_size];
200                        int j = 0;
201                        int traini = 0;
202                        final int tt = (int) Math.round(total / 2.) - 1;
203                        for (; j < tt - v_size / 2; j++, traini++) {
204                                training[traini] = j;
205                        }
206                        for (int k = 0; k < validation.length; k++, j++) {
207                                validation[k] = j;
208                        }
209                        for (; j < total; j++, traini++) {
210                                training[traini] = j;
211                        }
212                        for (int k = 0; k < test.length; k++, j++) {
213                                test[k] = j;
214                        }
215                        final Fold foldi = new Fold(training, test, validation);
216                        set_fold.add(foldi);
217                }
218                // [24/02/2014 16:59:07] .@bill: set_fold{1,1}
219                return set_fold;
220        }
221
222        private BilinearLearnerParameters prepareParams() {
223                final BilinearLearnerParameters params = new BilinearLearnerParameters();
224
225                params.put(BilinearLearnerParameters.ETA0_U, null);
226                params.put(BilinearLearnerParameters.ETA0_W, null);
227                params.put(BilinearLearnerParameters.LAMBDA_U, null);
228                params.put(BilinearLearnerParameters.LAMBDA_W, null);
229                params.put(BilinearLearnerParameters.ETA0_BIAS, null);
230
231                params.put(BilinearLearnerParameters.BICONVEX_TOL, 0.01);
232                params.put(BilinearLearnerParameters.BICONVEX_MAXITER, 10);
233                params.put(BilinearLearnerParameters.BIAS, true);
234                params.put(BilinearLearnerParameters.WINITSTRAT, new SparseZerosInitStrategy());
235                params.put(BilinearLearnerParameters.UINITSTRAT, new SparseZerosInitStrategy());
236                params.put(BilinearLearnerParameters.LOSS, new MatSquareLossFunction());
237                return params;
238        }
239
240        /**
241         * @param data
242         * @return the data file from the root
243         */
244        public static String dataFromRoot(String data) {
245                return String.format("%s/%s", ROOT, data);
246        }
247
248        protected void prepareExperimentLog() throws IOException {
249                final ConsoleAppender console = new ConsoleAppender(); // create
250                                                                                                                                // appender
251                // configure the appender
252                final String PATTERN = "[%p->%C{1}] %m%n";
253                console.setLayout(new PatternLayout(PATTERN));
254                console.setThreshold(Level.INFO);
255                console.activateOptions();
256                // add appender to any Logger (here is root)
257                Logger.getRootLogger().addAppender(console);
258                final File expRoot = prepareExperimentRoot();
259
260                final File logFile = new File(expRoot, "log");
261                if (logFile.exists())
262                        logFile.delete();
263                final String TIMED_PATTERN = "[%d{HH:mm:ss} %p->%C{1}] %m%n";
264                final FileAppender file = new FileAppender(new PatternLayout(TIMED_PATTERN), logFile.getAbsolutePath());
265                file.setThreshold(Level.DEBUG);
266                file.activateOptions();
267                Logger.getRootLogger().addAppender(file);
268                logger.info("Experiment root: " + expRoot);
269
270        }
271
272        /**
273         * @return
274         * @throws IOException
275         */
276        public File prepareExperimentRoot() throws IOException {
277                final String experimentRoot = currentOutputRoot();
278                final File expRoot = new File(experimentRoot);
279                if (expRoot.exists() && expRoot.isDirectory())
280                        return expRoot;
281                logger.debug("Experiment root: " + expRoot);
282                if (!expRoot.mkdirs())
283                        throw new IOException("Couldn't prepare experiment output");
284                return expRoot;
285        }
286
287        private String currentOutputRoot() {
288                return String.format("%s/%s/%s", OUTPUT_ROOT, getExperimentSetName(), "" + currentExperimentTime());
289        }
290
291        private long currentExperimentTime() {
292                return expStartTime;
293        }
294
295        private String getExperimentSetName() {
296                return "streamingBilinear/optimiselambda";
297        }
298}