001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.linear.learner; 031 032import gov.sandia.cognition.math.matrix.Matrix; 033import gov.sandia.cognition.math.matrix.Vector; 034import gov.sandia.cognition.math.matrix.mtj.AbstractMTJMatrix; 035import gov.sandia.cognition.math.matrix.mtj.AbstractSparseMatrix; 036import gov.sandia.cognition.math.matrix.mtj.SparseColumnMatrix; 037import gov.sandia.cognition.math.matrix.mtj.SparseMatrix; 038import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ; 039 040import java.io.DataInput; 041import java.io.DataOutput; 042import java.io.IOException; 043 044import no.uib.cipr.matrix.sparse.FlexCompColMatrix; 045 046import org.apache.log4j.Logger; 047import org.openimaj.io.ReadWriteableBinary; 048import org.openimaj.math.matrix.CFMatrixUtils; 049import org.openimaj.ml.linear.learner.init.ContextAwareInitStrategy; 050import org.openimaj.ml.linear.learner.init.InitStrategy; 051import org.openimaj.ml.linear.learner.init.SparseSingleValueInitStrat; 052import org.openimaj.ml.linear.learner.loss.LossFunction; 053import org.openimaj.ml.linear.learner.loss.MatLossFunction; 054import org.openimaj.ml.linear.learner.regul.Regulariser; 055 056 057/** 058 * An implementation of a stochastic gradient decent with proximal perameter adjustment 059 * (for regularised parameters). 060 * 061 * Data is dealt with sequentially using a one pass implementation of the 062 * online proximal algorithm described in chapter 9 and 10 of: 063 * The Geometry of Constrained Structured Prediction: Applications to Inference and 064 * Learning of Natural Language Syntax, PhD, Andre T. Martins 065 * 066 * The implementation does the following: 067 * - When an X,Y is recieved: 068 * - Update currently held batch 069 * - If the batch is full: 070 * - While There is a great deal of change in U and W: 071 * - Calculate the gradient of W holding U fixed 072 * - Proximal update of W 073 * - Calculate the gradient of U holding W fixed 074 * - Proximal update of U 075 * - Calculate the gradient of Bias holding U and W fixed 076 * - flush the batch 077 * - return current U and W (same as last time is batch isn't filled yet) 078 * 079 * 080 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 081 * 082 */ 083public class BilinearSparseOnlineLearner implements OnlineLearner<Matrix,Matrix>, ReadWriteableBinary{ 084 085 static Logger logger = Logger.getLogger(BilinearSparseOnlineLearner.class); 086 087 protected BilinearLearnerParameters params; 088 protected Matrix w; 089 protected Matrix u; 090 protected SparseMatrixFactoryMTJ smf = SparseMatrixFactoryMTJ.INSTANCE; 091 protected LossFunction loss; 092 protected Regulariser regul; 093 protected Double lambda_w,lambda_u; 094 protected Boolean biasMode; 095 protected Matrix bias; 096 protected Matrix diagX; 097 protected Double eta0_u; 098 protected Double eta0_w; 099 100 private Boolean forceSparcity; 101 102 private Boolean zStandardise; 103 104 private boolean nodataseen; 105 106 private double eta_gamma; 107 108 private double biasEta0; 109 110 /** 111 * The default parameters. These won't work with your dataset, i promise. 112 */ 113 public BilinearSparseOnlineLearner() { 114 this(new BilinearLearnerParameters()); 115 } 116 /** 117 * @param params the parameters used by this learner 118 */ 119 public BilinearSparseOnlineLearner(BilinearLearnerParameters params) { 120 this.params = params; 121 reinitParams(); 122 } 123 124 /** 125 * must be called if any parameters are changed 126 */ 127 public void reinitParams() { 128 this.loss = this.params.getTyped(BilinearLearnerParameters.LOSS); 129 this.regul = this.params.getTyped(BilinearLearnerParameters.REGUL); 130 this.lambda_w = this.params.getTyped(BilinearLearnerParameters.LAMBDA_W); 131 this.lambda_u = this.params.getTyped(BilinearLearnerParameters.LAMBDA_U); 132 this.biasMode = this.params.getTyped(BilinearLearnerParameters.BIAS); 133 this.eta0_u = this.params.getTyped(BilinearLearnerParameters.ETA0_U); 134 this.eta0_w = this.params.getTyped(BilinearLearnerParameters.ETA0_W); 135 this.biasEta0 = this.params.getTyped(BilinearLearnerParameters.ETA0_BIAS); 136 this.eta_gamma = params.getTyped(BilinearLearnerParameters.ETA_GAMMA); 137 this.forceSparcity = this.params.getTyped(BilinearLearnerParameters.FORCE_SPARCITY); 138 this.zStandardise = this.params.getTyped(BilinearLearnerParameters.Z_STANDARDISE); 139 if(!this.loss.isMatrixLoss()) 140 this.loss = new MatLossFunction(this.loss); 141 this.nodataseen = true; 142 } 143 private void initParams(Matrix x, Matrix y, int xrows, int xcols, int ycols) { 144 final InitStrategy wstrat = getInitStrat(BilinearLearnerParameters.WINITSTRAT,x,y); 145 final InitStrategy ustrat = getInitStrat(BilinearLearnerParameters.UINITSTRAT,x,y); 146 this.w = wstrat.init(xrows, ycols); 147 this.u = ustrat.init(xcols, ycols); 148 if(this.forceSparcity) 149 { 150 this.u = CFMatrixUtils.asSparseColumn(this.u); 151 this.w = CFMatrixUtils.asSparseColumn(this.w); 152 } 153 154 this.bias = smf.createMatrix(ycols,ycols); 155 if(this.biasMode){ 156 final InitStrategy bstrat = getInitStrat(BilinearLearnerParameters.BIASINITSTRAT,x,y); 157 this.bias = bstrat.init(ycols, ycols); 158 this.diagX = smf.createIdentity(ycols, ycols); 159 } 160 } 161 162 163 164 private InitStrategy getInitStrat(String initstrat, Matrix x, Matrix y) { 165 final InitStrategy strat = this.params.getTyped(initstrat); 166 if(strat instanceof ContextAwareInitStrategy){ 167 final ContextAwareInitStrategy<Matrix, Matrix> cwStrat = this.params.getTyped(initstrat); 168 cwStrat.setLearner(this); 169 cwStrat.setContext(x, y); 170 return cwStrat; 171 } 172 return strat; 173 } 174 @Override 175 public void process(Matrix X, Matrix Y){ 176 prepareNextRound(X, Y); 177 int iter = 0; 178 Matrix xt = X.transpose(); 179 Matrix xtrows = xt; 180 if(xt instanceof AbstractSparseMatrix){ 181 xtrows = CFMatrixUtils.asSparseRow(xt); 182 } 183 while(true) { 184 // We need to set the bias here because it is used in the loss calculation of U and W 185 if(this.biasMode) loss.setBias(this.bias); 186 iter += 1; 187 188 // Perform the bilinear operation 189 final Matrix neww = updateW(xt,eta0_w, lambda_u); 190 final Matrix newu = updateU(xtrows,neww,eta0_u, lambda_w); 191 Matrix newbias = null; 192 if(this.biasMode){ 193 newbias = updateBias(xt, newu, neww, biasEta0); 194 } 195 196 // This part of the code checks if we can stop the bilinear steps by checking how much everything has changed proportionally 197 198 double ratioB = 0; 199 double totalbias = 0; 200 201 final double sumchangew = CFMatrixUtils.absSum(neww.minus(this.w)); 202 final double totalw = CFMatrixUtils.absSum(this.w); 203 204 final double sumchangeu = CFMatrixUtils.absSum(newu.minus(this.u)); 205 final double totalu = CFMatrixUtils.absSum(this.u); 206 207 double ratioU = 0; 208 if(totalu!=0) ratioU = sumchangeu/totalu; 209 final double ratioW = 0; 210 if(totalw!=0) ratioU = sumchangew/totalw; 211 double ratio = ratioU + ratioW; 212 if(this.biasMode){ 213 final double sumchangebias = CFMatrixUtils.absSum(newbias.minus(this.bias)); 214 totalbias = CFMatrixUtils.absSum(this.bias); 215 if(totalbias!=0) ratioB = (sumchangebias/totalbias) ; 216 ratio += ratioB; 217 ratio/=3; 218 } else { 219 ratio/=2; 220 } 221 222 /** 223 * This is not a matter of simply type 224 * The 0 values of the sparse matrix are also removed. very important. 225 */ 226 if(this.forceSparcity) 227 { 228 this.u = CFMatrixUtils.asSparseColumn(newu); 229 this.w = CFMatrixUtils.asSparseColumn(neww); 230 } 231 else{ 232 233 this.w = neww; 234 this.u = newu; 235 } 236 237 if(this.biasMode){ 238 this.bias = newbias; 239 } 240 241 final Double biconvextol = this.params.getTyped("biconvex_tol"); 242 final Integer maxiter = this.params.getTyped("biconvex_maxiter"); 243 if(iter%3 == 0){ 244 logger.debug(String.format("Iter: %d. Last Ratio: %2.3f",iter,ratio)); 245 logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity(w)); 246 logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity(u)); 247 logger.debug("Total U magnitude: " + totalu); 248 logger.debug("Total W magnitude: " + totalw); 249 logger.debug("Total Bias: " + totalbias); 250 } 251 if(biconvextol < 0 || ratio < biconvextol || iter >= maxiter) { 252 logger.debug("tolerance reached after iteration: " + iter); 253 logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity(w)); 254 logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity(u)); 255 logger.debug("Total U magnitude: " + totalu); 256 logger.debug("Total W magnitude: " + totalw); 257 logger.debug("Total Bias: " + totalbias); 258 break; 259 } 260 } 261 } 262 private void prepareNextRound(Matrix X, Matrix Y) { 263 final int nfeatures = X.getNumRows(); 264 final int nusers = X.getNumColumns(); 265 final int ntasks = Y.getNumColumns(); 266// int ninstances = Y.getNumRows(); // Assume 1 instance! 267 268 // only inits when the current params is null 269 if (this.w == null){ 270 initParams(X,Y,nfeatures, nusers, ntasks); // Number of words, users and tasks 271 } 272 273 final Double dampening = this.params.getTyped(BilinearLearnerParameters.DAMPENING); 274 final double weighting = 1.0 - dampening ; 275 276 logger.debug("... dampening w, u and bias by: " + weighting); 277 278 // Adjust for weighting 279 this.w.scaleEquals(weighting); 280 this.u.scaleEquals(weighting); 281 if(this.biasMode){ 282 this.bias.scaleEquals(weighting); 283 } 284 // First expand Y s.t. blocks of rows contain the task values for each row of Y. 285 // This means Yexp has (n * t x t) 286 final SparseMatrix Yexp = expandY(Y); 287 loss.setY(Yexp); 288 } 289 290 protected Matrix updateBias(Matrix xt, Matrix nu, Matrix nw, double biasLossWeight) { 291 Matrix newut = nu.transpose(); 292 Matrix utxt = CFMatrixUtils.fastdot(newut,xt); 293 Matrix utxtw = CFMatrixUtils.fastdot(utxt,nw); 294 final Matrix mult = utxtw.plus(this.bias); 295 // We must set bias to null! 296 loss.setBias(null); 297 loss.setX(diagX); 298 // Calculate gradient of bias (don't regularise) 299 final Matrix biasGrad = loss.gradient(mult); 300 Matrix newbias = null; 301 for (int i = 0; i < 1000; i++) { 302 logger.debug("... Line searching etab = " + biasLossWeight); 303 newbias = this.bias.clone(); 304 Matrix scaledGradW = biasGrad.scale(1./biasLossWeight); 305 newbias = CFMatrixUtils.fastminus(newbias,scaledGradW); 306 307 if(loss.test_backtrack(this.bias, biasGrad, newbias, biasLossWeight)) 308 break; 309 biasLossWeight *= eta_gamma; 310 } 311// final Matrix newbias = this.bias.minus( 312// CFMatrixUtils.timesInplace( 313// biasGrad, 314// biasLossWeight 315// ) 316// ); 317 return newbias; 318 } 319 protected Matrix updateW(Matrix xt, double wLossWeighted, double weightedLambda) { 320 // Dprime is tasks x nwords 321 322 Matrix Dprime = null; 323 Matrix ut = this.u.transpose(); 324 if(this.nodataseen){ 325 this.nodataseen = false; 326 Matrix fakeu = new SparseSingleValueInitStrat(1).init(this.u.getNumColumns(), this.u.getNumRows()); 327 Dprime = CFMatrixUtils.fastdot(fakeu,xt); 328 } else { 329 Dprime = CFMatrixUtils.fastdot(ut, xt); 330 } 331 332 // ... as is the cost function's X 333 if(zStandardise){ 334 Vector rowMean = CFMatrixUtils.rowMean(Dprime); 335 CFMatrixUtils.minusEqualsCol(Dprime,rowMean); 336 } 337 loss.setX(Dprime); 338 final Matrix gradW = loss.gradient(this.w); 339 logger.debug("Abs w_grad: " + CFMatrixUtils.absSum(gradW)); 340 Matrix neww = null; 341 for (int i = 0; i < 1000; i++) { 342 logger.debug("... Line searching etaw = " + wLossWeighted); 343 neww = this.w.clone(); 344 Matrix scaledGradW = gradW.scale(1./wLossWeighted); 345 neww = CFMatrixUtils.fastminus(neww,scaledGradW); 346 neww = regul.prox(neww, weightedLambda/wLossWeighted); 347 if(loss.test_backtrack(this.w, gradW, neww, wLossWeighted)) 348 break; 349 wLossWeighted *= eta_gamma; 350 } 351 352 return neww; 353 } 354 protected Matrix updateU(Matrix xtrows, Matrix neww, double uLossWeight, double uWeightedLambda) { 355 // Vprime is nusers x tasks 356 final Matrix Vprime = CFMatrixUtils.fastdot(xtrows,neww); 357 // ... so the loss function's X is (tasks x nusers) 358 Matrix Vt = CFMatrixUtils.asSparseRow(Vprime.transpose()); 359 if(zStandardise){ 360 Vector rowMean = CFMatrixUtils.rowMean(Vt); 361 CFMatrixUtils.minusEqualsCol(Vt,rowMean); 362 } 363 loss.setX(Vt); 364 final Matrix gradU = loss.gradient(this.u); 365 logger.debug("Abs u_grad: " + CFMatrixUtils.absSum(gradU)); 366// CFMatrixUtils.timesInplace(gradU,uLossWeight); 367// newu = regul.prox(newu, uWeightedLambda); 368 Matrix newu = null; 369 for (int i = 0; i < 1000; i++) { 370 logger.debug("... Line searching etau = " + uLossWeight); 371 newu = this.u.clone(); 372 Matrix scaledGradW = gradU.scale(1./uLossWeight); 373 newu = CFMatrixUtils.fastminus(newu,scaledGradW); 374 newu = regul.prox(newu, uWeightedLambda/uLossWeight); 375 if(loss.test_backtrack(this.u, gradU, newu, uLossWeight)) 376 break; 377 uLossWeight *= eta_gamma; 378 } 379 380 return newu; 381 } 382 private double lambdat(int iter, double lambda) { 383 return lambda/iter; 384 } 385 /** 386 * Given a flat value matrix, makes a diagonal sparse matrix containing the values as the diagonal 387 * @param Y 388 * @return the diagonalised Y 389 */ 390 public static SparseMatrix expandY(Matrix Y) { 391 final int ntasks = Y.getNumColumns(); 392 final SparseMatrix Yexp = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(ntasks, ntasks); 393 for (int touter = 0; touter < ntasks; touter++) { 394 for (int tinner = 0; tinner < ntasks; tinner++) { 395 if(tinner == touter){ 396 Yexp.setElement(touter, tinner, Y.getElement(0, tinner)); 397 } 398 else{ 399 Yexp.setElement(touter, tinner, Double.NaN); 400 } 401 } 402 } 403 return Yexp; 404 } 405 406 407 protected double etat(int iter,double eta0) { 408 final Integer etaSteps = this.params.getTyped(BilinearLearnerParameters.ETASTEPS); 409 final double sqrtCeil = Math.sqrt(Math.ceil(iter/(double)etaSteps)); 410 return eta(eta0) / sqrtCeil; 411 } 412 private double eta(double eta0) { 413 return eta0 ; 414 } 415 416 417 418 /** 419 * @return the current apramters 420 */ 421 public BilinearLearnerParameters getParams() { 422 return this.params; 423 } 424 425 /** 426 * @return the current user matrix 427 */ 428 public Matrix getU(){ 429 return this.u; 430 } 431 432 /** 433 * @return the current word matrix 434 */ 435 public Matrix getW(){ 436 return this.w; 437 } 438 /** 439 * @return the current bias (null if {@link BilinearLearnerParameters#BIAS} is false 440 */ 441 public Matrix getBias() { 442 if(this.biasMode) 443 return this.bias; 444 else 445 return null; 446 } 447 448 /** 449 * Expand the U parameters matrix by added a set of rows. 450 * If currently unset, this function does nothing (assuming U will be initialised in the first round) 451 * The new U parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDUINITSTRAT} 452 * @param newUsers the number of new users to add 453 */ 454 public void addU(int newUsers) { 455 if(this.u == null) return; // If u has not be inited, then it will be on first process 456 final InitStrategy ustrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDUINITSTRAT,null,null); 457 final Matrix newU = ustrat.init(newUsers, this.u.getNumColumns()); 458 this.u = CFMatrixUtils.vstack(this.u,newU); 459 } 460 461 /** 462 * Expand the W parameters matrix by added a set of rows. 463 * If currently unset, this function does nothing (assuming W will be initialised in the first round) 464 * The new W parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDWINITSTRAT} 465 * @param newWords the number of new words to add 466 */ 467 public void addW(int newWords) { 468 if(this.w == null) return; // If w has not be inited, then it will be on first process 469 final InitStrategy wstrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDWINITSTRAT,null,null); 470 final Matrix newW = wstrat.init(newWords, this.w.getNumColumns()); 471 this.w = CFMatrixUtils.vstack(this.w,newW); 472 } 473 474 @Override 475 public BilinearSparseOnlineLearner clone(){ 476 final BilinearSparseOnlineLearner ret = new BilinearSparseOnlineLearner(this.getParams()); 477 ret.u = this.u.clone(); 478 ret.w = this.w.clone(); 479 if(this.biasMode){ 480 ret.bias = this.bias.clone(); 481 } 482 return ret; 483 } 484 /** 485 * @param newu set the model's U 486 */ 487 public void setU(Matrix newu) { 488 this.u = newu; 489 } 490 491 /** 492 * @param neww set the model's W 493 */ 494 public void setW(Matrix neww) { 495 this.w = neww; 496 } 497 @Override 498 public void readBinary(DataInput in) throws IOException { 499 final int nwords = in.readInt(); 500 final int nusers = in.readInt(); 501 final int ntasks = in.readInt(); 502 503 504 this.w = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nwords, ntasks); 505 for (int t = 0; t < ntasks; t++) { 506 for (int r = 0; r < nwords; r++) { 507 final double readDouble = in.readDouble(); 508 if(readDouble != 0){ 509 this.w.setElement(r, t, readDouble); 510 } 511 } 512 } 513 514 this.u = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nusers, ntasks); 515 for (int t = 0; t < ntasks; t++) { 516 for (int r = 0; r < nusers; r++) { 517 final double readDouble = in.readDouble(); 518 if(readDouble != 0){ 519 this.u.setElement(r, t, readDouble); 520 } 521 } 522 } 523 524 this.bias = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(ntasks, ntasks); 525 for (int t1 = 0; t1 < ntasks; t1++) { 526 for (int t2 = 0; t2 < ntasks; t2++) { 527 final double readDouble = in.readDouble(); 528 if(readDouble != 0){ 529 this.bias.setElement(t1, t2, readDouble); 530 } 531 } 532 } 533 } 534 @Override 535 public byte[] binaryHeader() { 536 return "".getBytes(); 537 } 538 @Override 539 public void writeBinary(DataOutput out) throws IOException { 540 out.writeInt(w.getNumRows()); 541 out.writeInt(u.getNumRows()); 542 out.writeInt(u.getNumColumns()); 543 final double[] wdata = CFMatrixUtils.getData(w); 544 for (int i = 0; i < wdata.length; i++) { 545 out.writeDouble(wdata[i]); 546 } 547 final double[] udata = CFMatrixUtils.getData(u); 548 for (int i = 0; i < udata.length; i++) { 549 out.writeDouble(udata[i]); 550 } 551 final double[] biasdata = CFMatrixUtils.getData(bias); 552 for (int i = 0; i < biasdata.length; i++) { 553 out.writeDouble(biasdata[i]); 554 } 555 } 556 557 558 @Override 559 public Matrix predict(Matrix x) { 560 final Matrix mult = this.u.transpose().times(x.transpose()).times(this.w); 561 if(this.biasMode)mult.plusEquals(this.bias); 562 final Vector ydiag = CFMatrixUtils.diag(mult); 563 final Matrix createIdentity = SparseMatrixFactoryMTJ.INSTANCE.createIdentity(1, ydiag.getDimensionality()); 564 createIdentity.setRow(0, ydiag); 565 return createIdentity; 566 } 567}