001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.linear.learner.matlib; 031 032import java.io.DataInput; 033import java.io.DataOutput; 034import java.io.IOException; 035 036import org.apache.log4j.Logger; 037import org.openimaj.io.ReadWriteableBinary; 038import org.openimaj.math.matrix.DiagonalMatrix; 039import org.openimaj.math.matrix.MatlibMatrixUtils; 040import org.openimaj.ml.linear.learner.BilinearLearnerParameters; 041import org.openimaj.ml.linear.learner.OnlineLearner; 042import org.openimaj.ml.linear.learner.matlib.init.InitStrategy; 043import org.openimaj.ml.linear.learner.matlib.init.SparseSingleValueInitStrat; 044import org.openimaj.ml.linear.learner.matlib.loss.LossFunction; 045import org.openimaj.ml.linear.learner.matlib.loss.MatLossFunction; 046import org.openimaj.ml.linear.learner.matlib.regul.Regulariser; 047 048import ch.akuhn.matrix.Matrix; 049import ch.akuhn.matrix.SparseMatrix; 050 051 052/** 053 * An implementation of a stochastic gradient decent with proximal perameter adjustment 054 * (for regularised parameters). 055 * 056 * Data is dealt with sequentially using a one pass implementation of the 057 * online proximal algorithm described in chapter 9 and 10 of: 058 * The Geometry of Constrained Structured Prediction: Applications to Inference and 059 * Learning of Natural Language Syntax, PhD, Andre T. Martins 060 * 061 * The implementation does the following: 062 * - When an X,Y is recieved: 063 * - Update currently held batch 064 * - If the batch is full: 065 * - While There is a great deal of change in U and W: 066 * - Calculate the gradient of W holding U fixed 067 * - Proximal update of W 068 * - Calculate the gradient of U holding W fixed 069 * - Proximal update of U 070 * - Calculate the gradient of Bias holding U and W fixed 071 * - flush the batch 072 * - return current U and W (same as last time is batch isn't filled yet) 073 * 074 * 075 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 076 * 077 */ 078public class MatlibBilinearSparseOnlineLearner implements OnlineLearner<Matrix,Matrix>, ReadWriteableBinary{ 079 080 static Logger logger = Logger.getLogger(MatlibBilinearSparseOnlineLearner.class); 081 082 protected BilinearLearnerParameters params; 083 protected Matrix w; 084 protected Matrix u; 085 protected LossFunction loss; 086 protected Regulariser regul; 087 protected Double lambda_w,lambda_u; 088 protected Boolean biasMode; 089 protected Matrix bias; 090 protected Matrix diagX; 091 protected Double eta0_u; 092 protected Double eta0_w; 093 094 private Boolean forceSparcity; 095 096 private Boolean zStandardise; 097 098 private boolean nodataseen; 099 100 /** 101 * The default parameters. These won't work with your dataset, i promise. 102 */ 103 public MatlibBilinearSparseOnlineLearner() { 104 this(new BilinearLearnerParameters()); 105 } 106 /** 107 * @param params the parameters used by this learner 108 */ 109 public MatlibBilinearSparseOnlineLearner(BilinearLearnerParameters params) { 110 this.params = params; 111 reinitParams(); 112 } 113 114 /** 115 * must be called if any parameters are changed 116 */ 117 public void reinitParams() { 118 this.loss = this.params.getTyped(BilinearLearnerParameters.LOSS); 119 this.regul = this.params.getTyped(BilinearLearnerParameters.REGUL); 120 this.lambda_w = this.params.getTyped(BilinearLearnerParameters.LAMBDA_W); 121 this.lambda_u = this.params.getTyped(BilinearLearnerParameters.LAMBDA_U); 122 this.biasMode = this.params.getTyped(BilinearLearnerParameters.BIAS); 123 this.eta0_u = this.params.getTyped(BilinearLearnerParameters.ETA0_U); 124 this.eta0_w = this.params.getTyped(BilinearLearnerParameters.ETA0_W); 125 this.forceSparcity = this.params.getTyped(BilinearLearnerParameters.FORCE_SPARCITY); 126 this.zStandardise = this.params.getTyped(BilinearLearnerParameters.Z_STANDARDISE); 127 if(!this.loss.isMatrixLoss()) 128 this.loss = new MatLossFunction(this.loss); 129 this.nodataseen = true; 130 } 131 private void initParams(Matrix x, Matrix y, int xrows, int xcols, int ycols) { 132 final InitStrategy wstrat = getInitStrat(BilinearLearnerParameters.WINITSTRAT,x,y); 133 final InitStrategy ustrat = getInitStrat(BilinearLearnerParameters.UINITSTRAT,x,y); 134 this.w = wstrat.init(xrows, ycols); 135 this.u = ustrat.init(xcols, ycols); 136 137 this.bias = SparseMatrix.sparse(ycols,ycols); 138 if(this.biasMode){ 139 final InitStrategy bstrat = getInitStrat(BilinearLearnerParameters.BIASINITSTRAT,x,y); 140 this.bias = bstrat.init(ycols, ycols); 141 this.diagX = new DiagonalMatrix(ycols,1); 142 } 143 } 144 145 private InitStrategy getInitStrat(String initstrat, Matrix x, Matrix y) { 146 final InitStrategy strat = this.params.getTyped(initstrat); 147 return strat; 148 } 149 @Override 150 public void process(Matrix X, Matrix Y){ 151 final int nfeatures = X.rowCount(); 152 final int nusers = X.columnCount(); 153 final int ntasks = Y.columnCount(); 154// int ninstances = Y.rowCount(); // Assume 1 instance! 155 156 // only inits when the current params is null 157 if (this.w == null){ 158 initParams(X,Y,nfeatures, nusers, ntasks); // Number of words, users and tasks 159 } 160 161 final Double dampening = this.params.getTyped(BilinearLearnerParameters.DAMPENING); 162 final double weighting = 1.0 - dampening ; 163 164 logger.debug("... dampening w, u and bias by: " + weighting); 165 166 // Adjust for weighting 167 MatlibMatrixUtils.scaleInplace(this.w,weighting); 168 MatlibMatrixUtils.scaleInplace(this.u,weighting); 169 if(this.biasMode){ 170 MatlibMatrixUtils.scaleInplace(this.bias,weighting); 171 } 172 // First expand Y s.t. blocks of rows contain the task values for each row of Y. 173 // This means Yexp has (n * t x t) 174 final SparseMatrix Yexp = expandY(Y); 175 loss.setY(Yexp); 176 int iter = 0; 177 while(true) { 178 // We need to set the bias here because it is used in the loss calculation of U and W 179 if(this.biasMode) loss.setBias(this.bias); 180 iter += 1; 181 182 final double uLossWeight = etat(iter,eta0_u); 183 final double wLossWeighted = etat(iter,eta0_w); 184 final double weightedLambda_u = lambdat(iter,lambda_u); 185 final double weightedLambda_w = lambdat(iter,lambda_w); 186 // Dprime is tasks x nwords 187 Matrix Dprime = null; 188 if(this.nodataseen){ 189 this.nodataseen = false; 190 Matrix fakeut = new SparseSingleValueInitStrat(1).init(this.u.columnCount(),this.u.rowCount()); 191 Dprime = MatlibMatrixUtils.dotProductTranspose(fakeut, X); // i.e. fakeut . X^T 192 } else { 193 Dprime = MatlibMatrixUtils.dotProductTransposeTranspose(u, X); // i.e. u^T . X^T 194 } 195 196 // ... as is the cost function's X 197 if(zStandardise){ 198// Vector rowMean = CFMatrixUtils.rowMean(Dprime); 199// CFMatrixUtils.minusEqualsCol(Dprime,rowMean); 200 } 201 loss.setX(Dprime); 202 final Matrix neww = updateW(this.w,wLossWeighted, weightedLambda_w); 203 204 // Vprime is nusers x tasks 205 final Matrix Vt = MatlibMatrixUtils.transposeDotProduct(neww,X); // i.e. (X^T.neww)^T X.transpose().times(neww); 206 // ... so the loss function's X is (tasks x nusers) 207 loss.setX(Vt); 208 final Matrix newu = updateU(this.u,uLossWeight, weightedLambda_u); 209 210 final double sumchangew = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(neww, this.w)); 211 final double totalw = MatlibMatrixUtils.normF(this.w); 212 213 final double sumchangeu = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(newu, this.u)); 214 final double totalu = MatlibMatrixUtils.normF(this.u); 215 216 double ratioU = 0; 217 if(totalu!=0) ratioU = sumchangeu/totalu; 218 final double ratioW = 0; 219 if(totalw!=0) ratioU = sumchangew/totalw; 220 double ratioB = 0; 221 double ratio = ratioU + ratioW; 222 double totalbias = 0; 223 if(this.biasMode){ 224 Matrix mult = MatlibMatrixUtils.dotProductTransposeTranspose(newu, X); 225 mult = MatlibMatrixUtils.dotProduct(mult, neww); 226 MatlibMatrixUtils.plusInplace(mult, bias); 227 // We must set bias to null! 228 loss.setBias(null); 229 loss.setX(diagX); 230 // Calculate gradient of bias (don't regularise) 231 final Matrix biasGrad = loss.gradient(mult); 232 final double biasLossWeight = biasEtat(iter); 233 final Matrix newbias = updateBias(biasGrad, biasLossWeight); 234 235 final double sumchangebias = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(newbias, bias)); 236 totalbias = MatlibMatrixUtils.normF(this.bias); 237 if(totalbias!=0) ratioB = (sumchangebias/totalbias) ; 238 this.bias = newbias; 239 ratio += ratioB; 240 ratio/=3; 241 } 242 else{ 243 ratio/=2; 244 } 245 246 final Double biconvextol = this.params.getTyped("biconvex_tol"); 247 final Integer maxiter = this.params.getTyped("biconvex_maxiter"); 248 if(iter%3 == 0){ 249 logger.debug(String.format("Iter: %d. Last Ratio: %2.3f",iter,ratio)); 250 logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity(w)); 251 logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity(u)); 252 logger.debug("Total U magnitude: " + totalu); 253 logger.debug("Total W magnitude: " + totalw); 254 logger.debug("Total Bias: " + totalbias); 255 } 256 if(biconvextol < 0 || ratio < biconvextol || iter >= maxiter) { 257 logger.debug("tolerance reached after iteration: " + iter); 258 logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity(w)); 259 logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity(u)); 260 logger.debug("Total U magnitude: " + totalu); 261 logger.debug("Total W magnitude: " + totalw); 262 logger.debug("Total Bias: " + totalbias); 263 break; 264 } 265 } 266 } 267 268 protected Matrix updateBias(Matrix biasGrad, double biasLossWeight) { 269 final Matrix newbias = MatlibMatrixUtils.minus( 270 this.bias, 271 MatlibMatrixUtils.scaleInplace( 272 biasGrad, 273 biasLossWeight 274 ) 275 ); 276 return newbias; 277 } 278 protected Matrix updateW(Matrix currentW, double wLossWeighted, double weightedLambda) { 279 final Matrix gradW = loss.gradient(currentW); 280 MatlibMatrixUtils.scaleInplace(gradW,wLossWeighted); 281 282 Matrix neww = MatlibMatrixUtils.minus(currentW,gradW); 283 neww = regul.prox(neww, weightedLambda); 284 return neww; 285 } 286 protected Matrix updateU(Matrix currentU, double uLossWeight, double uWeightedLambda) { 287 final Matrix gradU = loss.gradient(currentU); 288 MatlibMatrixUtils.scaleInplace(gradU,uLossWeight); 289 Matrix newu = MatlibMatrixUtils.minus(currentU,gradU); 290 newu = regul.prox(newu, uWeightedLambda); 291 return newu; 292 } 293 private double lambdat(int iter, double lambda) { 294 return lambda/iter; 295 } 296 /** 297 * Given a flat value matrix, makes a diagonal sparse matrix containing the values as the diagonal 298 * @param Y 299 * @return the diagonalised Y 300 */ 301 public static SparseMatrix expandY(Matrix Y) { 302 final int ntasks = Y.columnCount(); 303 final SparseMatrix Yexp = SparseMatrix.sparse(ntasks, ntasks); 304 for (int touter = 0; touter < ntasks; touter++) { 305 for (int tinner = 0; tinner < ntasks; tinner++) { 306 if(tinner == touter){ 307 Yexp.put(touter, tinner, Y.get(0, tinner)); 308 } 309 else{ 310 Yexp.put(touter, tinner, Double.NaN); 311 } 312 } 313 } 314 return Yexp; 315 } 316 private double biasEtat(int iter){ 317 final Double biasEta0 = this.params.getTyped(BilinearLearnerParameters.ETA0_BIAS); 318 return biasEta0 / Math.sqrt(iter); 319 } 320 321 322 private double etat(int iter,double eta0) { 323 final Integer etaSteps = this.params.getTyped(BilinearLearnerParameters.ETASTEPS); 324 final double sqrtCeil = Math.sqrt(Math.ceil(iter/(double)etaSteps)); 325 return eta(eta0) / sqrtCeil; 326 } 327 private double eta(double eta0) { 328 return eta0 ; 329 } 330 331 332 333 /** 334 * @return the current apramters 335 */ 336 public BilinearLearnerParameters getParams() { 337 return this.params; 338 } 339 340 /** 341 * @return the current user matrix 342 */ 343 public Matrix getU(){ 344 return this.u; 345 } 346 347 /** 348 * @return the current word matrix 349 */ 350 public Matrix getW(){ 351 return this.w; 352 } 353 /** 354 * @return the current bias (null if {@link BilinearLearnerParameters#BIAS} is false 355 */ 356 public Matrix getBias() { 357 if(this.biasMode) 358 return this.bias; 359 else 360 return null; 361 } 362 363 /** 364 * Expand the U parameters matrix by added a set of rows. 365 * If currently unset, this function does nothing (assuming U will be initialised in the first round) 366 * The new U parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDUINITSTRAT} 367 * @param newUsers the number of new users to add 368 */ 369 public void addU(int newUsers) { 370 if(this.u == null) return; // If u has not be inited, then it will be on first process 371 final InitStrategy ustrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDUINITSTRAT,null,null); 372 final Matrix newU = ustrat.init(newUsers, this.u.columnCount()); 373 this.u = MatlibMatrixUtils.vstack(this.u,newU); 374 } 375 376 /** 377 * Expand the W parameters matrix by added a set of rows. 378 * If currently unset, this function does nothing (assuming W will be initialised in the first round) 379 * The new W parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDWINITSTRAT} 380 * @param newWords the number of new words to add 381 */ 382 public void addW(int newWords) { 383 if(this.w == null) return; // If w has not be inited, then it will be on first process 384 final InitStrategy wstrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDWINITSTRAT,null,null); 385 final Matrix newW = wstrat.init(newWords, this.w.columnCount()); 386 this.w = MatlibMatrixUtils.vstack(this.w,newW); 387 } 388 389 @Override 390 public MatlibBilinearSparseOnlineLearner clone(){ 391 final MatlibBilinearSparseOnlineLearner ret = new MatlibBilinearSparseOnlineLearner(this.getParams()); 392 ret.u = MatlibMatrixUtils.copy(this.u); 393 ret.w = MatlibMatrixUtils.copy(this.w); 394 if(this.biasMode){ 395 ret.bias = MatlibMatrixUtils.copy(this.bias); 396 } 397 return ret; 398 } 399 /** 400 * @param newu set the model's U 401 */ 402 public void setU(Matrix newu) { 403 this.u = newu; 404 } 405 406 /** 407 * @param neww set the model's W 408 */ 409 public void setW(Matrix neww) { 410 this.w = neww; 411 } 412 @Override 413 public void readBinary(DataInput in) throws IOException { 414 final int nwords = in.readInt(); 415 final int nusers = in.readInt(); 416 final int ntasks = in.readInt(); 417 418 419 this.w = SparseMatrix.sparse(nwords, ntasks); 420 for (int t = 0; t < ntasks; t++) { 421 for (int r = 0; r < nwords; r++) { 422 final double readDouble = in.readDouble(); 423 if(readDouble != 0){ 424 this.w.put(r, t, readDouble); 425 } 426 } 427 } 428 429 this.u = SparseMatrix.sparse(nusers, ntasks); 430 for (int t = 0; t < ntasks; t++) { 431 for (int r = 0; r < nusers; r++) { 432 final double readDouble = in.readDouble(); 433 if(readDouble != 0){ 434 this.u.put(r, t, readDouble); 435 } 436 } 437 } 438 439 this.bias = SparseMatrix.sparse(ntasks, ntasks); 440 for (int t1 = 0; t1 < ntasks; t1++) { 441 for (int t2 = 0; t2 < ntasks; t2++) { 442 final double readDouble = in.readDouble(); 443 if(readDouble != 0){ 444 this.bias.put(t1, t2, readDouble); 445 } 446 } 447 } 448 } 449 @Override 450 public byte[] binaryHeader() { 451 return "".getBytes(); 452 } 453 @Override 454 public void writeBinary(DataOutput out) throws IOException { 455 out.writeInt(w.rowCount()); 456 out.writeInt(u.rowCount()); 457 out.writeInt(u.columnCount()); 458 final double[] wdata = w.asColumnMajorArray(); 459 for (int i = 0; i < wdata.length; i++) { 460 out.writeDouble(wdata[i]); 461 } 462 final double[] udata = u.asColumnMajorArray(); 463 for (int i = 0; i < udata.length; i++) { 464 out.writeDouble(udata[i]); 465 } 466 final double[] biasdata = bias.asColumnMajorArray(); 467 for (int i = 0; i < biasdata.length; i++) { 468 out.writeDouble(biasdata[i]); 469 } 470 } 471 472 473 @Override 474 public Matrix predict(Matrix x) { 475 Matrix xt = MatlibMatrixUtils.transpose(x); 476 final Matrix mult = MatlibMatrixUtils.dotProduct(MatlibMatrixUtils.dotProduct(MatlibMatrixUtils.transpose(u), xt),this.w); 477 if(this.biasMode) MatlibMatrixUtils.plusInplace(mult,this.bias); 478 Matrix ydiag = new DiagonalMatrix(mult); 479 return ydiag; 480 } 481}