001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.experiment.evaluation.cluster.analyser;
031
032import java.util.Map;
033
034import org.apache.log4j.Logger;
035import org.openimaj.logger.LoggerUtils;
036
037/**
038 * The normalised mutual information of a cluster estimate
039 * 
040 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
041 */
042public class NMIClusterAnalyser implements ClusterAnalyser<NMIAnalysis> {
043
044        private final static Logger logger = Logger.getLogger(NMIClusterAnalyser.class);
045
046        @Override
047        public NMIAnalysis analyse(int[][] correct, int[][] estimated) {
048                final NMIAnalysis ret = new NMIAnalysis();
049                final Map<Integer, Integer> invCor = ClusterAnalyserUtils.invert(correct);
050                final Map<Integer, Integer> invEst = ClusterAnalyserUtils.invert(estimated);
051                ret.nmi = nmi(correct, estimated, invCor, invEst);
052                return ret;
053        }
054
055        private double nmi(int[][] c, int[][] e, Map<Integer, Integer> ic, Map<Integer, Integer> ie) {
056                final double N = Math.max(ic.size(), ie.size());
057                final double mi = mutualInformation(N, c, e, ic, ie);
058                LoggerUtils.debugFormat(logger, "Iec = %2.5f", mi);
059                final double ent_e = entropy(e, N);
060                LoggerUtils.debugFormat(logger, "He = %2.5f", ent_e);
061                final double ent_c = entropy(c, N);
062                LoggerUtils.debugFormat(logger, "Hc = %2.5f", ent_c);
063                return mi / ((ent_e + ent_c) / 2);
064        }
065
066        /**
067         * Maximum liklihood estimate of the entropy
068         * 
069         * @param clusters
070         * @param N
071         * @return
072         */
073        private double entropy(int[][] clusters, double N) {
074                double total = 0;
075                for (int k = 0; k < clusters.length; k++) {
076                        LoggerUtils.debugFormat(logger, "%2.1f/%2.1f * log2 ((%2.1f / %2.1f) )", (double) clusters[k].length, N,
077                                        (double) clusters[k].length, N);
078                        final double prop = clusters[k].length / N;
079                        total += prop * log2(prop);
080                }
081                return -total;
082        }
083
084        private double log2(double prop) {
085                if (prop == 0)
086                        return 0;
087                return Math.log(prop) / Math.log(2);
088        }
089
090        /**
091         * Maximum Liklihood estimate of the mutual information
092         * 
093         * @param c
094         * @param e
095         * @param ic
096         * @param ie
097         * @return
098         */
099        private double mutualInformation(double N, int[][] c, int[][] e, Map<Integer, Integer> ic, Map<Integer, Integer> ie) {
100                double mi = 0;
101                for (int k = 0; k < e.length; k++) {
102                        final double n_e = e[k].length;
103                        for (int j = 0; j < c.length; j++) {
104                                final double n_c = c[j].length;
105                                double both = 0;
106                                for (int i = 0; i < e[k].length; i++) {
107                                        final Integer itemCluster = ic.get(e[k][i]);
108                                        if (itemCluster == null)
109                                                continue;
110                                        if (itemCluster == j)
111                                                both++;
112                                }
113                                final double normProp = (both * N) / (n_c * n_e);
114                                // LoggerUtils.debugFormat(logger,"normprop = %2.5f",normProp);
115                                final double sum = (both / N) * (log2(normProp));
116                                mi += sum;
117
118                                // LoggerUtils.debugFormat(logger,"%2.1f/%2.1f * log2 ((%2.1f * %2.1f) / (%2.1f * %2.1f)) = %2.5f",both,N,both,N,n_c,n_e,sum);
119                        }
120                }
121                return mi;
122        }
123
124        // public static void main(String[] args) {
125        // LoggerUtils.prepareConsoleLogger();
126        // NMIClusterAnalyser an = new NMIClusterAnalyser();
127        // NMIAnalysis res = an.analyse(
128        // new int[][]{new int[]{1,2,3},new int[]{4,5,6}},
129        // // new int[][]{new int[]{1,2},new int[]{3},new int[]{4,5},new int[]{6}}
130        // // new int[][]{new int[]{1},new int[]{2},new int[]{3},new int[]{4},new
131        // int[]{5},new int[]{6}}
132        // new int[][]{new int[]{7,8,9}}
133        // );
134        // System.out.println(res);
135        // }
136
137}