001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.linear.learner; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033 034import org.apache.log4j.Logger; 035import org.openimaj.citation.annotation.Reference; 036import org.openimaj.citation.annotation.ReferenceType; 037import org.openimaj.math.matrix.CFMatrixUtils; 038 039/** 040 * An implementation of a stochastic gradient decent with proximal parameter 041 * adjustment (for regularised parameters). 042 * <p> 043 * Data is dealt with sequentially using a one pass implementation of the online 044 * proximal algorithm described in chapter 9 and 10 of: The Geometry of 045 * Constrained Structured Prediction: Applications to Inference and Learning of 046 * Natural Language Syntax, PhD, Andre T. Martins 047 * <p> 048 * This is a direct extension of the {@link BilinearSparseOnlineLearner} but 049 * instead of a mixed update scheme (i.e. for a number of iterations W and U are 050 * updated synchronously) we have an unmixed scheme where W is updated for a 051 * number of iterations, followed by U for a number of iterations continuing as 052 * a whole for a number of iterations 053 * <p> 054 * The implementation does the following: 055 * <ul> 056 * <li>When an X,Y is received: 057 * <ul> 058 * <li>Update currently held batch 059 * <li>If the batch is full: 060 * <ul> 061 * <li>While There is a great deal of change in U and W: 062 * <ul> 063 * <li>While There is a great deal of change in W: 064 * <ul> 065 * <li>Calculate the gradient of W holding U fixed 066 * <li>Proximal update of W 067 * <li>Calculate the gradient of Bias holding U and W fixed 068 * </ul> 069 * <li>While There is a great deal of change in U: 070 * <ul> 071 * <li>Calculate the gradient of U holding W fixed 072 * <li>Proximal update of U 073 * <li>Calculate the gradient of Bias holding U and W fixed 074 * </ul> 075 * </ul> 076 * <li>flush the batch 077 * </ul> 078 * <li>return current U and W (same as last time is batch isn't filled yet) 079 * </ul> 080 * </ul> 081 * 082 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 083 * 084 */ 085@Reference( 086 author = { "Andre F. T. Martins" }, 087 title = "The Geometry of Constrained Structured Prediction: Applications to Inference and Learning of Natural Language Syntax", 088 type = ReferenceType.Phdthesis, 089 year = "2012") 090public class BilinearUnmixedSparseOnlineLearner extends BilinearSparseOnlineLearner { 091 092 static Logger logger = Logger.getLogger(BilinearUnmixedSparseOnlineLearner.class); 093 094 @Override 095 protected Matrix updateW(Matrix currentW, double wLossWeighted, double weightedLambda) { 096 Matrix current = currentW; 097 int iter = 0; 098 final Double biconvextol = this.params.getTyped(BilinearLearnerParameters.BICONVEX_TOL); 099 final Integer maxiter = this.params.getTyped(BilinearLearnerParameters.BICONVEX_MAXITER); 100 while (true) { 101 final Matrix newcurrent = super.updateW(current, wLossWeighted, weightedLambda); 102 final double sumchange = CFMatrixUtils.absSum(current.minus(newcurrent)); 103 final double total = CFMatrixUtils.absSum(current); 104 final double ratio = sumchange / total; 105 current = newcurrent; 106 if (ratio < biconvextol || iter >= maxiter) { 107 logger.debug("W tolerance reached after iteration: " + iter); 108 break; 109 } 110 iter++; 111 } 112 return current; 113 } 114 115 @Override 116 protected Matrix updateU(Matrix currentU, Matrix neww, double uLossWeighted, double weightedLambda) { 117 Matrix current = currentU; 118 int iter = 0; 119 final Double biconvextol = this.params.getTyped(BilinearLearnerParameters.BICONVEX_TOL); 120 final Integer maxiter = this.params.getTyped(BilinearLearnerParameters.BICONVEX_MAXITER); 121 while (true) { 122 final Matrix newcurrent = super.updateU(current, neww, uLossWeighted, weightedLambda); 123 final double sumchange = CFMatrixUtils.absSum(current.minus(newcurrent)); 124 final double total = CFMatrixUtils.absSum(current); 125 final double ratio = sumchange / total; 126 current = newcurrent; 127 if (ratio < biconvextol || iter >= maxiter) { 128 logger.debug("U tolerance reached after iteration: " + iter); 129 break; 130 } 131 iter++; 132 } 133 return current; 134 } 135}