001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.learner;
031
032import gov.sandia.cognition.math.matrix.Matrix;
033import gov.sandia.cognition.math.matrix.Vector;
034import gov.sandia.cognition.math.matrix.mtj.AbstractMTJMatrix;
035import gov.sandia.cognition.math.matrix.mtj.AbstractSparseMatrix;
036import gov.sandia.cognition.math.matrix.mtj.SparseColumnMatrix;
037import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
038import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
039
040import java.io.DataInput;
041import java.io.DataOutput;
042import java.io.IOException;
043
044import no.uib.cipr.matrix.sparse.FlexCompColMatrix;
045
046import org.apache.log4j.Logger;
047import org.openimaj.io.ReadWriteableBinary;
048import org.openimaj.math.matrix.CFMatrixUtils;
049import org.openimaj.ml.linear.learner.init.ContextAwareInitStrategy;
050import org.openimaj.ml.linear.learner.init.InitStrategy;
051import org.openimaj.ml.linear.learner.init.SparseSingleValueInitStrat;
052import org.openimaj.ml.linear.learner.loss.LossFunction;
053import org.openimaj.ml.linear.learner.loss.MatLossFunction;
054import org.openimaj.ml.linear.learner.regul.Regulariser;
055
056
057/**
058 * An implementation of a stochastic gradient decent with proximal perameter adjustment
059 * (for regularised parameters).
060 *
061 * Data is dealt with sequentially using a one pass implementation of the
062 * online proximal algorithm described in chapter 9 and 10 of:
063 * The Geometry of Constrained Structured Prediction: Applications to Inference and
064 * Learning of Natural Language Syntax, PhD, Andre T. Martins
065 *
066 * The implementation does the following:
067 *      - When an X,Y is recieved:
068 *              - Update currently held batch
069 *              - If the batch is full:
070 *                      - While There is a great deal of change in U and W:
071 *                              - Calculate the gradient of W holding U fixed
072 *                              - Proximal update of W
073 *                              - Calculate the gradient of U holding W fixed
074 *                              - Proximal update of U
075 *                              - Calculate the gradient of Bias holding U and W fixed
076 *                      - flush the batch
077 *              - return current U and W (same as last time is batch isn't filled yet)
078 *
079 *
080 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
081 *
082 */
083public class BilinearSparseOnlineLearner implements OnlineLearner<Matrix,Matrix>, ReadWriteableBinary{
084
085        static Logger logger = Logger.getLogger(BilinearSparseOnlineLearner.class);
086
087        protected BilinearLearnerParameters params;
088        protected Matrix w;
089        protected Matrix u;
090        protected SparseMatrixFactoryMTJ smf = SparseMatrixFactoryMTJ.INSTANCE;
091        protected LossFunction loss;
092        protected Regulariser regul;
093        protected Double lambda_w,lambda_u;
094        protected Boolean biasMode;
095        protected Matrix bias;
096        protected Matrix diagX;
097        protected Double eta0_u;
098        protected Double eta0_w;
099
100        private Boolean forceSparcity;
101
102        private Boolean zStandardise;
103
104        private boolean nodataseen;
105
106        private double eta_gamma;
107
108        private double biasEta0;
109
110        /**
111         * The default parameters. These won't work with your dataset, i promise.
112         */
113        public BilinearSparseOnlineLearner() {
114                this(new BilinearLearnerParameters());
115        }
116        /**
117         * @param params the parameters used by this learner
118         */
119        public BilinearSparseOnlineLearner(BilinearLearnerParameters params) {
120                this.params = params;
121                reinitParams();
122        }
123
124        /**
125         * must be called if any parameters are changed
126         */
127        public void reinitParams() {
128                this.loss = this.params.getTyped(BilinearLearnerParameters.LOSS);
129                this.regul = this.params.getTyped(BilinearLearnerParameters.REGUL);
130                this.lambda_w = this.params.getTyped(BilinearLearnerParameters.LAMBDA_W);
131                this.lambda_u = this.params.getTyped(BilinearLearnerParameters.LAMBDA_U);
132                this.biasMode = this.params.getTyped(BilinearLearnerParameters.BIAS);
133                this.eta0_u = this.params.getTyped(BilinearLearnerParameters.ETA0_U);
134                this.eta0_w = this.params.getTyped(BilinearLearnerParameters.ETA0_W);
135                this.biasEta0 = this.params.getTyped(BilinearLearnerParameters.ETA0_BIAS);
136                this.eta_gamma = params.getTyped(BilinearLearnerParameters.ETA_GAMMA);
137                this.forceSparcity = this.params.getTyped(BilinearLearnerParameters.FORCE_SPARCITY);
138                this.zStandardise = this.params.getTyped(BilinearLearnerParameters.Z_STANDARDISE);
139                if(!this.loss.isMatrixLoss())
140                        this.loss = new MatLossFunction(this.loss);
141                this.nodataseen = true;
142        }
143        private void initParams(Matrix x, Matrix y, int xrows, int xcols, int ycols) {
144                final InitStrategy wstrat = getInitStrat(BilinearLearnerParameters.WINITSTRAT,x,y);
145                final InitStrategy ustrat = getInitStrat(BilinearLearnerParameters.UINITSTRAT,x,y);
146                this.w = wstrat.init(xrows, ycols);
147                this.u = ustrat.init(xcols, ycols);
148                if(this.forceSparcity)
149                {
150                        this.u = CFMatrixUtils.asSparseColumn(this.u);
151                        this.w = CFMatrixUtils.asSparseColumn(this.w);
152                }
153
154                this.bias = smf.createMatrix(ycols,ycols);
155                if(this.biasMode){
156                        final InitStrategy bstrat = getInitStrat(BilinearLearnerParameters.BIASINITSTRAT,x,y);
157                        this.bias = bstrat.init(ycols, ycols);
158                        this.diagX = smf.createIdentity(ycols, ycols);
159                }
160        }
161        
162        
163
164        private InitStrategy getInitStrat(String initstrat, Matrix x, Matrix y) {
165                final InitStrategy strat = this.params.getTyped(initstrat);
166                if(strat instanceof ContextAwareInitStrategy){
167                        final ContextAwareInitStrategy<Matrix, Matrix> cwStrat = this.params.getTyped(initstrat);
168                        cwStrat.setLearner(this);
169                        cwStrat.setContext(x, y);
170                        return cwStrat;
171                }
172                return strat;
173        }
174        @Override
175        public void process(Matrix X, Matrix Y){
176                prepareNextRound(X, Y);
177                int iter = 0;
178                Matrix xt = X.transpose();
179                Matrix xtrows = xt;
180                if(xt instanceof AbstractSparseMatrix){
181                        xtrows = CFMatrixUtils.asSparseRow(xt);
182                }
183                while(true) {
184                        // We need to set the bias here because it is used in the loss calculation of U and W
185                        if(this.biasMode) loss.setBias(this.bias);
186                        iter += 1;
187
188                        // Perform the bilinear operation
189                        final Matrix neww = updateW(xt,eta0_w, lambda_u);
190                        final Matrix newu = updateU(xtrows,neww,eta0_u, lambda_w);
191                        Matrix newbias = null;
192                        if(this.biasMode){
193                                newbias = updateBias(xt, newu, neww, biasEta0);
194                        }
195                        
196                        // This part of the code checks if we can stop the bilinear steps by checking how much everything has changed proportionally
197                        
198                        double ratioB = 0;
199                        double totalbias = 0;
200                        
201                        final double sumchangew = CFMatrixUtils.absSum(neww.minus(this.w));
202                        final double totalw = CFMatrixUtils.absSum(this.w);
203
204                        final double sumchangeu = CFMatrixUtils.absSum(newu.minus(this.u));
205                        final double totalu = CFMatrixUtils.absSum(this.u);
206
207                        double ratioU = 0;
208                        if(totalu!=0) ratioU = sumchangeu/totalu;
209                        final double ratioW = 0;
210                        if(totalw!=0) ratioU = sumchangew/totalw;
211                        double ratio = ratioU + ratioW;
212                        if(this.biasMode){
213                                final double sumchangebias = CFMatrixUtils.absSum(newbias.minus(this.bias));
214                                totalbias = CFMatrixUtils.absSum(this.bias);
215                                if(totalbias!=0) ratioB = (sumchangebias/totalbias) ;
216                                ratio += ratioB;
217                                ratio/=3;
218                        } else {
219                                ratio/=2;
220                        }
221                        
222                        /**
223                         * This is not a matter of simply type
224                         * The 0 values of the sparse matrix are also removed. very important.
225                         */
226                        if(this.forceSparcity)
227                        {
228                                this.u = CFMatrixUtils.asSparseColumn(newu);
229                                this.w = CFMatrixUtils.asSparseColumn(neww);
230                        }
231                        else{
232
233                                this.w = neww;
234                                this.u = newu;
235                        }
236                        
237                        if(this.biasMode){
238                                this.bias = newbias;
239                        }
240
241                        final Double biconvextol = this.params.getTyped("biconvex_tol");
242                        final Integer maxiter = this.params.getTyped("biconvex_maxiter");
243                        if(iter%3 == 0){
244                                logger.debug(String.format("Iter: %d. Last Ratio: %2.3f",iter,ratio));
245                                logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity(w));
246                                logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity(u));
247                                logger.debug("Total U magnitude: " + totalu);
248                                logger.debug("Total W magnitude: " + totalw);
249                                logger.debug("Total Bias: " + totalbias);
250                        }
251                        if(biconvextol  < 0 || ratio < biconvextol || iter >= maxiter) {
252                                logger.debug("tolerance reached after iteration: " + iter);
253                                logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity(w));
254                                logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity(u));
255                                logger.debug("Total U magnitude: " + totalu);
256                                logger.debug("Total W magnitude: " + totalw);
257                                logger.debug("Total Bias: " + totalbias);
258                                break;
259                        }
260                }
261        }
262        private void prepareNextRound(Matrix X, Matrix Y) {
263                final int nfeatures = X.getNumRows();
264                final int nusers = X.getNumColumns();
265                final int ntasks = Y.getNumColumns();
266//              int ninstances = Y.getNumRows(); // Assume 1 instance!
267                
268                // only inits when the current params is null
269                if (this.w == null){
270                        initParams(X,Y,nfeatures, nusers, ntasks); // Number of words, users and tasks
271                }
272
273                final Double dampening = this.params.getTyped(BilinearLearnerParameters.DAMPENING);
274                final double weighting = 1.0 - dampening ;
275
276                logger.debug("... dampening w, u and bias by: " + weighting);
277
278                // Adjust for weighting
279                this.w.scaleEquals(weighting);
280                this.u.scaleEquals(weighting);
281                if(this.biasMode){
282                        this.bias.scaleEquals(weighting);
283                }
284                // First expand Y s.t. blocks of rows contain the task values for each row of Y.
285                // This means Yexp has (n * t x t)
286                final SparseMatrix Yexp = expandY(Y);
287                loss.setY(Yexp);
288        }
289        
290        protected Matrix updateBias(Matrix xt, Matrix nu, Matrix nw, double biasLossWeight) {
291                Matrix newut = nu.transpose();
292                Matrix utxt = CFMatrixUtils.fastdot(newut,xt);
293                Matrix utxtw = CFMatrixUtils.fastdot(utxt,nw);
294                final Matrix mult = utxtw.plus(this.bias);
295                // We must set bias to null!
296                loss.setBias(null);
297                loss.setX(diagX);
298                // Calculate gradient of bias (don't regularise)
299                final Matrix biasGrad = loss.gradient(mult);
300                Matrix newbias = null;
301                for (int i = 0; i < 1000; i++) {
302                        logger.debug("... Line searching etab = " + biasLossWeight);
303                        newbias = this.bias.clone();
304                        Matrix scaledGradW = biasGrad.scale(1./biasLossWeight);
305                        newbias = CFMatrixUtils.fastminus(newbias,scaledGradW);
306                        
307                        if(loss.test_backtrack(this.bias, biasGrad, newbias, biasLossWeight)) 
308                                break;
309                        biasLossWeight *= eta_gamma;
310                }
311//              final Matrix newbias = this.bias.minus(
312//                              CFMatrixUtils.timesInplace(
313//                                              biasGrad,
314//                                              biasLossWeight
315//                              )
316//              );
317                return newbias;
318        }
319        protected Matrix updateW(Matrix xt, double wLossWeighted, double weightedLambda) {
320                // Dprime is tasks x nwords
321                
322                Matrix Dprime = null;
323                Matrix ut = this.u.transpose();                         
324                if(this.nodataseen){
325                        this.nodataseen = false;
326                        Matrix fakeu = new SparseSingleValueInitStrat(1).init(this.u.getNumColumns(), this.u.getNumRows());
327                        Dprime = CFMatrixUtils.fastdot(fakeu,xt);
328                } else {
329                        Dprime = CFMatrixUtils.fastdot(ut, xt);
330                }
331                
332                // ... as is the cost function's X
333                if(zStandardise){
334                        Vector rowMean = CFMatrixUtils.rowMean(Dprime);
335                        CFMatrixUtils.minusEqualsCol(Dprime,rowMean);
336                }
337                loss.setX(Dprime);
338                final Matrix gradW = loss.gradient(this.w);
339                logger.debug("Abs w_grad: " + CFMatrixUtils.absSum(gradW));
340                Matrix neww = null;
341                for (int i = 0; i < 1000; i++) {
342                        logger.debug("... Line searching etaw = " + wLossWeighted);
343                        neww = this.w.clone();
344                        Matrix scaledGradW = gradW.scale(1./wLossWeighted);
345                        neww = CFMatrixUtils.fastminus(neww,scaledGradW);
346                        neww = regul.prox(neww, weightedLambda/wLossWeighted);
347                        if(loss.test_backtrack(this.w, gradW, neww, wLossWeighted)) 
348                                break;
349                        wLossWeighted *= eta_gamma;
350                }
351                
352                return neww;
353        }
354        protected Matrix updateU(Matrix xtrows, Matrix neww, double uLossWeight, double uWeightedLambda) {
355                // Vprime is nusers x tasks
356                final Matrix Vprime = CFMatrixUtils.fastdot(xtrows,neww);
357                // ... so the loss function's X is (tasks x nusers)
358                Matrix Vt = CFMatrixUtils.asSparseRow(Vprime.transpose());
359                if(zStandardise){
360                        Vector rowMean = CFMatrixUtils.rowMean(Vt);
361                        CFMatrixUtils.minusEqualsCol(Vt,rowMean);
362                }
363                loss.setX(Vt);
364                final Matrix gradU = loss.gradient(this.u);
365                logger.debug("Abs u_grad: " + CFMatrixUtils.absSum(gradU));
366//              CFMatrixUtils.timesInplace(gradU,uLossWeight);
367//              newu = regul.prox(newu, uWeightedLambda);
368                Matrix newu = null;
369                for (int i = 0; i < 1000; i++) {
370                        logger.debug("... Line searching etau = " + uLossWeight);
371                        newu = this.u.clone();
372                        Matrix scaledGradW = gradU.scale(1./uLossWeight);
373                        newu = CFMatrixUtils.fastminus(newu,scaledGradW);
374                        newu = regul.prox(newu, uWeightedLambda/uLossWeight);
375                        if(loss.test_backtrack(this.u, gradU, newu, uLossWeight)) 
376                                break;
377                        uLossWeight *= eta_gamma;
378                }
379                
380                return newu;
381        }
382        private double lambdat(int iter, double lambda) {
383                return lambda/iter;
384        }
385        /**
386         * Given a flat value matrix, makes a diagonal sparse matrix containing the values as the diagonal
387         * @param Y
388         * @return the diagonalised Y
389         */
390        public static SparseMatrix expandY(Matrix Y) {
391                final int ntasks = Y.getNumColumns();
392                final SparseMatrix Yexp = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(ntasks, ntasks);
393                for (int touter = 0; touter < ntasks; touter++) {
394                        for (int tinner = 0; tinner < ntasks; tinner++) {
395                                if(tinner == touter){
396                                        Yexp.setElement(touter, tinner, Y.getElement(0, tinner));
397                                }
398                                else{
399                                        Yexp.setElement(touter, tinner, Double.NaN);
400                                }
401                        }
402                }
403                return Yexp;
404        }
405
406
407        protected double etat(int iter,double eta0) {
408                final Integer etaSteps = this.params.getTyped(BilinearLearnerParameters.ETASTEPS);
409                final double sqrtCeil = Math.sqrt(Math.ceil(iter/(double)etaSteps));
410                return eta(eta0) / sqrtCeil;
411        }
412        private double eta(double eta0) {
413                return eta0 ;
414        }
415
416
417
418        /**
419         * @return the current apramters
420         */
421        public BilinearLearnerParameters getParams() {
422                return this.params;
423        }
424
425        /**
426         * @return the current user matrix
427         */
428        public Matrix getU(){
429                return this.u;
430        }
431
432        /**
433         * @return the current word matrix
434         */
435        public Matrix getW(){
436                return this.w;
437        }
438        /**
439         * @return the current bias (null if {@link BilinearLearnerParameters#BIAS} is false
440         */
441        public Matrix getBias() {
442                if(this.biasMode)
443                        return this.bias;
444                else
445                        return null;
446        }
447
448        /**
449         * Expand the U parameters matrix by added a set of rows.
450         * If currently unset, this function does nothing (assuming U will be initialised in the first round)
451         * The new U parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDUINITSTRAT}
452         * @param newUsers the number of new users to add
453         */
454        public void addU(int newUsers) {
455                if(this.u == null) return; // If u has not be inited, then it will be on first process
456                final InitStrategy ustrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDUINITSTRAT,null,null);
457                final Matrix newU = ustrat.init(newUsers, this.u.getNumColumns());
458                this.u = CFMatrixUtils.vstack(this.u,newU);
459        }
460
461        /**
462         * Expand the W parameters matrix by added a set of rows.
463         * If currently unset, this function does nothing (assuming W will be initialised in the first round)
464         * The new W parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDWINITSTRAT}
465         * @param newWords the number of new words to add
466         */
467        public void addW(int newWords) {
468                if(this.w == null) return; // If w has not be inited, then it will be on first process
469                final InitStrategy wstrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDWINITSTRAT,null,null);
470                final Matrix newW = wstrat.init(newWords, this.w.getNumColumns());
471                this.w = CFMatrixUtils.vstack(this.w,newW);
472        }
473
474        @Override
475        public BilinearSparseOnlineLearner clone(){
476                final BilinearSparseOnlineLearner ret = new BilinearSparseOnlineLearner(this.getParams());
477                ret.u = this.u.clone();
478                ret.w = this.w.clone();
479                if(this.biasMode){
480                        ret.bias = this.bias.clone();
481                }
482                return ret;
483        }
484        /**
485         * @param newu set the model's U
486         */
487        public void setU(Matrix newu) {
488                this.u = newu;
489        }
490
491        /**
492         * @param neww set the model's W
493         */
494        public void setW(Matrix neww) {
495                this.w = neww;
496        }
497        @Override
498        public void readBinary(DataInput in) throws IOException {
499                final int nwords = in.readInt();
500                final int nusers = in.readInt();
501                final int ntasks = in.readInt();
502
503
504                this.w = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nwords, ntasks);
505                for (int t = 0; t < ntasks; t++) {
506                        for (int r = 0; r < nwords; r++) {
507                                final double readDouble = in.readDouble();
508                                if(readDouble != 0){
509                                        this.w.setElement(r, t, readDouble);
510                                }
511                        }
512                }
513
514                this.u = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nusers, ntasks);
515                for (int t = 0; t < ntasks; t++) {
516                        for (int r = 0; r < nusers; r++) {
517                                final double readDouble = in.readDouble();
518                                if(readDouble != 0){
519                                        this.u.setElement(r, t, readDouble);
520                                }
521                        }
522                }
523
524                this.bias = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(ntasks, ntasks);
525                for (int t1 = 0; t1 < ntasks; t1++) {
526                        for (int t2 = 0; t2 < ntasks; t2++) {
527                                final double readDouble = in.readDouble();
528                                if(readDouble != 0){
529                                        this.bias.setElement(t1, t2, readDouble);
530                                }
531                        }
532                }
533        }
534        @Override
535        public byte[] binaryHeader() {
536                return "".getBytes();
537        }
538        @Override
539        public void writeBinary(DataOutput out) throws IOException {
540                out.writeInt(w.getNumRows());
541                out.writeInt(u.getNumRows());
542                out.writeInt(u.getNumColumns());
543                final double[] wdata = CFMatrixUtils.getData(w);
544                for (int i = 0; i < wdata.length; i++) {
545                        out.writeDouble(wdata[i]);
546                }
547                final double[] udata = CFMatrixUtils.getData(u);
548                for (int i = 0; i < udata.length; i++) {
549                        out.writeDouble(udata[i]);
550                }
551                final double[] biasdata = CFMatrixUtils.getData(bias);
552                for (int i = 0; i < biasdata.length; i++) {
553                        out.writeDouble(biasdata[i]);
554                }
555        }
556
557
558        @Override
559        public Matrix predict(Matrix x) {
560                final Matrix mult = this.u.transpose().times(x.transpose()).times(this.w);
561                if(this.biasMode)mult.plusEquals(this.bias);
562                final Vector ydiag = CFMatrixUtils.diag(mult);
563                final Matrix createIdentity = SparseMatrixFactoryMTJ.INSTANCE.createIdentity(1, ydiag.getDimensionality());
564                createIdentity.setRow(0, ydiag);
565                return createIdentity;
566        }
567}