001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.linear.learner.matlib;
031
032import java.io.DataInput;
033import java.io.DataOutput;
034import java.io.IOException;
035
036import org.apache.log4j.Logger;
037import org.openimaj.io.ReadWriteableBinary;
038import org.openimaj.math.matrix.DiagonalMatrix;
039import org.openimaj.math.matrix.MatlibMatrixUtils;
040import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
041import org.openimaj.ml.linear.learner.OnlineLearner;
042import org.openimaj.ml.linear.learner.matlib.init.InitStrategy;
043import org.openimaj.ml.linear.learner.matlib.init.SparseSingleValueInitStrat;
044import org.openimaj.ml.linear.learner.matlib.loss.LossFunction;
045import org.openimaj.ml.linear.learner.matlib.loss.MatLossFunction;
046import org.openimaj.ml.linear.learner.matlib.regul.Regulariser;
047
048import ch.akuhn.matrix.Matrix;
049import ch.akuhn.matrix.SparseMatrix;
050
051
052/**
053 * An implementation of a stochastic gradient decent with proximal perameter adjustment
054 * (for regularised parameters).
055 *
056 * Data is dealt with sequentially using a one pass implementation of the
057 * online proximal algorithm described in chapter 9 and 10 of:
058 * The Geometry of Constrained Structured Prediction: Applications to Inference and
059 * Learning of Natural Language Syntax, PhD, Andre T. Martins
060 *
061 * The implementation does the following:
062 *      - When an X,Y is recieved:
063 *              - Update currently held batch
064 *              - If the batch is full:
065 *                      - While There is a great deal of change in U and W:
066 *                              - Calculate the gradient of W holding U fixed
067 *                              - Proximal update of W
068 *                              - Calculate the gradient of U holding W fixed
069 *                              - Proximal update of U
070 *                              - Calculate the gradient of Bias holding U and W fixed
071 *                      - flush the batch
072 *              - return current U and W (same as last time is batch isn't filled yet)
073 *
074 *
075 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
076 *
077 */
078public class MatlibBilinearSparseOnlineLearner implements OnlineLearner<Matrix,Matrix>, ReadWriteableBinary{
079
080        static Logger logger = Logger.getLogger(MatlibBilinearSparseOnlineLearner.class);
081
082        protected BilinearLearnerParameters params;
083        protected Matrix w;
084        protected Matrix u;
085        protected LossFunction loss;
086        protected Regulariser regul;
087        protected Double lambda_w,lambda_u;
088        protected Boolean biasMode;
089        protected Matrix bias;
090        protected Matrix diagX;
091        protected Double eta0_u;
092        protected Double eta0_w;
093
094        private Boolean forceSparcity;
095
096        private Boolean zStandardise;
097
098        private boolean nodataseen;
099
100        /**
101         * The default parameters. These won't work with your dataset, i promise.
102         */
103        public MatlibBilinearSparseOnlineLearner() {
104                this(new BilinearLearnerParameters());
105        }
106        /**
107         * @param params the parameters used by this learner
108         */
109        public MatlibBilinearSparseOnlineLearner(BilinearLearnerParameters params) {
110                this.params = params;
111                reinitParams();
112        }
113
114        /**
115         * must be called if any parameters are changed
116         */
117        public void reinitParams() {
118                this.loss = this.params.getTyped(BilinearLearnerParameters.LOSS);
119                this.regul = this.params.getTyped(BilinearLearnerParameters.REGUL);
120                this.lambda_w = this.params.getTyped(BilinearLearnerParameters.LAMBDA_W);
121                this.lambda_u = this.params.getTyped(BilinearLearnerParameters.LAMBDA_U);
122                this.biasMode = this.params.getTyped(BilinearLearnerParameters.BIAS);
123                this.eta0_u = this.params.getTyped(BilinearLearnerParameters.ETA0_U);
124                this.eta0_w = this.params.getTyped(BilinearLearnerParameters.ETA0_W);
125                this.forceSparcity = this.params.getTyped(BilinearLearnerParameters.FORCE_SPARCITY);
126                this.zStandardise = this.params.getTyped(BilinearLearnerParameters.Z_STANDARDISE);
127                if(!this.loss.isMatrixLoss())
128                        this.loss = new MatLossFunction(this.loss);
129                this.nodataseen = true;
130        }
131        private void initParams(Matrix x, Matrix y, int xrows, int xcols, int ycols) {
132                final InitStrategy wstrat = getInitStrat(BilinearLearnerParameters.WINITSTRAT,x,y);
133                final InitStrategy ustrat = getInitStrat(BilinearLearnerParameters.UINITSTRAT,x,y);
134                this.w = wstrat.init(xrows, ycols);
135                this.u = ustrat.init(xcols, ycols);
136
137                this.bias = SparseMatrix.sparse(ycols,ycols);
138                if(this.biasMode){
139                        final InitStrategy bstrat = getInitStrat(BilinearLearnerParameters.BIASINITSTRAT,x,y);
140                        this.bias = bstrat.init(ycols, ycols);
141                        this.diagX = new DiagonalMatrix(ycols,1);
142                }
143        }
144
145        private InitStrategy getInitStrat(String initstrat, Matrix x, Matrix y) {
146                final InitStrategy strat = this.params.getTyped(initstrat);
147                return strat;
148        }
149        @Override
150        public void process(Matrix X, Matrix Y){
151                final int nfeatures = X.rowCount();
152                final int nusers = X.columnCount();
153                final int ntasks = Y.columnCount();
154//              int ninstances = Y.rowCount(); // Assume 1 instance!
155                
156                // only inits when the current params is null
157                if (this.w == null){
158                        initParams(X,Y,nfeatures, nusers, ntasks); // Number of words, users and tasks
159                }
160
161                final Double dampening = this.params.getTyped(BilinearLearnerParameters.DAMPENING);
162                final double weighting = 1.0 - dampening ;
163
164                logger.debug("... dampening w, u and bias by: " + weighting);
165
166                // Adjust for weighting
167                MatlibMatrixUtils.scaleInplace(this.w,weighting);
168                MatlibMatrixUtils.scaleInplace(this.u,weighting);
169                if(this.biasMode){
170                        MatlibMatrixUtils.scaleInplace(this.bias,weighting);
171                }
172                // First expand Y s.t. blocks of rows contain the task values for each row of Y.
173                // This means Yexp has (n * t x t)
174                final SparseMatrix Yexp = expandY(Y);
175                loss.setY(Yexp);
176                int iter = 0;
177                while(true) {
178                        // We need to set the bias here because it is used in the loss calculation of U and W
179                        if(this.biasMode) loss.setBias(this.bias);
180                        iter += 1;
181
182                        final double uLossWeight = etat(iter,eta0_u);
183                        final double wLossWeighted = etat(iter,eta0_w);
184                        final double weightedLambda_u = lambdat(iter,lambda_u);
185                        final double weightedLambda_w = lambdat(iter,lambda_w);
186                        // Dprime is tasks x nwords
187                        Matrix Dprime = null;
188                        if(this.nodataseen){
189                                this.nodataseen = false;
190                                Matrix fakeut = new SparseSingleValueInitStrat(1).init(this.u.columnCount(),this.u.rowCount());
191                                Dprime = MatlibMatrixUtils.dotProductTranspose(fakeut, X); // i.e. fakeut . X^T
192                        } else {
193                                Dprime = MatlibMatrixUtils.dotProductTransposeTranspose(u, X); // i.e. u^T . X^T
194                        }
195                        
196                        // ... as is the cost function's X
197                        if(zStandardise){
198//                              Vector rowMean = CFMatrixUtils.rowMean(Dprime);
199//                              CFMatrixUtils.minusEqualsCol(Dprime,rowMean);
200                        }
201                        loss.setX(Dprime);
202                        final Matrix neww = updateW(this.w,wLossWeighted, weightedLambda_w);
203
204                        // Vprime is nusers x tasks
205                        final Matrix Vt = MatlibMatrixUtils.transposeDotProduct(neww,X); // i.e. (X^T.neww)^T X.transpose().times(neww);
206                        // ... so the loss function's X is (tasks x nusers)
207                        loss.setX(Vt);
208                        final Matrix newu = updateU(this.u,uLossWeight, weightedLambda_u);
209                        
210                        final double sumchangew = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(neww, this.w));
211                        final double totalw = MatlibMatrixUtils.normF(this.w);
212
213                        final double sumchangeu = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(newu, this.u));
214                        final double totalu = MatlibMatrixUtils.normF(this.u);
215
216                        double ratioU = 0;
217                        if(totalu!=0) ratioU = sumchangeu/totalu;
218                        final double ratioW = 0;
219                        if(totalw!=0) ratioU = sumchangew/totalw;
220                        double ratioB = 0;
221                        double ratio = ratioU + ratioW;
222                        double totalbias = 0;
223                        if(this.biasMode){
224                                Matrix mult = MatlibMatrixUtils.dotProductTransposeTranspose(newu, X);
225                                mult = MatlibMatrixUtils.dotProduct(mult, neww);                                
226                                MatlibMatrixUtils.plusInplace(mult, bias);
227                                // We must set bias to null!
228                                loss.setBias(null);
229                                loss.setX(diagX);
230                                // Calculate gradient of bias (don't regularise)
231                                final Matrix biasGrad = loss.gradient(mult);
232                                final double biasLossWeight = biasEtat(iter);
233                                final Matrix newbias = updateBias(biasGrad, biasLossWeight);
234
235                                final double sumchangebias = MatlibMatrixUtils.normF(MatlibMatrixUtils.minus(newbias, bias));
236                                totalbias = MatlibMatrixUtils.normF(this.bias);
237                                if(totalbias!=0) ratioB = (sumchangebias/totalbias) ;
238                                this.bias = newbias;
239                                ratio += ratioB;
240                                ratio/=3;
241                        }
242                        else{
243                                ratio/=2;
244                        }
245
246                        final Double biconvextol = this.params.getTyped("biconvex_tol");
247                        final Integer maxiter = this.params.getTyped("biconvex_maxiter");
248                        if(iter%3 == 0){
249                                logger.debug(String.format("Iter: %d. Last Ratio: %2.3f",iter,ratio));
250                                logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity(w));
251                                logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity(u));
252                                logger.debug("Total U magnitude: " + totalu);
253                                logger.debug("Total W magnitude: " + totalw);
254                                logger.debug("Total Bias: " + totalbias);
255                        }
256                        if(biconvextol  < 0 || ratio < biconvextol || iter >= maxiter) {
257                                logger.debug("tolerance reached after iteration: " + iter);
258                                logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity(w));
259                                logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity(u));
260                                logger.debug("Total U magnitude: " + totalu);
261                                logger.debug("Total W magnitude: " + totalw);
262                                logger.debug("Total Bias: " + totalbias);
263                                break;
264                        }
265                }
266        }
267        
268        protected Matrix updateBias(Matrix biasGrad, double biasLossWeight) {
269                final Matrix newbias = MatlibMatrixUtils.minus(
270                                this.bias,
271                                MatlibMatrixUtils.scaleInplace(
272                                                biasGrad,
273                                                biasLossWeight
274                                )
275                );
276                return newbias;
277        }
278        protected Matrix updateW(Matrix currentW, double wLossWeighted, double weightedLambda) {
279                final Matrix gradW = loss.gradient(currentW);
280                MatlibMatrixUtils.scaleInplace(gradW,wLossWeighted);
281
282                Matrix neww = MatlibMatrixUtils.minus(currentW,gradW);
283                neww = regul.prox(neww, weightedLambda);
284                return neww;
285        }
286        protected Matrix updateU(Matrix currentU, double uLossWeight, double uWeightedLambda) {
287                final Matrix gradU = loss.gradient(currentU);
288                MatlibMatrixUtils.scaleInplace(gradU,uLossWeight);
289                Matrix newu = MatlibMatrixUtils.minus(currentU,gradU);
290                newu = regul.prox(newu, uWeightedLambda);
291                return newu;
292        }
293        private double lambdat(int iter, double lambda) {
294                return lambda/iter;
295        }
296        /**
297         * Given a flat value matrix, makes a diagonal sparse matrix containing the values as the diagonal
298         * @param Y
299         * @return the diagonalised Y
300         */
301        public static SparseMatrix expandY(Matrix Y) {
302                final int ntasks = Y.columnCount();
303                final SparseMatrix Yexp = SparseMatrix.sparse(ntasks, ntasks);
304                for (int touter = 0; touter < ntasks; touter++) {
305                        for (int tinner = 0; tinner < ntasks; tinner++) {
306                                if(tinner == touter){
307                                        Yexp.put(touter, tinner, Y.get(0, tinner));
308                                }
309                                else{
310                                        Yexp.put(touter, tinner, Double.NaN);
311                                }
312                        }
313                }
314                return Yexp;
315        }
316        private double biasEtat(int iter){
317                final Double biasEta0 = this.params.getTyped(BilinearLearnerParameters.ETA0_BIAS);
318                return biasEta0 / Math.sqrt(iter);
319        }
320
321
322        private double etat(int iter,double eta0) {
323                final Integer etaSteps = this.params.getTyped(BilinearLearnerParameters.ETASTEPS);
324                final double sqrtCeil = Math.sqrt(Math.ceil(iter/(double)etaSteps));
325                return eta(eta0) / sqrtCeil;
326        }
327        private double eta(double eta0) {
328                return eta0 ;
329        }
330
331
332
333        /**
334         * @return the current apramters
335         */
336        public BilinearLearnerParameters getParams() {
337                return this.params;
338        }
339
340        /**
341         * @return the current user matrix
342         */
343        public Matrix getU(){
344                return this.u;
345        }
346
347        /**
348         * @return the current word matrix
349         */
350        public Matrix getW(){
351                return this.w;
352        }
353        /**
354         * @return the current bias (null if {@link BilinearLearnerParameters#BIAS} is false
355         */
356        public Matrix getBias() {
357                if(this.biasMode)
358                        return this.bias;
359                else
360                        return null;
361        }
362
363        /**
364         * Expand the U parameters matrix by added a set of rows.
365         * If currently unset, this function does nothing (assuming U will be initialised in the first round)
366         * The new U parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDUINITSTRAT}
367         * @param newUsers the number of new users to add
368         */
369        public void addU(int newUsers) {
370                if(this.u == null) return; // If u has not be inited, then it will be on first process
371                final InitStrategy ustrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDUINITSTRAT,null,null);
372                final Matrix newU = ustrat.init(newUsers, this.u.columnCount());
373                this.u = MatlibMatrixUtils.vstack(this.u,newU);
374        }
375
376        /**
377         * Expand the W parameters matrix by added a set of rows.
378         * If currently unset, this function does nothing (assuming W will be initialised in the first round)
379         * The new W parameters are initialised used {@link BilinearLearnerParameters#EXPANDEDWINITSTRAT}
380         * @param newWords the number of new words to add
381         */
382        public void addW(int newWords) {
383                if(this.w == null) return; // If w has not be inited, then it will be on first process
384                final InitStrategy wstrat = this.getInitStrat(BilinearLearnerParameters.EXPANDEDWINITSTRAT,null,null);
385                final Matrix newW = wstrat.init(newWords, this.w.columnCount());
386                this.w = MatlibMatrixUtils.vstack(this.w,newW);
387        }
388
389        @Override
390        public MatlibBilinearSparseOnlineLearner clone(){
391                final MatlibBilinearSparseOnlineLearner ret = new MatlibBilinearSparseOnlineLearner(this.getParams());
392                ret.u = MatlibMatrixUtils.copy(this.u);
393                ret.w = MatlibMatrixUtils.copy(this.w);
394                if(this.biasMode){
395                        ret.bias = MatlibMatrixUtils.copy(this.bias);
396                }
397                return ret;
398        }
399        /**
400         * @param newu set the model's U
401         */
402        public void setU(Matrix newu) {
403                this.u = newu;
404        }
405
406        /**
407         * @param neww set the model's W
408         */
409        public void setW(Matrix neww) {
410                this.w = neww;
411        }
412        @Override
413        public void readBinary(DataInput in) throws IOException {
414                final int nwords = in.readInt();
415                final int nusers = in.readInt();
416                final int ntasks = in.readInt();
417
418
419                this.w = SparseMatrix.sparse(nwords, ntasks);
420                for (int t = 0; t < ntasks; t++) {
421                        for (int r = 0; r < nwords; r++) {
422                                final double readDouble = in.readDouble();
423                                if(readDouble != 0){
424                                        this.w.put(r, t, readDouble);
425                                }
426                        }
427                }
428
429                this.u = SparseMatrix.sparse(nusers, ntasks);
430                for (int t = 0; t < ntasks; t++) {
431                        for (int r = 0; r < nusers; r++) {
432                                final double readDouble = in.readDouble();
433                                if(readDouble != 0){
434                                        this.u.put(r, t, readDouble);
435                                }
436                        }
437                }
438
439                this.bias = SparseMatrix.sparse(ntasks, ntasks);
440                for (int t1 = 0; t1 < ntasks; t1++) {
441                        for (int t2 = 0; t2 < ntasks; t2++) {
442                                final double readDouble = in.readDouble();
443                                if(readDouble != 0){
444                                        this.bias.put(t1, t2, readDouble);
445                                }
446                        }
447                }
448        }
449        @Override
450        public byte[] binaryHeader() {
451                return "".getBytes();
452        }
453        @Override
454        public void writeBinary(DataOutput out) throws IOException {
455                out.writeInt(w.rowCount());
456                out.writeInt(u.rowCount());
457                out.writeInt(u.columnCount());
458                final double[] wdata = w.asColumnMajorArray();
459                for (int i = 0; i < wdata.length; i++) {
460                        out.writeDouble(wdata[i]);
461                }
462                final double[] udata = u.asColumnMajorArray();
463                for (int i = 0; i < udata.length; i++) {
464                        out.writeDouble(udata[i]);
465                }
466                final double[] biasdata = bias.asColumnMajorArray();
467                for (int i = 0; i < biasdata.length; i++) {
468                        out.writeDouble(biasdata[i]);
469                }
470        }
471
472
473        @Override
474        public Matrix predict(Matrix x) {
475                Matrix xt = MatlibMatrixUtils.transpose(x);
476                final Matrix mult = MatlibMatrixUtils.dotProduct(MatlibMatrixUtils.dotProduct(MatlibMatrixUtils.transpose(u), xt),this.w);
477                if(this.biasMode) MatlibMatrixUtils.plusInplace(mult,this.bias);
478                Matrix ydiag = new DiagonalMatrix(mult);
479                return ydiag;
480        }
481}