1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.commons.math3.stat.correlation;
19 
20 import java.util.ArrayList;
21 import java.util.HashSet;
22 import java.util.List;
23 import java.util.Set;
24 
25 import org.apache.commons.math3.exception.DimensionMismatchException;
26 import org.apache.commons.math3.exception.MathIllegalArgumentException;
27 import org.apache.commons.math3.exception.util.LocalizedFormats;
28 import org.apache.commons.math3.linear.BlockRealMatrix;
29 import org.apache.commons.math3.linear.RealMatrix;
30 import org.apache.commons.math3.stat.ranking.NaNStrategy;
31 import org.apache.commons.math3.stat.ranking.NaturalRanking;
32 import org.apache.commons.math3.stat.ranking.RankingAlgorithm;
33 
34 /**
35  * Spearman's rank correlation. This implementation performs a rank
36  * transformation on the input data and then computes {@link PearsonsCorrelation}
37  * on the ranked data.
38  * <p>
39  * By default, ranks are computed using {@link NaturalRanking} with default
40  * strategies for handling NaNs and ties in the data (NaNs maximal, ties averaged).
41  * The ranking algorithm can be set using a constructor argument.
42  *
43  * @since 2.0
44  */
45 public class SpearmansCorrelation {
46 
47     /** Input data */
48     private final RealMatrix data;
49 
50     /** Ranking algorithm  */
51     private final RankingAlgorithm rankingAlgorithm;
52 
53     /** Rank correlation */
54     private final PearsonsCorrelation rankCorrelation;
55 
56     /**
57      * Create a SpearmansCorrelation without data.
58      */
SpearmansCorrelation()59     public SpearmansCorrelation() {
60         this(new NaturalRanking());
61     }
62 
63     /**
64      * Create a SpearmansCorrelation with the given ranking algorithm.
65      * <p>
66      * From version 4.0 onwards this constructor will throw an exception
67      * if the provided {@link NaturalRanking} uses a {@link NaNStrategy#REMOVED} strategy.
68      *
69      * @param rankingAlgorithm ranking algorithm
70      * @since 3.1
71      */
SpearmansCorrelation(final RankingAlgorithm rankingAlgorithm)72     public SpearmansCorrelation(final RankingAlgorithm rankingAlgorithm) {
73         data = null;
74         this.rankingAlgorithm = rankingAlgorithm;
75         rankCorrelation = null;
76     }
77 
78     /**
79      * Create a SpearmansCorrelation from the given data matrix.
80      *
81      * @param dataMatrix matrix of data with columns representing
82      * variables to correlate
83      */
SpearmansCorrelation(final RealMatrix dataMatrix)84     public SpearmansCorrelation(final RealMatrix dataMatrix) {
85         this(dataMatrix, new NaturalRanking());
86     }
87 
88     /**
89      * Create a SpearmansCorrelation with the given input data matrix
90      * and ranking algorithm.
91      * <p>
92      * From version 4.0 onwards this constructor will throw an exception
93      * if the provided {@link NaturalRanking} uses a {@link NaNStrategy#REMOVED} strategy.
94      *
95      * @param dataMatrix matrix of data with columns representing
96      * variables to correlate
97      * @param rankingAlgorithm ranking algorithm
98      */
SpearmansCorrelation(final RealMatrix dataMatrix, final RankingAlgorithm rankingAlgorithm)99     public SpearmansCorrelation(final RealMatrix dataMatrix, final RankingAlgorithm rankingAlgorithm) {
100         this.rankingAlgorithm = rankingAlgorithm;
101         this.data = rankTransform(dataMatrix);
102         rankCorrelation = new PearsonsCorrelation(data);
103     }
104 
105     /**
106      * Calculate the Spearman Rank Correlation Matrix.
107      *
108      * @return Spearman Rank Correlation Matrix
109      * @throws NullPointerException if this instance was created with no data
110      */
getCorrelationMatrix()111     public RealMatrix getCorrelationMatrix() {
112         return rankCorrelation.getCorrelationMatrix();
113     }
114 
115     /**
116      * Returns a {@link PearsonsCorrelation} instance constructed from the
117      * ranked input data. That is,
118      * <code>new SpearmansCorrelation(matrix).getRankCorrelation()</code>
119      * is equivalent to
120      * <code>new PearsonsCorrelation(rankTransform(matrix))</code> where
121      * <code>rankTransform(matrix)</code> is the result of applying the
122      * configured <code>RankingAlgorithm</code> to each of the columns of
123      * <code>matrix.</code>
124      *
125      * <p>Returns null if this instance was created with no data.</p>
126      *
127      * @return PearsonsCorrelation among ranked column data
128      */
getRankCorrelation()129     public PearsonsCorrelation getRankCorrelation() {
130         return rankCorrelation;
131     }
132 
133     /**
134      * Computes the Spearman's rank correlation matrix for the columns of the
135      * input matrix.
136      *
137      * @param matrix matrix with columns representing variables to correlate
138      * @return correlation matrix
139      */
computeCorrelationMatrix(final RealMatrix matrix)140     public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
141         final RealMatrix matrixCopy = rankTransform(matrix);
142         return new PearsonsCorrelation().computeCorrelationMatrix(matrixCopy);
143     }
144 
145     /**
146      * Computes the Spearman's rank correlation matrix for the columns of the
147      * input rectangular array.  The columns of the array represent values
148      * of variables to be correlated.
149      *
150      * @param matrix matrix with columns representing variables to correlate
151      * @return correlation matrix
152      */
computeCorrelationMatrix(final double[][] matrix)153     public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
154        return computeCorrelationMatrix(new BlockRealMatrix(matrix));
155     }
156 
157     /**
158      * Computes the Spearman's rank correlation coefficient between the two arrays.
159      *
160      * @param xArray first data array
161      * @param yArray second data array
162      * @return Returns Spearman's rank correlation coefficient for the two arrays
163      * @throws DimensionMismatchException if the arrays lengths do not match
164      * @throws MathIllegalArgumentException if the array length is less than 2
165      */
correlation(final double[] xArray, final double[] yArray)166     public double correlation(final double[] xArray, final double[] yArray) {
167         if (xArray.length != yArray.length) {
168             throw new DimensionMismatchException(xArray.length, yArray.length);
169         } else if (xArray.length < 2) {
170             throw new MathIllegalArgumentException(LocalizedFormats.INSUFFICIENT_DIMENSION,
171                                                    xArray.length, 2);
172         } else {
173             double[] x = xArray;
174             double[] y = yArray;
175             if (rankingAlgorithm instanceof NaturalRanking &&
176                 NaNStrategy.REMOVED == ((NaturalRanking) rankingAlgorithm).getNanStrategy()) {
177                 final Set<Integer> nanPositions = new HashSet<Integer>();
178 
179                 nanPositions.addAll(getNaNPositions(xArray));
180                 nanPositions.addAll(getNaNPositions(yArray));
181 
182                 x = removeValues(xArray, nanPositions);
183                 y = removeValues(yArray, nanPositions);
184             }
185             return new PearsonsCorrelation().correlation(rankingAlgorithm.rank(x), rankingAlgorithm.rank(y));
186         }
187     }
188 
189     /**
190      * Applies rank transform to each of the columns of <code>matrix</code>
191      * using the current <code>rankingAlgorithm</code>.
192      *
193      * @param matrix matrix to transform
194      * @return a rank-transformed matrix
195      */
rankTransform(final RealMatrix matrix)196     private RealMatrix rankTransform(final RealMatrix matrix) {
197         RealMatrix transformed = null;
198 
199         if (rankingAlgorithm instanceof NaturalRanking &&
200                 ((NaturalRanking) rankingAlgorithm).getNanStrategy() == NaNStrategy.REMOVED) {
201             final Set<Integer> nanPositions = new HashSet<Integer>();
202             for (int i = 0; i < matrix.getColumnDimension(); i++) {
203                 nanPositions.addAll(getNaNPositions(matrix.getColumn(i)));
204             }
205 
206             // if we have found NaN values, we have to update the matrix size
207             if (!nanPositions.isEmpty()) {
208                 transformed = new BlockRealMatrix(matrix.getRowDimension() - nanPositions.size(),
209                                                   matrix.getColumnDimension());
210                 for (int i = 0; i < transformed.getColumnDimension(); i++) {
211                     transformed.setColumn(i, removeValues(matrix.getColumn(i), nanPositions));
212                 }
213             }
214         }
215 
216         if (transformed == null) {
217             transformed = matrix.copy();
218         }
219 
220         for (int i = 0; i < transformed.getColumnDimension(); i++) {
221             transformed.setColumn(i, rankingAlgorithm.rank(transformed.getColumn(i)));
222         }
223 
224         return transformed;
225     }
226 
227     /**
228      * Returns a list containing the indices of NaN values in the input array.
229      *
230      * @param input the input array
231      * @return a list of NaN positions in the input array
232      */
getNaNPositions(final double[] input)233     private List<Integer> getNaNPositions(final double[] input) {
234         final List<Integer> positions = new ArrayList<Integer>();
235         for (int i = 0; i < input.length; i++) {
236             if (Double.isNaN(input[i])) {
237                 positions.add(i);
238             }
239         }
240         return positions;
241     }
242 
243     /**
244      * Removes all values from the input array at the specified indices.
245      *
246      * @param input the input array
247      * @param indices a set containing the indices to be removed
248      * @return the input array without the values at the specified indices
249      */
removeValues(final double[] input, final Set<Integer> indices)250     private double[] removeValues(final double[] input, final Set<Integer> indices) {
251         if (indices.isEmpty()) {
252             return input;
253         }
254         final double[] result = new double[input.length - indices.size()];
255         for (int i = 0, j = 0; i < input.length; i++) {
256             if (!indices.contains(i)) {
257                 result[j++] = input[i];
258             }
259         }
260         return result;
261     }
262 }
263