001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.learner;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033
034import org.apache.log4j.Logger;
035import org.openimaj.citation.annotation.Reference;
036import org.openimaj.citation.annotation.ReferenceType;
037import org.openimaj.math.matrix.CFMatrixUtils;
038
039/**
040 * An implementation of a stochastic gradient decent with proximal parameter
041 * adjustment (for regularised parameters).
042 * <p>
043 * Data is dealt with sequentially using a one pass implementation of the online
044 * proximal algorithm described in chapter 9 and 10 of: The Geometry of
045 * Constrained Structured Prediction: Applications to Inference and Learning of
046 * Natural Language Syntax, PhD, Andre T. Martins
047 * <p>
048 * This is a direct extension of the {@link BilinearSparseOnlineLearner} but
049 * instead of a mixed update scheme (i.e. for a number of iterations W and U are
050 * updated synchronously) we have an unmixed scheme where W is updated for a
051 * number of iterations, followed by U for a number of iterations continuing as
052 * a whole for a number of iterations
053 * <p>
054 * The implementation does the following:
055 * <ul>
056 * <li>When an X,Y is received:
057 * <ul>
058 * <li>Update currently held batch
059 * <li>If the batch is full:
060 * <ul>
061 * <li>While There is a great deal of change in U and W:
062 * <ul>
063 * <li>While There is a great deal of change in W:
064 * <ul>
065 * <li>Calculate the gradient of W holding U fixed
066 * <li>Proximal update of W
067 * <li>Calculate the gradient of Bias holding U and W fixed
068 * </ul>
069 * <li>While There is a great deal of change in U:
070 * <ul>
071 * <li>Calculate the gradient of U holding W fixed
072 * <li>Proximal update of U
073 * <li>Calculate the gradient of Bias holding U and W fixed
074 * </ul>
075 * </ul>
076 * <li>flush the batch
077 * </ul>
078 * <li>return current U and W (same as last time is batch isn't filled yet)
079 * </ul>
080 * </ul>
081 * 
082 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
083 * 
084 */
085@Reference(
086                author = { "Andre F. T. Martins" },
087                title = "The Geometry of Constrained Structured Prediction: Applications to Inference and Learning of Natural Language Syntax",
088                type = ReferenceType.Phdthesis,
089                year = "2012")
090public class BilinearUnmixedSparseOnlineLearner extends BilinearSparseOnlineLearner {
091
092        static Logger logger = Logger.getLogger(BilinearUnmixedSparseOnlineLearner.class);
093
094        @Override
095        protected Matrix updateW(Matrix currentW, double wLossWeighted, double weightedLambda) {
096                Matrix current = currentW;
097                int iter = 0;
098                final Double biconvextol = this.params.getTyped(BilinearLearnerParameters.BICONVEX_TOL);
099                final Integer maxiter = this.params.getTyped(BilinearLearnerParameters.BICONVEX_MAXITER);
100                while (true) {
101                        final Matrix newcurrent = super.updateW(current, wLossWeighted, weightedLambda);
102                        final double sumchange = CFMatrixUtils.absSum(current.minus(newcurrent));
103                        final double total = CFMatrixUtils.absSum(current);
104                        final double ratio = sumchange / total;
105                        current = newcurrent;
106                        if (ratio < biconvextol || iter >= maxiter) {
107                                logger.debug("W tolerance reached after iteration: " + iter);
108                                break;
109                        }
110                        iter++;
111                }
112                return current;
113        }
114
115        @Override
116        protected Matrix updateU(Matrix currentU, Matrix neww, double uLossWeighted, double weightedLambda) {
117                Matrix current = currentU;
118                int iter = 0;
119                final Double biconvextol = this.params.getTyped(BilinearLearnerParameters.BICONVEX_TOL);
120                final Integer maxiter = this.params.getTyped(BilinearLearnerParameters.BICONVEX_MAXITER);
121                while (true) {
122                        final Matrix newcurrent = super.updateU(current, neww, uLossWeighted, weightedLambda);
123                        final double sumchange = CFMatrixUtils.absSum(current.minus(newcurrent));
124                        final double total = CFMatrixUtils.absSum(current);
125                        final double ratio = sumchange / total;
126                        current = newcurrent;
127                        if (ratio < biconvextol || iter >= maxiter) {
128                                logger.debug("U tolerance reached after iteration: " + iter);
129                                break;
130                        }
131                        iter++;
132                }
133                return current;
134        }
135}