1 /*****************************************************************
2 Copyright (C) 2001-2012 Leo Breiman, Adele Cutler and Merck & Co., Inc.
3
4 This program is free software; you can redistribute it and/or
5 modify it under the terms of the GNU General Public License
6 as published by the Free Software Foundation; either version 2
7 of the License, or (at your option) any later version.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program; if not, write to the Free Software
16 Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
17
18 C driver for Breiman & Cutler's random forest code.
19 Re-written from the original main program in Fortran.
20 Andy Liaw Feb. 7, 2002.
21 Modifications to get the forest out Matt Wiener Feb. 26, 2002.
22 *****************************************************************/
23
24 #include <R.h>
25 #include <R_ext/Utils.h>
26 #include "rf.h"
27
28 void oob(int nsample, int nclass, int *cl, int *jtr,int *jerr,
29 int *counttr, int *out, double *errtr, int *jest, double *cutoff);
30
31 void TestSetError(double *countts, int *jts, int *clts, int *jet, int ntest,
32 int nclass, int nvote, double *errts,
33 int labelts, int *nclts, double *cutoff);
34
35 /* Define the R RNG for use from Fortran. */
F77_SUB(rrand)36 void F77_SUB(rrand)(double *r) { *r = unif_rand(); }
37
classRF(double * x,int * dimx,int * cl,int * ncl,int * cat,int * maxcat,int * sampsize,int * strata,int * Options,int * ntree,int * nvar,int * ipi,double * classwt,double * cut,int * nodesize,int * outcl,int * counttr,double * prox,double * imprt,double * impsd,double * impmat,int * nrnodes,int * ndbigtree,int * nodestatus,int * bestvar,int * treemap,int * nodeclass,double * xbestsplit,double * errtr,int * testdat,double * xts,int * clts,int * nts,double * countts,int * outclts,int * labelts,double * proxts,double * errts,int * inbag)38 void classRF(double *x, int *dimx, int *cl, int *ncl, int *cat, int *maxcat,
39 int *sampsize, int *strata, int *Options, int *ntree, int *nvar,
40 int *ipi, double *classwt, double *cut, int *nodesize,
41 int *outcl, int *counttr, double *prox,
42 double *imprt, double *impsd, double *impmat, int *nrnodes,
43 int *ndbigtree, int *nodestatus, int *bestvar, int *treemap,
44 int *nodeclass, double *xbestsplit, double *errtr,
45 int *testdat, double *xts, int *clts, int *nts, double *countts,
46 int *outclts, int *labelts, double *proxts, double *errts,
47 int *inbag) {
48 /******************************************************************
49 * C wrapper for random forests: get input from R and drive
50 * the Fortran routines.
51 *
52 * Input:
53 *
54 * x: matrix of predictors (transposed!)
55 * dimx: two integers: number of variables and number of cases
56 * cl: class labels of the data
57 * ncl: number of classes in the response
58 * cat: integer vector of number of classes in the predictor;
59 * 1=continuous
60 * maxcat: maximum of cat
61 * Options: 7 integers: (0=no, 1=yes)
62 * add a second class (for unsupervised RF)?
63 * 1: sampling from product of marginals
64 * 2: sampling from product of uniforms
65 * assess variable importance?
66 * calculate proximity?
67 * calculate proximity based on OOB predictions?
68 * calculate outlying measure?
69 * how often to print output?
70 * keep the forest for future prediction?
71 * ntree: number of trees
72 * nvar: number of predictors to use for each split
73 * ipi: 0=use class proportion as prob.; 1=use supplied priors
74 * pi: double vector of class priors
75 * nodesize: minimum node size: no node with fewer than ndsize
76 * cases will be split
77 *
78 * Output:
79 *
80 * outcl: class predicted by RF
81 * counttr: matrix of votes (transposed!)
82 * imprt: matrix of variable importance measures
83 * impmat: matrix of local variable importance measures
84 * prox: matrix of proximity (if iprox=1)
85 ******************************************************************/
86
87 int nsample0, mdim, nclass, addClass, mtry, ntest, nsample, ndsize,
88 mimp, nimp, near, nuse, noutall, nrightall, nrightimpall,
89 keepInbag, nstrata;
90 int jb, j, n, m, k, idxByNnode, idxByNsample, imp, localImp, iprox,
91 oobprox, keepf, replace, stratify, trace, *nright,
92 *nrightimp, *nout, *nclts, Ntree;
93
94 int *out, *nodepop, *jin, *nodex,
95 *nodexts, *nodestart, *ta, *ncase, *jerr, *varUsed,
96 *jtr, *classFreq, *idmove, *jvr,
97 *at, *a, *b, *mind, *nind, *jts, *oobpair;
98 int **strata_idx, *strata_size, last, ktmp, nEmpty, ntry;
99
100 double av=0.0, delta=0.0;
101
102 double *tgini, *tx, *wl, *classpop, *tclasscat, *tclasspop, *win,
103 *tp, *wr, *bestsplitnext, *bestsplit;
104
105 addClass = Options[0];
106 imp = Options[1];
107 localImp = Options[2];
108 iprox = Options[3];
109 oobprox = Options[4];
110 trace = Options[5];
111 keepf = Options[6];
112 replace = Options[7];
113 stratify = Options[8];
114 keepInbag = Options[9];
115 mdim = dimx[0];
116 nsample0 = dimx[1];
117 nclass = (*ncl==1) ? 2 : *ncl;
118 ndsize = *nodesize;
119 Ntree = *ntree;
120 mtry = *nvar;
121 ntest = *nts;
122 nsample = addClass ? (nsample0 + nsample0) : nsample0;
123 mimp = imp ? mdim : 1;
124 nimp = imp ? nsample : 1;
125 near = iprox ? nsample0 : 1;
126 if (trace == 0) trace = Ntree + 1;
127
128 tgini = (double *) S_alloc(mdim, sizeof(double));
129 wl = (double *) S_alloc(nclass, sizeof(double));
130 wr = (double *) S_alloc(nclass, sizeof(double));
131 classpop = (double *) S_alloc(nclass* *nrnodes, sizeof(double));
132 tclasscat = (double *) S_alloc(nclass*MAX_CAT, sizeof(double));
133 tclasspop = (double *) S_alloc(nclass, sizeof(double));
134 tx = (double *) S_alloc(nsample, sizeof(double));
135 win = (double *) S_alloc(nsample, sizeof(double));
136 tp = (double *) S_alloc(nsample, sizeof(double));
137 bestsplitnext = (double *) S_alloc(*nrnodes, sizeof(double));
138 bestsplit = (double *) S_alloc(*nrnodes, sizeof(double));
139
140 out = (int *) S_alloc(nsample, sizeof(int));
141 nodepop = (int *) S_alloc(*nrnodes, sizeof(int));
142 nodestart = (int *) S_alloc(*nrnodes, sizeof(int));
143 jin = (int *) S_alloc(nsample, sizeof(int));
144 nodex = (int *) S_alloc(nsample, sizeof(int));
145 nodexts = (int *) S_alloc(ntest, sizeof(int));
146 ta = (int *) S_alloc(nsample, sizeof(int));
147 ncase = (int *) S_alloc(nsample, sizeof(int));
148 jerr = (int *) S_alloc(nsample, sizeof(int));
149 varUsed = (int *) S_alloc(mdim, sizeof(int));
150 jtr = (int *) S_alloc(nsample, sizeof(int));
151 jvr = (int *) S_alloc(nsample, sizeof(int));
152 classFreq = (int *) S_alloc(nclass, sizeof(int));
153 jts = (int *) S_alloc(ntest, sizeof(int));
154 idmove = (int *) S_alloc(nsample, sizeof(int));
155 at = (int *) S_alloc(mdim*nsample, sizeof(int));
156 a = (int *) S_alloc(mdim*nsample, sizeof(int));
157 b = (int *) S_alloc(mdim*nsample, sizeof(int));
158 mind = (int *) S_alloc(mdim, sizeof(int));
159 nright = (int *) S_alloc(nclass, sizeof(int));
160 nrightimp = (int *) S_alloc(nclass, sizeof(int));
161 nout = (int *) S_alloc(nclass, sizeof(int));
162 if (oobprox) {
163 oobpair = (int *) S_alloc(near*near, sizeof(int));
164 }
165
166 /* Count number of cases in each class. */
167 zeroInt(classFreq, nclass);
168 for (n = 0; n < nsample; ++n) classFreq[cl[n] - 1] ++;
169 /* Normalize class weights. */
170 normClassWt(cl, nsample, nclass, *ipi, classwt, classFreq);
171
172 if (stratify) {
173 /* Count number of strata and frequency of each stratum. */
174 nstrata = 0;
175 for (n = 0; n < nsample0; ++n)
176 if (strata[n] > nstrata) nstrata = strata[n];
177 /* Create the array of pointers, each pointing to a vector
178 of indices of where data of each stratum is. */
179 strata_size = (int *) S_alloc(nstrata, sizeof(int));
180 for (n = 0; n < nsample0; ++n) {
181 strata_size[strata[n] - 1] ++;
182 }
183 strata_idx = (int **) S_alloc(nstrata, sizeof(int *));
184 for (n = 0; n < nstrata; ++n) {
185 strata_idx[n] = (int *) S_alloc(strata_size[n], sizeof(int));
186 }
187 zeroInt(strata_size, nstrata);
188 for (n = 0; n < nsample0; ++n) {
189 strata_size[strata[n] - 1] ++;
190 strata_idx[strata[n] - 1][strata_size[strata[n] - 1] - 1] = n;
191 }
192 } else {
193 nind = replace ? NULL : (int *) S_alloc(nsample, sizeof(int));
194 }
195
196 /* INITIALIZE FOR RUN */
197 if (*testdat) zeroDouble(countts, ntest * nclass);
198 zeroInt(counttr, nclass * nsample);
199 zeroInt(out, nsample);
200 zeroDouble(tgini, mdim);
201 zeroDouble(errtr, (nclass + 1) * Ntree);
202
203 if (*labelts) {
204 nclts = (int *) S_alloc(nclass, sizeof(int));
205 for (n = 0; n < ntest; ++n) nclts[clts[n]-1]++;
206 zeroDouble(errts, (nclass + 1) * Ntree);
207 }
208
209 if (imp) {
210 zeroDouble(imprt, (nclass+2) * mdim);
211 zeroDouble(impsd, (nclass+1) * mdim);
212 if (localImp) zeroDouble(impmat, nsample * mdim);
213 }
214 if (iprox) {
215 zeroDouble(prox, nsample0 * nsample0);
216 if (*testdat) zeroDouble(proxts, ntest * (ntest + nsample0));
217 }
218 /* pre-sort x data */
219 makeA(x, mdim, nsample, cat, at, b);
220
221 R_CheckUserInterrupt();
222
223
224 /* Starting the main loop over number of trees. */
225 GetRNGstate();
226 if (trace <= Ntree) {
227 /* Print header for running output. */
228 Rprintf("ntree OOB");
229 for (n = 1; n <= nclass; ++n) Rprintf("%7i", n);
230 if (*labelts) {
231 Rprintf("| Test");
232 for (n = 1; n <= nclass; ++n) Rprintf("%7i", n);
233 }
234 Rprintf("\n");
235 }
236 idxByNnode = 0;
237 idxByNsample = 0;
238 for (jb = 0; jb < Ntree; jb++) {
239 /* Do we need to simulate data for the second class? */
240 if (addClass) {
241 createClass(x, nsample0, nsample, mdim);
242 makeA(x, mdim, nsample, cat, at, b);
243 }
244
245 do {
246 zeroInt(nodestatus + idxByNnode, *nrnodes);
247 zeroInt(treemap + 2*idxByNnode, 2 * *nrnodes);
248 zeroDouble(xbestsplit + idxByNnode, *nrnodes);
249 zeroInt(nodeclass + idxByNnode, *nrnodes);
250 zeroInt(varUsed, mdim);
251 /* TODO: Put all sampling code into a function. */
252 /* drawSample(sampsize, nsample, ); */
253 if (stratify) { /* stratified sampling */
254 zeroInt(jin, nsample);
255 zeroDouble(tclasspop, nclass);
256 zeroDouble(win, nsample);
257 if (replace) { /* with replacement */
258 for (n = 0; n < nstrata; ++n) {
259 for (j = 0; j < sampsize[n]; ++j) {
260 ktmp = (int) (unif_rand() * strata_size[n]);
261 k = strata_idx[n][ktmp];
262 tclasspop[cl[k] - 1] += classwt[cl[k] - 1];
263 win[k] += classwt[cl[k] - 1];
264 jin[k] += 1;
265 }
266 }
267 } else { /* stratified sampling w/o replacement */
268 /* re-initialize the index array */
269 zeroInt(strata_size, nstrata);
270 for (j = 0; j < nsample; ++j) {
271 strata_size[strata[j] - 1] ++;
272 strata_idx[strata[j] - 1][strata_size[strata[j] - 1] - 1] = j;
273 }
274 /* sampling without replacement */
275 for (n = 0; n < nstrata; ++n) {
276 last = strata_size[n] - 1;
277 for (j = 0; j < sampsize[n]; ++j) {
278 ktmp = (int) (unif_rand() * (last+1));
279 k = strata_idx[n][ktmp];
280 swapInt(strata_idx[n][last], strata_idx[n][ktmp]);
281 last--;
282 tclasspop[cl[k] - 1] += classwt[cl[k]-1];
283 win[k] += classwt[cl[k]-1];
284 jin[k] += 1;
285 }
286 }
287 }
288 } else { /* unstratified sampling */
289 ntry = 0;
290 do {
291 nEmpty = 0;
292 zeroInt(jin, nsample);
293 zeroDouble(tclasspop, nclass);
294 zeroDouble(win, nsample);
295 if (replace) {
296 for (n = 0; n < *sampsize; ++n) {
297 k = unif_rand() * nsample;
298 tclasspop[cl[k] - 1] += classwt[cl[k]-1];
299 win[k] += classwt[cl[k]-1];
300 jin[k] += 1;
301 }
302 } else {
303 for (n = 0; n < nsample; ++n) nind[n] = n;
304 last = nsample - 1;
305 for (n = 0; n < *sampsize; ++n) {
306 ktmp = (int) (unif_rand() * (last+1));
307 k = nind[ktmp];
308 swapInt(nind[ktmp], nind[last]);
309 last--;
310 tclasspop[cl[k] - 1] += classwt[cl[k]-1];
311 win[k] += classwt[cl[k]-1];
312 jin[k] += 1;
313 }
314 }
315 /* check if any class is missing in the sample */
316 for (n = 0; n < nclass; ++n) {
317 if (tclasspop[n] == 0.0) nEmpty++;
318 }
319 ntry++;
320 } while (nclass - nEmpty < 2 && ntry <= 30);
321 /* If there are still fewer than two classes in the data, throw an error. */
322 if (nclass - nEmpty < 2) error("Still have fewer than two classes in the in-bag sample after 30 attempts.");
323 }
324
325 /* If need to keep indices of inbag data, do that here. */
326 if (keepInbag) {
327 for (n = 0; n < nsample0; ++n) {
328 inbag[n + idxByNsample] = jin[n];
329 }
330 }
331
332 /* Copy the original a matrix back. */
333 memcpy(a, at, sizeof(int) * mdim * nsample);
334 modA(a, &nuse, nsample, mdim, cat, *maxcat, ncase, jin);
335
336 F77_CALL(buildtree)(a, b, cl, cat, maxcat, &mdim, &nsample,
337 &nclass,
338 treemap + 2*idxByNnode, bestvar + idxByNnode,
339 bestsplit, bestsplitnext, tgini,
340 nodestatus + idxByNnode, nodepop,
341 nodestart, classpop, tclasspop, tclasscat,
342 ta, nrnodes, idmove, &ndsize, ncase,
343 &mtry, varUsed, nodeclass + idxByNnode,
344 ndbigtree + jb, win, wr, wl, &mdim,
345 &nuse, mind);
346 /* if the "tree" has only the root node, start over */
347 } while (ndbigtree[jb] == 1);
348
349 Xtranslate(x, mdim, *nrnodes, nsample, bestvar + idxByNnode,
350 bestsplit, bestsplitnext, xbestsplit + idxByNnode,
351 nodestatus + idxByNnode, cat, ndbigtree[jb]);
352
353 /* Get test set error */
354 if (*testdat) {
355 predictClassTree(xts, ntest, mdim, treemap + 2*idxByNnode,
356 nodestatus + idxByNnode, xbestsplit + idxByNnode,
357 bestvar + idxByNnode,
358 nodeclass + idxByNnode, ndbigtree[jb],
359 cat, nclass, jts, nodexts, *maxcat);
360 TestSetError(countts, jts, clts, outclts, ntest, nclass, jb+1,
361 errts + jb*(nclass+1), *labelts, nclts, cut);
362 }
363
364 /* Get out-of-bag predictions and errors. */
365 predictClassTree(x, nsample, mdim, treemap + 2*idxByNnode,
366 nodestatus + idxByNnode, xbestsplit + idxByNnode,
367 bestvar + idxByNnode,
368 nodeclass + idxByNnode, ndbigtree[jb],
369 cat, nclass, jtr, nodex, *maxcat);
370
371 zeroInt(nout, nclass);
372 noutall = 0;
373 for (n = 0; n < nsample; ++n) {
374 if (jin[n] == 0) {
375 /* increment the OOB votes */
376 counttr[n*nclass + jtr[n] - 1] ++;
377 /* count number of times a case is OOB */
378 out[n]++;
379 /* count number of OOB cases in the current iteration.
380 nout[n] is the number of OOB cases for the n-th class.
381 noutall is the number of OOB cases overall. */
382 nout[cl[n] - 1]++;
383 noutall++;
384 }
385 }
386
387 /* Compute out-of-bag error rate. */
388 oob(nsample, nclass, cl, jtr, jerr, counttr, out,
389 errtr + jb*(nclass+1), outcl, cut);
390
391 if ((jb+1) % trace == 0) {
392 Rprintf("%5i: %6.2f%%", jb+1, 100.0*errtr[jb * (nclass+1)]);
393 for (n = 1; n <= nclass; ++n) {
394 Rprintf("%6.2f%%", 100.0 * errtr[n + jb * (nclass+1)]);
395 }
396 if (*labelts) {
397 Rprintf("| ");
398 for (n = 0; n <= nclass; ++n) {
399 Rprintf("%6.2f%%", 100.0 * errts[n + jb * (nclass+1)]);
400 }
401 }
402 Rprintf("\n");
403 #ifdef WIN32
404 R_FlushConsole();
405 R_ProcessEvents();
406 #endif
407 R_CheckUserInterrupt();
408 }
409
410 /* DO PROXIMITIES */
411 if (iprox) {
412 computeProximity(prox, oobprox, nodex, jin, oobpair, near);
413 /* proximity for test data */
414 if (*testdat) {
415 computeProximity(proxts, 0, nodexts, jin, oobpair, ntest);
416 /* Compute proximity between testset and training set. */
417 for (n = 0; n < ntest; ++n) {
418 for (k = 0; k < near; ++k) {
419 if (nodexts[n] == nodex[k])
420 proxts[n + ntest * (k+ntest)] += 1.0;
421 }
422 }
423 }
424 }
425
426 /* DO VARIABLE IMPORTANCE */
427 if (imp) {
428 nrightall = 0;
429 /* Count the number of correct prediction by the current tree
430 among the OOB samples, by class. */
431 zeroInt(nright, nclass);
432 for (n = 0; n < nsample; ++n) {
433 /* out-of-bag and predicted correctly: */
434 if (jin[n] == 0 && jtr[n] == cl[n]) {
435 nright[cl[n] - 1]++;
436 nrightall++;
437 }
438 }
439 for (m = 0; m < mdim; ++m) {
440 if (varUsed[m]) {
441 nrightimpall = 0;
442 zeroInt(nrightimp, nclass);
443 for (n = 0; n < nsample; ++n) tx[n] = x[m + n*mdim];
444 /* Permute the m-th variable. */
445 permuteOOB(m, x, jin, nsample, mdim);
446 /* Predict the modified data using the current tree. */
447 predictClassTree(x, nsample, mdim, treemap + 2*idxByNnode,
448 nodestatus + idxByNnode,
449 xbestsplit + idxByNnode,
450 bestvar + idxByNnode,
451 nodeclass + idxByNnode, ndbigtree[jb],
452 cat, nclass, jvr, nodex, *maxcat);
453 /* Count how often correct predictions are made with
454 the modified data. */
455 for (n = 0; n < nsample; n++) {
456 /* Restore the original data for that variable. */
457 x[m + n*mdim] = tx[n];
458 if (jin[n] == 0) {
459 if (jvr[n] == cl[n]) {
460 + nrightimp[cl[n] - 1]++;
461 nrightimpall++;
462 }
463 if (localImp && jvr[n] != jtr[n]) {
464 if (cl[n] == jvr[n]) {
465 impmat[m + n*mdim] -= 1.0;
466 } else {
467 impmat[m + n*mdim] += 1.0;
468 }
469 }
470 }
471 }
472 /* Accumulate decrease in proportions of correct
473 predictions. */
474 /* class-specific measures first: */
475 for (n = 0; n < nclass; ++n) {
476 if (nout[n] > 0) {
477 delta = ((double) (nright[n] - nrightimp[n])) / nout[n];
478 imprt[m + n*mdim] += delta;
479 impsd[m + n*mdim] += delta * delta;
480 }
481 }
482 /* overall measure, across all classes: */
483 if (noutall > 0) {
484 delta = ((double)(nrightall - nrightimpall)) / noutall;
485 imprt[m + nclass*mdim] += delta;
486 impsd[m + nclass*mdim] += delta * delta;
487 }
488 }
489 }
490 }
491
492 R_CheckUserInterrupt();
493 #ifdef WIN32
494 R_ProcessEvents();
495 #endif
496 if (keepf) idxByNnode += *nrnodes;
497 if (keepInbag) idxByNsample += nsample0;
498 }
499 PutRNGstate();
500
501 /* Final processing of variable importance. */
502 for (m = 0; m < mdim; m++) tgini[m] /= Ntree;
503 if (imp) {
504 for (m = 0; m < mdim; ++m) {
505 if (localImp) { /* casewise measures */
506 for (n = 0; n < nsample; ++n) impmat[m + n*mdim] /= out[n];
507 }
508 /* class-specific measures */
509 for (k = 0; k < nclass; ++k) {
510 av = imprt[m + k*mdim] / Ntree;
511 impsd[m + k*mdim] =
512 sqrt(((impsd[m + k*mdim] / Ntree) - av*av) / Ntree);
513 imprt[m + k*mdim] = av;
514 /* imprt[m + k*mdim] = (se <= 0.0) ? -1000.0 - av : av / se; */
515 }
516 /* overall measures */
517 av = imprt[m + nclass*mdim] / Ntree;
518 impsd[m + nclass*mdim] =
519 sqrt(((impsd[m + nclass*mdim] / Ntree) - av*av) / Ntree);
520 imprt[m + nclass*mdim] = av;
521 imprt[m + (nclass+1)*mdim] = tgini[m];
522 }
523 } else {
524 for (m = 0; m < mdim; ++m) imprt[m] = tgini[m];
525 }
526
527 /* PROXIMITY DATA ++++++++++++++++++++++++++++++++*/
528 if (iprox) {
529 for (n = 0; n < near; ++n) {
530 for (k = n + 1; k < near; ++k) {
531 prox[near*k + n] /= oobprox ?
532 (oobpair[near*k + n] > 0 ? oobpair[near*k + n] : 1) :
533 Ntree;
534 prox[near*n + k] = prox[near*k + n];
535 }
536 prox[near*n + n] = 1.0;
537 }
538 if (*testdat) {
539 for (n = 0; n < ntest; ++n) {
540 for (k = 0; k < ntest + nsample; ++k)
541 proxts[ntest*k + n] /= Ntree;
542 proxts[ntest * n + n] = 1.0;
543 }
544 }
545 }
546 }
547
548
classForest(int * mdim,int * ntest,int * nclass,int * maxcat,int * nrnodes,int * ntree,double * x,double * xbestsplit,double * pid,double * cutoff,double * countts,int * treemap,int * nodestatus,int * cat,int * nodeclass,int * jts,int * jet,int * bestvar,int * node,int * treeSize,int * keepPred,int * prox,double * proxMat,int * nodes)549 void classForest(int *mdim, int *ntest, int *nclass, int *maxcat,
550 int *nrnodes, int *ntree, double *x, double *xbestsplit,
551 double *pid, double *cutoff, double *countts, int *treemap,
552 int *nodestatus, int *cat, int *nodeclass, int *jts,
553 int *jet, int *bestvar, int *node, int *treeSize,
554 int *keepPred, int *prox, double *proxMat, int *nodes) {
555 int j, n, n1, n2, idxNodes, offset1, offset2, *junk, ntie;
556 double crit, cmax;
557
558 zeroDouble(countts, *nclass * *ntest);
559 idxNodes = 0;
560 offset1 = 0;
561 offset2 = 0;
562 junk = NULL;
563
564 for (j = 0; j < *ntree; ++j) {
565 /* predict by the j-th tree */
566 predictClassTree(x, *ntest, *mdim, treemap + 2*idxNodes,
567 nodestatus + idxNodes, xbestsplit + idxNodes,
568 bestvar + idxNodes, nodeclass + idxNodes,
569 treeSize[j], cat, *nclass,
570 jts + offset1, node + offset2, *maxcat);
571 /* accumulate votes: */
572 for (n = 0; n < *ntest; ++n) {
573 countts[jts[n + offset1] - 1 + n * *nclass] += 1.0;
574 }
575
576 /* if desired, do proximities for this round */
577 if (*prox) computeProximity(proxMat, 0, node + offset2, junk, junk,
578 *ntest);
579 idxNodes += *nrnodes;
580 if (*keepPred) offset1 += *ntest;
581 if (*nodes) offset2 += *ntest;
582 }
583
584 /* Aggregated prediction is the class with the maximum votes/cutoff */
585 for (n = 0; n < *ntest; ++n) {
586 cmax = 0.0;
587 ntie = 1;
588 for (j = 0; j < *nclass; ++j) {
589 crit = (countts[j + n * *nclass] / *ntree) / cutoff[j];
590 if (crit > cmax) {
591 jet[n] = j + 1;
592 cmax = crit;
593 ntie = 1;
594 }
595 /* Break ties at random: */
596 if (crit == cmax) {
597 ntie++;
598 if (unif_rand() < 1.0 / ntie) jet[n] = j + 1;
599 }
600 }
601 }
602
603 /* if proximities requested, do the final adjustment
604 (division by number of trees) */
605 if (*prox) {
606 for (n1 = 0; n1 < *ntest; ++n1) {
607 for (n2 = n1 + 1; n2 < *ntest; ++n2) {
608 proxMat[n1 + n2 * *ntest] /= *ntree;
609 proxMat[n2 + n1 * *ntest] = proxMat[n1 + n2 * *ntest];
610 }
611 proxMat[n1 + n1 * *ntest] = 1.0;
612 }
613 }
614 }
615
616 /*
617 Modified by A. Liaw 1/10/2003 (Deal with cutoff)
618 Re-written in C by A. Liaw 3/08/2004
619 */
oob(int nsample,int nclass,int * cl,int * jtr,int * jerr,int * counttr,int * out,double * errtr,int * jest,double * cutoff)620 void oob(int nsample, int nclass, int *cl, int *jtr,int *jerr,
621 int *counttr, int *out, double *errtr, int *jest,
622 double *cutoff) {
623 int j, n, noob, *noobcl, ntie;
624 double qq, smax, smaxtr;
625
626 noobcl = (int *) S_alloc(nclass, sizeof(int));
627 zeroInt(jerr, nsample);
628 zeroDouble(errtr, nclass+1);
629
630 noob = 0;
631 for (n = 0; n < nsample; ++n) {
632 if (out[n]) {
633 noob++;
634 noobcl[cl[n]-1]++;
635 smax = 0.0;
636 smaxtr = 0.0;
637 ntie = 1;
638 for (j = 0; j < nclass; ++j) {
639 qq = (((double) counttr[j + n*nclass]) / out[n]) / cutoff[j];
640 if (j+1 != cl[n]) smax = (qq > smax) ? qq : smax;
641 /* if vote / cutoff is larger than current max, re-set max and
642 change predicted class to the current class */
643 if (qq > smaxtr) {
644 smaxtr = qq;
645 jest[n] = j+1;
646 ntie = 1;
647 }
648 /* break tie at random */
649 if (qq == smaxtr) {
650 ntie++;
651 if (unif_rand() < 1.0 / ntie) {
652 smaxtr = qq;
653 jest[n] = j+1;
654 }
655 }
656 }
657 if (jest[n] != cl[n]) {
658 errtr[cl[n]] += 1.0;
659 errtr[0] += 1.0;
660 jerr[n] = 1;
661 }
662 }
663 }
664 errtr[0] /= noob;
665 for (n = 1; n <= nclass; ++n) errtr[n] /= noobcl[n-1];
666 }
667
668
TestSetError(double * countts,int * jts,int * clts,int * jet,int ntest,int nclass,int nvote,double * errts,int labelts,int * nclts,double * cutoff)669 void TestSetError(double *countts, int *jts, int *clts, int *jet, int ntest,
670 int nclass, int nvote, double *errts,
671 int labelts, int *nclts, double *cutoff) {
672 int j, n, ntie;
673 double cmax, crit;
674
675 for (n = 0; n < ntest; ++n) countts[jts[n]-1 + n*nclass] += 1.0;
676
677 /* Prediction is the class with the maximum votes */
678 for (n = 0; n < ntest; ++n) {
679 cmax=0.0;
680 ntie = 1;
681 for (j = 0; j < nclass; ++j) {
682 crit = (countts[j + n*nclass] / nvote) / cutoff[j];
683 if (crit > cmax) {
684 jet[n] = j+1;
685 cmax = crit;
686 ntie = 1;
687 }
688 /* Break ties at random: */
689 if (crit == cmax) {
690 ntie++;
691 if (unif_rand() < 1.0 / ntie) {
692 jet[n] = j+1;
693 cmax = crit;
694 }
695 }
696 }
697 }
698 if (labelts) {
699 zeroDouble(errts, nclass + 1);
700 for (n = 0; n < ntest; ++n) {
701 if (jet[n] != clts[n]) {
702 errts[0] += 1.0;
703 errts[clts[n]] += 1.0;
704 }
705 }
706 errts[0] /= ntest;
707 for (n = 1; n <= nclass; ++n) errts[n] /= nclts[n-1];
708 }
709 }
710