1 //
2 //  runSparcc.cpp
3 //  PDSSparCC
4 //
5 //  Created by Patrick Schloss on 10/31/12.
6 //  Copyright (c) 2012 University of Michigan. All rights reserved.
7 //
8 
9 #include "calcsparcc.h"
10 #include "linearalgebra.h"
11 
12 /**************************************************************************************************/
13 
CalcSparcc(vector<vector<float>> sharedVector,int maxIterations,int numSamplings,string method)14 CalcSparcc::CalcSparcc(vector<vector<float> > sharedVector, int maxIterations, int numSamplings, string method){
15     try {
16         m = MothurOut::getInstance();
17         numOTUs = (int)sharedVector[0].size();
18         numGroups = (int)sharedVector.size();
19         normalizationMethod = method;
20         int numOTUs = (int)sharedVector[0].size();
21 
22         addPseudoCount(sharedVector);
23 
24         vector<vector<vector<float> > > allCorrelations(numSamplings);
25 
26         //    float cycClockStart = clock();
27         //    unsigned long long cycTimeStart = time(NULL);
28 
29         for(int i=0;i<numSamplings;i++){
30 
31             if (m->getControl_pressed()) { break; }
32             vector<float> logFractions =  getLogFractions(sharedVector, method);
33             getT_Matrix(logFractions);     //this step is slow...
34             getT_Vector();
35             getD_Matrix();
36             vector<float> basisVariances = getBasisVariances();     //this step is slow...
37             vector<vector<float> > correlation = getBasisCorrelations(basisVariances);
38 
39             excluded.resize(numOTUs);
40             for(int j=0;j<numOTUs;j++){ excluded[j].assign(numOTUs, 0); }
41 
42             float maxRho = 1;
43             int excludeRow = -1;
44             int excludeColumn = -1;
45 
46             int iter = 0;
47             while(maxRho > 0.10 && iter < maxIterations){
48                 maxRho = getExcludedPairs(correlation, excludeRow, excludeColumn);
49                 excludeValues(excludeRow, excludeColumn);
50                 vector<float> excludedBasisVariances = getBasisVariances();
51                 correlation = getBasisCorrelations(excludedBasisVariances);
52                 iter++;
53             }
54             allCorrelations[i] = correlation;
55         }
56 
57         if (!m->getControl_pressed()) {
58             if(numSamplings > 1){
59                 getMedian(allCorrelations);
60             }
61             else{
62                 median = allCorrelations[0];
63             }
64         }
65     }
66     catch(exception& e) {
67         m->errorOut(e, "CalcSparcc", "CalcSparcc");
68         exit(1);
69     }
70 }
71 
72 /**************************************************************************************************/
73 
addPseudoCount(vector<vector<float>> & sharedVector)74 void CalcSparcc::addPseudoCount(vector<vector<float> >& sharedVector){
75     try {
76         for(int i=0;i<numGroups;i++){   //iterate across the groups
77             if (m->getControl_pressed()) { return; }
78             for(int j=0;j<numOTUs;j++){
79                 sharedVector[i][j] += 1;
80             }
81         }
82     }
83     catch(exception& e) {
84         m->errorOut(e, "CalcSparcc", "addPseudoCount");
85         exit(1);
86     }
87 }
88 /**************************************************************************************************/
89 
getLogFractions(vector<vector<float>> sharedVector,string method)90 vector<float> CalcSparcc::getLogFractions(vector<vector<float> > sharedVector, string method){   //dirichlet by default
91     try {
92         vector<float> logSharedFractions(numGroups * numOTUs, 0);
93 
94         if(method == "dirichlet"){
95             vector<float> alphas(numGroups);
96             for(int i=0;i<numGroups;i++){   //iterate across the groups
97                 if (m->getControl_pressed()) { return logSharedFractions; }
98                 alphas = util.randomDirichlet(sharedVector[i]);
99 
100                 for(int j=0;j<numOTUs;j++){
101 									logSharedFractions[i * numOTUs + j] = alphas[j];
102 								}
103             }
104         }
105         else if(method == "relabund"){
106             for(int i=0;i<numGroups;i++){
107                 if (m->getControl_pressed()) { return logSharedFractions; }
108                 float total = 0.0;
109                 for(int j=0;j<numOTUs;j++){
110                     total += sharedVector[i][j];
111                 }
112                 for(int j=0;j<numOTUs;j++){
113                     logSharedFractions[i * numOTUs + j] = sharedVector[i][j]/total;
114                 }
115             }
116         }
117 
118         for(int i=0;i<logSharedFractions.size();i++){
119             logSharedFractions[i] = log(logSharedFractions[i]);
120         }
121 
122         return logSharedFractions;
123     }
124     catch(exception& e) {
125         m->errorOut(e, "CalcSparcc", "addPseudoCount");
126         exit(1);
127     }
128 
129 }
130 
131 /**************************************************************************************************/
132 
getT_Matrix(vector<float> sharedFractions)133 void CalcSparcc::getT_Matrix(vector<float> sharedFractions){
134     try {
135         tMatrix.resize(numOTUs * numOTUs, 0);
136 
137         vector<float> diff(numGroups);
138         for(int j1=0;j1<numOTUs;j1++){
139             for(int j2=0;j2<j1;j2++){
140                 if (m->getControl_pressed()) { return; }
141                 float mean = 0.0;
142                 for(int i=0;i<numGroups;i++){
143                     diff[i] = sharedFractions[i * numOTUs + j1] - sharedFractions[i * numOTUs + j2];
144                     mean += diff[i];
145                 }
146 
147                 mean /= float(numGroups);
148                 float variance = 0.0;
149                 for(int i=0;i<numGroups;i++){
150                     variance += (diff[i] - mean) * (diff[i] - mean);
151                 }
152                 variance /= (float)(numGroups-1);
153 
154                 tMatrix[j1 * numOTUs + j2] = variance;
155                 tMatrix[j2 * numOTUs + j1] = tMatrix[j1 * numOTUs + j2];
156             }
157         }
158     }
159     catch(exception& e) {
160         m->errorOut(e, "CalcSparcc", "getT_Matrix");
161         exit(1);
162     }
163 
164 }
165 
166 /**************************************************************************************************/
167 
getT_Vector()168 void CalcSparcc::getT_Vector(){
169     try {
170         tVector.assign(numOTUs, 0);
171 
172         for(int j1=0;j1<numOTUs;j1++){
173             if (m->getControl_pressed()) { return; }
174             for(int j2=0;j2<numOTUs;j2++){
175                 tVector[j1] += tMatrix[j1 * numOTUs + j2];
176             }
177         }
178     }
179     catch(exception& e) {
180         m->errorOut(e, "CalcSparcc", "getT_Vector");
181         exit(1);
182     }
183 }
184 
185 /**************************************************************************************************/
186 
getD_Matrix()187 void CalcSparcc::getD_Matrix(){
188     try {
189         float d = numOTUs - 1.0;
190 
191         dMatrix.resize(numOTUs);
192         for(int i=0;i<numOTUs;i++){
193             if (m->getControl_pressed()) { return; }
194             dMatrix[i].resize(numOTUs, 1);
195             dMatrix[i][i] = d;
196         }
197     }
198     catch(exception& e) {
199         m->errorOut(e, "CalcSparcc", "getD_Matrix");
200         exit(1);
201     }
202 }
203 
204 /**************************************************************************************************/
205 
getBasisVariances()206 vector<float> CalcSparcc::getBasisVariances(){
207     try {
208         LinearAlgebra LA;
209 
210         vector<float> variances = LA.solveEquations(dMatrix, tVector);
211 
212         for(int i=0;i<variances.size();i++){
213             if (m->getControl_pressed()) { return variances; }
214             if(variances[i] < 0){   variances[i] = 1e-4;    }
215         }
216 
217         return variances;
218     }
219     catch(exception& e) {
220         m->errorOut(e, "CalcSparcc", "getBasisVariances");
221         exit(1);
222     }
223 }
224 
225 /**************************************************************************************************/
226 
getBasisCorrelations(vector<float> basisVariance)227 vector<vector<float> > CalcSparcc::getBasisCorrelations(vector<float> basisVariance){
228     try {
229         vector<vector<float> > rho(numOTUs);
230         for(int i=0;i<numOTUs;i++){ rho[i].resize(numOTUs, 0);    }
231 
232         for(int i=0;i<numOTUs;i++){
233             float var_i = basisVariance[i];
234             float sqrt_var_i = sqrt(var_i);
235 
236             rho[i][i] = 1.00;
237 
238             for(int j=0;j<i;j++){
239                 if (m->getControl_pressed()) { return rho; }
240                 float var_j = basisVariance[j];
241 
242                 rho[i][j] = (var_i + var_j - tMatrix[i * numOTUs + j]) / (2.0 * sqrt_var_i * sqrt(var_j));
243                 if(rho[i][j] > 1.0)         {   rho[i][j] = 1.0;   }
244                 else if(rho[i][j] < -1.0)   {   rho[i][j] = -1.0;  }
245 
246                 rho[j][i] = rho[i][j];
247             }
248         }
249 
250         return rho;
251     }
252     catch(exception& e) {
253         m->errorOut(e, "CalcSparcc", "getBasisCorrelations");
254         exit(1);
255     }
256 }
257 
258 /**************************************************************************************************/
259 
getExcludedPairs(vector<vector<float>> rho,int & maxRow,int & maxColumn)260 float CalcSparcc::getExcludedPairs(vector<vector<float> > rho, int& maxRow, int& maxColumn){
261     try {
262         float maxRho = 0;
263         maxRow = -1;
264         maxColumn = -1;
265 
266         for(int i=0;i<numOTUs;i++){
267 
268             for(int j=0;j<i;j++){
269                 if (m->getControl_pressed()) { return maxRho; }
270                 float tester = abs(rho[i][j]);
271 
272                 if(tester > maxRho && excluded[i][j] != 1){
273                     maxRho = tester;
274                     maxRow = i;
275                     maxColumn = j;
276                 }
277             }
278 
279         }
280 
281         return maxRho;
282     }
283     catch(exception& e) {
284         m->errorOut(e, "CalcSparcc", "getExcludedPairs");
285         exit(1);
286     }
287 }
288 
289 /**************************************************************************************************/
290 
excludeValues(int excludeRow,int excludeColumn)291 void CalcSparcc::excludeValues(int excludeRow, int excludeColumn){
292     try {
293         tVector[excludeRow] -= tMatrix[excludeRow * numOTUs + excludeColumn];
294         tVector[excludeColumn] -= tMatrix[excludeRow * numOTUs + excludeColumn];
295 
296         dMatrix[excludeRow][excludeColumn] = 0;
297         dMatrix[excludeColumn][excludeRow] = 0;
298         dMatrix[excludeRow][excludeRow]--;
299         dMatrix[excludeColumn][excludeColumn]--;
300 
301         excluded[excludeRow][excludeColumn] = 1;
302         excluded[excludeColumn][excludeRow] = 1;
303     }
304     catch(exception& e) {
305         m->errorOut(e, "CalcSparcc", "excludeValues");
306         exit(1);
307     }
308 }
309 
310 /**************************************************************************************************/
311 
getMedian(vector<vector<vector<float>>> allCorrelations)312 void CalcSparcc::getMedian(vector<vector<vector<float> > > allCorrelations){
313     try {
314         int numSamples = (int)allCorrelations.size();
315         median.resize(numOTUs);
316         for(int i=0;i<numOTUs;i++){ median[i].assign(numOTUs, 1);   }
317 
318         vector<float> hold(numSamples);
319 
320         for(int i=0;i<numOTUs;i++){
321             for(int j=0;j<i;j++){
322                 if (m->getControl_pressed()) { return; }
323 
324                 for(int k=0;k<numSamples;k++){
325                     hold[k] = allCorrelations[k][i][j];
326                 }
327 
328                 sort(hold.begin(), hold.end());
329                 median[i][j] = hold[int(numSamples * 0.5)];
330                 median[j][i] = median[i][j];
331             }
332         }
333     }
334     catch(exception& e) {
335         m->errorOut(e, "CalcSparcc", "getMedian");
336         exit(1);
337     }
338 }
339 
340 /**************************************************************************************************/
341