1 /*****************************************************************
2 Copyright (C) 2001-2012 Leo Breiman, Adele Cutler and Merck & Co., Inc.
3 
4 This program is free software; you can redistribute it and/or
5 modify it under the terms of the GNU General Public License
6 as published by the Free Software Foundation; either version 2
7 of the License, or (at your option) any later version.
8 
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13 
14 You should have received a copy of the GNU General Public License
15 along with this program; if not, write to the Free Software
16 Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
17 
18 C driver for Breiman & Cutler's random forest code.
19 Re-written from the original main program in Fortran.
20 Andy Liaw Feb. 7, 2002.
21 Modifications to get the forest out Matt Wiener Feb. 26, 2002.
22 *****************************************************************/
23 
24 #include <R.h>
25 #include <R_ext/Utils.h>
26 #include "rf.h"
27 
28 void oob(int nsample, int nclass, int *cl, int *jtr,int *jerr,
29          int *counttr, int *out, double *errtr, int *jest, double *cutoff);
30 
31 void TestSetError(double *countts, int *jts, int *clts, int *jet, int ntest,
32 		  int nclass, int nvote, double *errts,
33 		  int labelts, int *nclts, double *cutoff);
34 
35 /*  Define the R RNG for use from Fortran. */
F77_SUB(rrand)36 void F77_SUB(rrand)(double *r) { *r = unif_rand(); }
37 
classRF(double * x,int * dimx,int * cl,int * ncl,int * cat,int * maxcat,int * sampsize,int * strata,int * Options,int * ntree,int * nvar,int * ipi,double * classwt,double * cut,int * nodesize,int * outcl,int * counttr,double * prox,double * imprt,double * impsd,double * impmat,int * nrnodes,int * ndbigtree,int * nodestatus,int * bestvar,int * treemap,int * nodeclass,double * xbestsplit,double * errtr,int * testdat,double * xts,int * clts,int * nts,double * countts,int * outclts,int * labelts,double * proxts,double * errts,int * inbag)38 void classRF(double *x, int *dimx, int *cl, int *ncl, int *cat, int *maxcat,
39              int *sampsize, int *strata, int *Options, int *ntree, int *nvar,
40              int *ipi, double *classwt, double *cut, int *nodesize,
41              int *outcl, int *counttr, double *prox,
42              double *imprt, double *impsd, double *impmat, int *nrnodes,
43              int *ndbigtree, int *nodestatus, int *bestvar, int *treemap,
44              int *nodeclass, double *xbestsplit, double *errtr,
45              int *testdat, double *xts, int *clts, int *nts, double *countts,
46              int *outclts, int *labelts, double *proxts, double *errts,
47              int *inbag) {
48   /******************************************************************
49    *  C wrapper for random forests:  get input from R and drive
50    *  the Fortran routines.
51    *
52    *  Input:
53    *
54    *  x:        matrix of predictors (transposed!)
55    *  dimx:     two integers: number of variables and number of cases
56    *  cl:       class labels of the data
57    *  ncl:      number of classes in the response
58    *  cat:      integer vector of number of classes in the predictor;
59    *            1=continuous
60    * maxcat:    maximum of cat
61    * Options:   7 integers: (0=no, 1=yes)
62    *     add a second class (for unsupervised RF)?
63    *         1: sampling from product of marginals
64    *         2: sampling from product of uniforms
65    *     assess variable importance?
66    *     calculate proximity?
67    *     calculate proximity based on OOB predictions?
68    *     calculate outlying measure?
69    *     how often to print output?
70    *     keep the forest for future prediction?
71    *  ntree:    number of trees
72    *  nvar:     number of predictors to use for each split
73    *  ipi:      0=use class proportion as prob.; 1=use supplied priors
74    *  pi:       double vector of class priors
75    *  nodesize: minimum node size: no node with fewer than ndsize
76    *            cases will be split
77    *
78    *  Output:
79    *
80    *  outcl:    class predicted by RF
81    *  counttr:  matrix of votes (transposed!)
82    *  imprt:    matrix of variable importance measures
83    *  impmat:   matrix of local variable importance measures
84    *  prox:     matrix of proximity (if iprox=1)
85    ******************************************************************/
86 
87   int nsample0, mdim, nclass, addClass, mtry, ntest, nsample, ndsize,
88   mimp, nimp, near, nuse, noutall, nrightall, nrightimpall,
89   keepInbag, nstrata;
90   int jb, j, n, m, k, idxByNnode, idxByNsample, imp, localImp, iprox,
91   oobprox, keepf, replace, stratify, trace, *nright,
92   *nrightimp, *nout, *nclts, Ntree;
93 
94   int *out, *nodepop, *jin, *nodex,
95   *nodexts, *nodestart, *ta, *ncase, *jerr, *varUsed,
96   *jtr, *classFreq, *idmove, *jvr,
97   *at, *a, *b, *mind, *nind, *jts, *oobpair;
98   int **strata_idx, *strata_size, last, ktmp, nEmpty, ntry;
99 
100   double av=0.0, delta=0.0;
101 
102   double *tgini, *tx, *wl, *classpop, *tclasscat, *tclasspop, *win,
103   *tp, *wr, *bestsplitnext, *bestsplit;
104 
105   addClass = Options[0];
106   imp      = Options[1];
107   localImp = Options[2];
108   iprox    = Options[3];
109   oobprox  = Options[4];
110   trace    = Options[5];
111   keepf    = Options[6];
112   replace  = Options[7];
113   stratify = Options[8];
114   keepInbag = Options[9];
115   mdim     = dimx[0];
116   nsample0 = dimx[1];
117   nclass   = (*ncl==1) ? 2 : *ncl;
118   ndsize   = *nodesize;
119   Ntree    = *ntree;
120   mtry     = *nvar;
121   ntest    = *nts;
122   nsample = addClass ? (nsample0 + nsample0) : nsample0;
123   mimp = imp ? mdim : 1;
124   nimp = imp ? nsample : 1;
125   near = iprox ? nsample0 : 1;
126   if (trace == 0) trace = Ntree + 1;
127 
128   tgini =      (double *) S_alloc(mdim, sizeof(double));
129   wl =         (double *) S_alloc(nclass, sizeof(double));
130   wr =         (double *) S_alloc(nclass, sizeof(double));
131   classpop =   (double *) S_alloc(nclass* *nrnodes, sizeof(double));
132   tclasscat =  (double *) S_alloc(nclass*MAX_CAT, sizeof(double));
133   tclasspop =  (double *) S_alloc(nclass, sizeof(double));
134   tx =         (double *) S_alloc(nsample, sizeof(double));
135   win =        (double *) S_alloc(nsample, sizeof(double));
136   tp =         (double *) S_alloc(nsample, sizeof(double));
137   bestsplitnext = (double *) S_alloc(*nrnodes, sizeof(double));
138   bestsplit =     (double *) S_alloc(*nrnodes, sizeof(double));
139 
140   out =           (int *) S_alloc(nsample, sizeof(int));
141   nodepop =       (int *) S_alloc(*nrnodes, sizeof(int));
142   nodestart =     (int *) S_alloc(*nrnodes, sizeof(int));
143   jin =           (int *) S_alloc(nsample, sizeof(int));
144   nodex =         (int *) S_alloc(nsample, sizeof(int));
145   nodexts =       (int *) S_alloc(ntest, sizeof(int));
146   ta =            (int *) S_alloc(nsample, sizeof(int));
147   ncase =         (int *) S_alloc(nsample, sizeof(int));
148   jerr =          (int *) S_alloc(nsample, sizeof(int));
149   varUsed =       (int *) S_alloc(mdim, sizeof(int));
150   jtr =           (int *) S_alloc(nsample, sizeof(int));
151   jvr =           (int *) S_alloc(nsample, sizeof(int));
152   classFreq =     (int *) S_alloc(nclass, sizeof(int));
153   jts =           (int *) S_alloc(ntest, sizeof(int));
154   idmove =        (int *) S_alloc(nsample, sizeof(int));
155   at =            (int *) S_alloc(mdim*nsample, sizeof(int));
156   a =             (int *) S_alloc(mdim*nsample, sizeof(int));
157   b =             (int *) S_alloc(mdim*nsample, sizeof(int));
158   mind =          (int *) S_alloc(mdim, sizeof(int));
159   nright =        (int *) S_alloc(nclass, sizeof(int));
160   nrightimp =     (int *) S_alloc(nclass, sizeof(int));
161   nout =          (int *) S_alloc(nclass, sizeof(int));
162   if (oobprox) {
163     oobpair = (int *) S_alloc(near*near, sizeof(int));
164   }
165 
166   /* Count number of cases in each class. */
167   zeroInt(classFreq, nclass);
168   for (n = 0; n < nsample; ++n) classFreq[cl[n] - 1] ++;
169   /* Normalize class weights. */
170   normClassWt(cl, nsample, nclass, *ipi, classwt, classFreq);
171 
172   if (stratify) {
173     /* Count number of strata and frequency of each stratum. */
174     nstrata = 0;
175     for (n = 0; n < nsample0; ++n)
176       if (strata[n] > nstrata) nstrata = strata[n];
177       /* Create the array of pointers, each pointing to a vector
178        of indices of where data of each stratum is. */
179       strata_size = (int  *) S_alloc(nstrata, sizeof(int));
180       for (n = 0; n < nsample0; ++n) {
181         strata_size[strata[n] - 1] ++;
182       }
183       strata_idx =  (int **) S_alloc(nstrata, sizeof(int *));
184       for (n = 0; n < nstrata; ++n) {
185         strata_idx[n] = (int *) S_alloc(strata_size[n], sizeof(int));
186       }
187       zeroInt(strata_size, nstrata);
188       for (n = 0; n < nsample0; ++n) {
189         strata_size[strata[n] - 1] ++;
190         strata_idx[strata[n] - 1][strata_size[strata[n] - 1] - 1] = n;
191       }
192   } else {
193     nind = replace ? NULL : (int *) S_alloc(nsample, sizeof(int));
194   }
195 
196   /*    INITIALIZE FOR RUN */
197   if (*testdat) zeroDouble(countts, ntest * nclass);
198   zeroInt(counttr, nclass * nsample);
199   zeroInt(out, nsample);
200   zeroDouble(tgini, mdim);
201   zeroDouble(errtr, (nclass + 1) * Ntree);
202 
203   if (*labelts) {
204     nclts  = (int *) S_alloc(nclass, sizeof(int));
205     for (n = 0; n < ntest; ++n) nclts[clts[n]-1]++;
206     zeroDouble(errts, (nclass + 1) * Ntree);
207   }
208 
209   if (imp) {
210     zeroDouble(imprt, (nclass+2) * mdim);
211     zeroDouble(impsd, (nclass+1) * mdim);
212     if (localImp) zeroDouble(impmat, nsample * mdim);
213   }
214   if (iprox) {
215     zeroDouble(prox, nsample0 * nsample0);
216     if (*testdat) zeroDouble(proxts, ntest * (ntest + nsample0));
217   }
218   /* pre-sort x data */
219   makeA(x, mdim, nsample, cat, at, b);
220 
221   R_CheckUserInterrupt();
222 
223 
224   /* Starting the main loop over number of trees. */
225   GetRNGstate();
226   if (trace <= Ntree) {
227     /* Print header for running output. */
228     Rprintf("ntree      OOB");
229     for (n = 1; n <= nclass; ++n) Rprintf("%7i", n);
230     if (*labelts) {
231       Rprintf("|    Test");
232       for (n = 1; n <= nclass; ++n) Rprintf("%7i", n);
233     }
234     Rprintf("\n");
235   }
236   idxByNnode = 0;
237   idxByNsample = 0;
238   for (jb = 0; jb < Ntree; jb++) {
239     /* Do we need to simulate data for the second class? */
240     if (addClass) {
241       createClass(x, nsample0, nsample, mdim);
242       makeA(x, mdim, nsample, cat, at, b);
243     }
244 
245     do {
246       zeroInt(nodestatus + idxByNnode, *nrnodes);
247       zeroInt(treemap + 2*idxByNnode, 2 * *nrnodes);
248       zeroDouble(xbestsplit + idxByNnode, *nrnodes);
249       zeroInt(nodeclass + idxByNnode, *nrnodes);
250       zeroInt(varUsed, mdim);
251       /* TODO: Put all sampling code into a function. */
252       /* drawSample(sampsize, nsample, ); */
253       if (stratify) {  /* stratified sampling */
254       zeroInt(jin, nsample);
255         zeroDouble(tclasspop, nclass);
256         zeroDouble(win, nsample);
257         if (replace) {  /* with replacement */
258       for (n = 0; n < nstrata; ++n) {
259         for (j = 0; j < sampsize[n]; ++j) {
260           ktmp = (int) (unif_rand() * strata_size[n]);
261           k = strata_idx[n][ktmp];
262           tclasspop[cl[k] - 1] += classwt[cl[k] - 1];
263           win[k] += classwt[cl[k] - 1];
264           jin[k] += 1;
265         }
266       }
267         } else { /* stratified sampling w/o replacement */
268       /* re-initialize the index array */
269       zeroInt(strata_size, nstrata);
270           for (j = 0; j < nsample; ++j) {
271             strata_size[strata[j] - 1] ++;
272             strata_idx[strata[j] - 1][strata_size[strata[j] - 1] - 1] = j;
273           }
274           /* sampling without replacement */
275           for (n = 0; n < nstrata; ++n) {
276             last = strata_size[n] - 1;
277             for (j = 0; j < sampsize[n]; ++j) {
278               ktmp = (int) (unif_rand() * (last+1));
279               k = strata_idx[n][ktmp];
280               swapInt(strata_idx[n][last], strata_idx[n][ktmp]);
281               last--;
282               tclasspop[cl[k] - 1] += classwt[cl[k]-1];
283               win[k] += classwt[cl[k]-1];
284               jin[k] += 1;
285             }
286           }
287         }
288       } else {  /* unstratified sampling */
289           ntry = 0;
290         do {
291           nEmpty = 0;
292           zeroInt(jin, nsample);
293           zeroDouble(tclasspop, nclass);
294           zeroDouble(win, nsample);
295           if (replace) {
296             for (n = 0; n < *sampsize; ++n) {
297               k = unif_rand() * nsample;
298               tclasspop[cl[k] - 1] += classwt[cl[k]-1];
299               win[k] += classwt[cl[k]-1];
300               jin[k] += 1;
301             }
302           } else {
303             for (n = 0; n < nsample; ++n) nind[n] = n;
304             last = nsample - 1;
305             for (n = 0; n < *sampsize; ++n) {
306               ktmp = (int) (unif_rand() * (last+1));
307               k = nind[ktmp];
308               swapInt(nind[ktmp], nind[last]);
309               last--;
310               tclasspop[cl[k] - 1] += classwt[cl[k]-1];
311               win[k] += classwt[cl[k]-1];
312               jin[k] += 1;
313             }
314           }
315           /* check if any class is missing in the sample */
316           for (n = 0; n < nclass; ++n) {
317             if (tclasspop[n] == 0.0) nEmpty++;
318           }
319           ntry++;
320         } while (nclass - nEmpty < 2 && ntry <= 30);
321         /* If there are still fewer than two classes in the data, throw an error. */
322         if (nclass - nEmpty < 2) error("Still have fewer than two classes in the in-bag sample after 30 attempts.");
323       }
324 
325       /* If need to keep indices of inbag data, do that here. */
326       if (keepInbag) {
327         for (n = 0; n < nsample0; ++n) {
328           inbag[n + idxByNsample] = jin[n];
329         }
330       }
331 
332       /* Copy the original a matrix back. */
333       memcpy(a, at, sizeof(int) * mdim * nsample);
334       modA(a, &nuse, nsample, mdim, cat, *maxcat, ncase, jin);
335 
336       F77_CALL(buildtree)(a, b, cl, cat, maxcat, &mdim, &nsample,
337                &nclass,
338                treemap + 2*idxByNnode, bestvar + idxByNnode,
339                bestsplit, bestsplitnext, tgini,
340                nodestatus + idxByNnode, nodepop,
341                nodestart, classpop, tclasspop, tclasscat,
342                ta, nrnodes, idmove, &ndsize, ncase,
343                &mtry, varUsed, nodeclass + idxByNnode,
344                ndbigtree + jb, win, wr, wl, &mdim,
345                &nuse, mind);
346       /* if the "tree" has only the root node, start over */
347     } while (ndbigtree[jb] == 1);
348 
349     Xtranslate(x, mdim, *nrnodes, nsample, bestvar + idxByNnode,
350                bestsplit, bestsplitnext, xbestsplit + idxByNnode,
351                nodestatus + idxByNnode, cat, ndbigtree[jb]);
352 
353     /*  Get test set error */
354     if (*testdat) {
355       predictClassTree(xts, ntest, mdim, treemap + 2*idxByNnode,
356                        nodestatus + idxByNnode, xbestsplit + idxByNnode,
357                        bestvar + idxByNnode,
358                        nodeclass + idxByNnode, ndbigtree[jb],
359                        cat, nclass, jts, nodexts, *maxcat);
360       TestSetError(countts, jts, clts, outclts, ntest, nclass, jb+1,
361                    errts + jb*(nclass+1), *labelts, nclts, cut);
362     }
363 
364     /*  Get out-of-bag predictions and errors. */
365     predictClassTree(x, nsample, mdim, treemap + 2*idxByNnode,
366                      nodestatus + idxByNnode, xbestsplit + idxByNnode,
367                      bestvar + idxByNnode,
368                      nodeclass + idxByNnode, ndbigtree[jb],
369                      cat, nclass, jtr, nodex, *maxcat);
370 
371     zeroInt(nout, nclass);
372     noutall = 0;
373     for (n = 0; n < nsample; ++n) {
374       if (jin[n] == 0) {
375         /* increment the OOB votes */
376         counttr[n*nclass + jtr[n] - 1] ++;
377         /* count number of times a case is OOB */
378         out[n]++;
379         /* count number of OOB cases in the current iteration.
380          nout[n] is the number of OOB cases for the n-th class.
381          noutall is the number of OOB cases overall. */
382         nout[cl[n] - 1]++;
383         noutall++;
384       }
385     }
386 
387     /* Compute out-of-bag error rate. */
388     oob(nsample, nclass, cl, jtr, jerr, counttr, out,
389         errtr + jb*(nclass+1), outcl, cut);
390 
391     if ((jb+1) % trace == 0) {
392       Rprintf("%5i: %6.2f%%", jb+1, 100.0*errtr[jb * (nclass+1)]);
393       for (n = 1; n <= nclass; ++n) {
394         Rprintf("%6.2f%%", 100.0 * errtr[n + jb * (nclass+1)]);
395       }
396       if (*labelts) {
397         Rprintf("| ");
398         for (n = 0; n <= nclass; ++n) {
399           Rprintf("%6.2f%%", 100.0 * errts[n + jb * (nclass+1)]);
400         }
401       }
402       Rprintf("\n");
403 #ifdef WIN32
404       R_FlushConsole();
405       R_ProcessEvents();
406 #endif
407       R_CheckUserInterrupt();
408     }
409 
410     /*  DO PROXIMITIES */
411     if (iprox) {
412       computeProximity(prox, oobprox, nodex, jin, oobpair, near);
413       /* proximity for test data */
414       if (*testdat) {
415         computeProximity(proxts, 0, nodexts, jin, oobpair, ntest);
416         /* Compute proximity between testset and training set. */
417         for (n = 0; n < ntest; ++n) {
418           for (k = 0; k < near; ++k) {
419             if (nodexts[n] == nodex[k])
420               proxts[n + ntest * (k+ntest)] += 1.0;
421           }
422         }
423       }
424     }
425 
426     /*  DO VARIABLE IMPORTANCE  */
427     if (imp) {
428       nrightall = 0;
429       /* Count the number of correct prediction by the current tree
430        among the OOB samples, by class. */
431       zeroInt(nright, nclass);
432       for (n = 0; n < nsample; ++n) {
433         /* out-of-bag and predicted correctly: */
434         if (jin[n] == 0 && jtr[n] == cl[n]) {
435           nright[cl[n] - 1]++;
436           nrightall++;
437         }
438       }
439       for (m = 0; m < mdim; ++m) {
440         if (varUsed[m]) {
441           nrightimpall = 0;
442           zeroInt(nrightimp, nclass);
443           for (n = 0; n < nsample; ++n) tx[n] = x[m + n*mdim];
444           /* Permute the m-th variable. */
445           permuteOOB(m, x, jin, nsample, mdim);
446           /* Predict the modified data using the current tree. */
447           predictClassTree(x, nsample, mdim, treemap + 2*idxByNnode,
448                            nodestatus + idxByNnode,
449                            xbestsplit + idxByNnode,
450                            bestvar + idxByNnode,
451                            nodeclass + idxByNnode, ndbigtree[jb],
452                                                             cat, nclass, jvr, nodex, *maxcat);
453           /* Count how often correct predictions are made with
454            the modified data. */
455           for (n = 0; n < nsample; n++) {
456             /* Restore the original data for that variable. */
457             x[m + n*mdim] = tx[n];
458             if (jin[n] == 0) {
459               if (jvr[n] == cl[n]) {
460                 +								nrightimp[cl[n] - 1]++;
461                 nrightimpall++;
462               }
463               if (localImp && jvr[n] != jtr[n]) {
464                 if (cl[n] == jvr[n]) {
465                   impmat[m + n*mdim] -= 1.0;
466                 } else {
467                   impmat[m + n*mdim] += 1.0;
468                 }
469               }
470             }
471           }
472           /* Accumulate decrease in proportions of correct
473            predictions. */
474           /* class-specific measures first: */
475           for (n = 0; n < nclass; ++n) {
476             if (nout[n] > 0) {
477               delta = ((double) (nright[n] - nrightimp[n])) / nout[n];
478               imprt[m + n*mdim] += delta;
479               impsd[m + n*mdim] += delta * delta;
480             }
481           }
482           /* overall measure, across all classes: */
483           if (noutall > 0) {
484             delta = ((double)(nrightall - nrightimpall)) / noutall;
485             imprt[m + nclass*mdim] += delta;
486             impsd[m + nclass*mdim] += delta * delta;
487           }
488         }
489       }
490     }
491 
492     R_CheckUserInterrupt();
493 #ifdef WIN32
494     R_ProcessEvents();
495 #endif
496     if (keepf) idxByNnode += *nrnodes;
497     if (keepInbag) idxByNsample += nsample0;
498   }
499   PutRNGstate();
500 
501   /*  Final processing of variable importance. */
502   for (m = 0; m < mdim; m++) tgini[m] /= Ntree;
503   if (imp) {
504     for (m = 0; m < mdim; ++m) {
505       if (localImp) { /* casewise measures */
506   for (n = 0; n < nsample; ++n) impmat[m + n*mdim] /= out[n];
507       }
508       /* class-specific measures */
509       for (k = 0; k < nclass; ++k) {
510         av = imprt[m + k*mdim] / Ntree;
511         impsd[m + k*mdim] =
512           sqrt(((impsd[m + k*mdim] / Ntree) - av*av) / Ntree);
513         imprt[m + k*mdim] = av;
514         /* imprt[m + k*mdim] = (se <= 0.0) ? -1000.0 - av : av / se; */
515       }
516       /* overall measures */
517       av = imprt[m + nclass*mdim] / Ntree;
518       impsd[m + nclass*mdim] =
519         sqrt(((impsd[m + nclass*mdim] / Ntree) - av*av) / Ntree);
520       imprt[m + nclass*mdim] = av;
521       imprt[m + (nclass+1)*mdim] = tgini[m];
522     }
523   } else {
524     for (m = 0; m < mdim; ++m) imprt[m] = tgini[m];
525   }
526 
527   /*  PROXIMITY DATA ++++++++++++++++++++++++++++++++*/
528   if (iprox) {
529     for (n = 0; n < near; ++n) {
530       for (k = n + 1; k < near; ++k) {
531         prox[near*k + n] /= oobprox ?
532         (oobpair[near*k + n] > 0 ? oobpair[near*k + n] : 1) :
533         Ntree;
534         prox[near*n + k] = prox[near*k + n];
535       }
536       prox[near*n + n] = 1.0;
537     }
538     if (*testdat) {
539       for (n = 0; n < ntest; ++n) {
540         for (k = 0; k < ntest + nsample; ++k)
541           proxts[ntest*k + n] /= Ntree;
542         proxts[ntest * n + n] = 1.0;
543       }
544     }
545   }
546 }
547 
548 
classForest(int * mdim,int * ntest,int * nclass,int * maxcat,int * nrnodes,int * ntree,double * x,double * xbestsplit,double * pid,double * cutoff,double * countts,int * treemap,int * nodestatus,int * cat,int * nodeclass,int * jts,int * jet,int * bestvar,int * node,int * treeSize,int * keepPred,int * prox,double * proxMat,int * nodes)549 void classForest(int *mdim, int *ntest, int *nclass, int *maxcat,
550                  int *nrnodes, int *ntree, double *x, double *xbestsplit,
551                  double *pid, double *cutoff, double *countts, int *treemap,
552                  int *nodestatus, int *cat, int *nodeclass, int *jts,
553                  int *jet, int *bestvar, int *node, int *treeSize,
554                  int *keepPred, int *prox, double *proxMat, int *nodes) {
555   int j, n, n1, n2, idxNodes, offset1, offset2, *junk, ntie;
556   double crit, cmax;
557 
558   zeroDouble(countts, *nclass * *ntest);
559   idxNodes = 0;
560   offset1 = 0;
561   offset2 = 0;
562   junk = NULL;
563 
564   for (j = 0; j < *ntree; ++j) {
565     /* predict by the j-th tree */
566     predictClassTree(x, *ntest, *mdim, treemap + 2*idxNodes,
567                      nodestatus + idxNodes, xbestsplit + idxNodes,
568                      bestvar + idxNodes, nodeclass + idxNodes,
569                      treeSize[j], cat, *nclass,
570                      jts + offset1, node + offset2, *maxcat);
571     /* accumulate votes: */
572     for (n = 0; n < *ntest; ++n) {
573       countts[jts[n + offset1] - 1 + n * *nclass] += 1.0;
574     }
575 
576     /* if desired, do proximities for this round */
577     if (*prox) computeProximity(proxMat, 0, node + offset2, junk, junk,
578         *ntest);
579     idxNodes += *nrnodes;
580     if (*keepPred) offset1 += *ntest;
581     if (*nodes)    offset2 += *ntest;
582   }
583 
584   /* Aggregated prediction is the class with the maximum votes/cutoff */
585   for (n = 0; n < *ntest; ++n) {
586     cmax = 0.0;
587     ntie = 1;
588     for (j = 0; j < *nclass; ++j) {
589       crit = (countts[j + n * *nclass] / *ntree) / cutoff[j];
590       if (crit > cmax) {
591         jet[n] = j + 1;
592         cmax = crit;
593         ntie = 1;
594       }
595       /* Break ties at random: */
596       if (crit == cmax) {
597         ntie++;
598         if (unif_rand() < 1.0 / ntie) jet[n] = j + 1;
599       }
600     }
601   }
602 
603   /* if proximities requested, do the final adjustment
604    (division by number of trees) */
605   if (*prox) {
606     for (n1 = 0; n1 < *ntest; ++n1) {
607       for (n2 = n1 + 1; n2 < *ntest; ++n2) {
608         proxMat[n1 + n2 * *ntest] /= *ntree;
609         proxMat[n2 + n1 * *ntest] = proxMat[n1 + n2 * *ntest];
610       }
611       proxMat[n1 + n1 * *ntest] = 1.0;
612     }
613   }
614 }
615 
616 /*
617  Modified by A. Liaw 1/10/2003 (Deal with cutoff)
618  Re-written in C by A. Liaw 3/08/2004
619  */
oob(int nsample,int nclass,int * cl,int * jtr,int * jerr,int * counttr,int * out,double * errtr,int * jest,double * cutoff)620 void oob(int nsample, int nclass, int *cl, int *jtr,int *jerr,
621          int *counttr, int *out, double *errtr, int *jest,
622          double *cutoff) {
623   int j, n, noob, *noobcl, ntie;
624   double qq, smax, smaxtr;
625 
626   noobcl  = (int *) S_alloc(nclass, sizeof(int));
627   zeroInt(jerr, nsample);
628   zeroDouble(errtr, nclass+1);
629 
630   noob = 0;
631   for (n = 0; n < nsample; ++n) {
632     if (out[n]) {
633       noob++;
634       noobcl[cl[n]-1]++;
635       smax = 0.0;
636       smaxtr = 0.0;
637       ntie = 1;
638       for (j = 0; j < nclass; ++j) {
639         qq = (((double) counttr[j + n*nclass]) / out[n]) / cutoff[j];
640         if (j+1 != cl[n]) smax = (qq > smax) ? qq : smax;
641         /* if vote / cutoff is larger than current max, re-set max and
642          change predicted class to the current class */
643         if (qq > smaxtr) {
644           smaxtr = qq;
645           jest[n] = j+1;
646           ntie = 1;
647         }
648         /* break tie at random */
649         if (qq == smaxtr) {
650           ntie++;
651           if (unif_rand() < 1.0 / ntie) {
652             smaxtr = qq;
653             jest[n] = j+1;
654           }
655         }
656       }
657       if (jest[n] != cl[n]) {
658         errtr[cl[n]] += 1.0;
659         errtr[0] += 1.0;
660         jerr[n] = 1;
661       }
662     }
663   }
664   errtr[0] /= noob;
665   for (n = 1; n <= nclass; ++n) errtr[n] /= noobcl[n-1];
666 }
667 
668 
TestSetError(double * countts,int * jts,int * clts,int * jet,int ntest,int nclass,int nvote,double * errts,int labelts,int * nclts,double * cutoff)669 void TestSetError(double *countts, int *jts, int *clts, int *jet, int ntest,
670                   int nclass, int nvote, double *errts,
671                   int labelts, int *nclts, double *cutoff) {
672   int j, n, ntie;
673   double cmax, crit;
674 
675   for (n = 0; n < ntest; ++n) countts[jts[n]-1 + n*nclass] += 1.0;
676 
677   /*  Prediction is the class with the maximum votes */
678   for (n = 0; n < ntest; ++n) {
679     cmax=0.0;
680     ntie = 1;
681     for (j = 0; j < nclass; ++j) {
682       crit = (countts[j + n*nclass] / nvote) / cutoff[j];
683       if (crit > cmax) {
684         jet[n] = j+1;
685         cmax = crit;
686         ntie = 1;
687       }
688       /*  Break ties at random: */
689       if (crit == cmax) {
690         ntie++;
691         if (unif_rand() < 1.0 / ntie) {
692           jet[n] = j+1;
693           cmax = crit;
694         }
695       }
696     }
697   }
698   if (labelts) {
699     zeroDouble(errts, nclass + 1);
700     for (n = 0; n < ntest; ++n) {
701       if (jet[n] != clts[n]) {
702         errts[0] += 1.0;
703         errts[clts[n]] += 1.0;
704       }
705     }
706     errts[0] /= ntest;
707     for (n = 1; n <= nclass; ++n) errts[n] /= nclts[n-1];
708   }
709 }
710