1 /*******************************************************************
2 Copyright (C) 2001-7 Leo Breiman, Adele Cutler and Merck & Co., Inc.
3
4 This program is free software; you can redistribute it and/or
5 modify it under the terms of the GNU General Public License
6 as published by the Free Software Foundation; either version 2
7 of the License, or (at your option) any later version.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
13 *******************************************************************/
14
15 /******************************************************************
16 * buildtree and findbestsplit routines translated from Leo's
17 * original Fortran code.
18 *
19 * copyright 1999 by leo Breiman
20 * this is free software and can be used for any purpose.
21 * It comes with no guarantee.
22 *
23 ******************************************************************/
24 #include <Rmath.h>
25 #include <R.h>
26 #include "rf.h"
27
regTree(double * x,double * y,int mdim,int nsample,int * lDaughter,int * rDaughter,double * upper,double * avnode,int * nodestatus,int nrnodes,int * treeSize,int nthsize,int mtry,int * mbest,int * cat,double * tgini,int * varUsed)28 void regTree(double *x, double *y, int mdim, int nsample, int *lDaughter,
29 int *rDaughter,
30 double *upper, double *avnode, int *nodestatus, int nrnodes,
31 int *treeSize, int nthsize, int mtry, int *mbest, int *cat,
32 double *tgini, int *varUsed) {
33 int i, j, k, m, ncur, *jdex, *nodestart, *nodepop;
34 int ndstart, ndend, ndendl, nodecnt, jstat, msplit;
35 double d, ss, av, decsplit, ubest, sumnode;
36
37 nodestart = (int *) Calloc(nrnodes, int);
38 nodepop = (int *) Calloc(nrnodes, int);
39
40 /* initialize some arrays for the tree */
41 zeroInt(nodestatus, nrnodes);
42 zeroInt(nodestart, nrnodes);
43 zeroInt(nodepop, nrnodes);
44 zeroDouble(avnode, nrnodes);
45
46 jdex = (int *) Calloc(nsample, int);
47 for (i = 1; i <= nsample; ++i) jdex[i-1] = i;
48
49 ncur = 0;
50 nodestart[0] = 0;
51 nodepop[0] = nsample;
52 nodestatus[0] = NODE_TOSPLIT;
53
54 /* compute mean and sum of squares for Y */
55 av = 0.0;
56 ss = 0.0;
57 for (i = 0; i < nsample; ++i) {
58 d = y[jdex[i] - 1];
59 ss += i * (av - d) * (av - d) / (i + 1);
60 av = (i * av + d) / (i + 1);
61 }
62 avnode[0] = av;
63
64 /* start main loop */
65 for (k = 0; k < nrnodes - 2; ++k) {
66 if (k > ncur || ncur >= nrnodes - 2) break;
67 /* skip if the node is not to be split */
68 if (nodestatus[k] != NODE_TOSPLIT) continue;
69
70 #ifdef RF_DEBUG
71 Rprintf("regTree: k=%d, av=%f, ss=%f\n", k, av, ss);
72 #endif
73
74 /* initialize for next call to findbestsplit */
75 ndstart = nodestart[k];
76 ndend = ndstart + nodepop[k] - 1;
77 nodecnt = nodepop[k];
78 sumnode = nodecnt * avnode[k];
79 jstat = 0;
80 decsplit = 0.0;
81
82 #ifdef RF_DEBUG
83 Rprintf("before findBestSplit: ndstart=%d, ndend=%d, jstat=%d, decsplit=%f\n",
84 ndstart, ndend, jstat, decsplit);
85 #endif
86
87 findBestSplit(x, jdex, y, mdim, nsample, ndstart, ndend, &msplit,
88 &decsplit, &ubest, &ndendl, &jstat, mtry, sumnode,
89 nodecnt, cat);
90 #ifdef RF_DEBUG
91 Rprintf(" after findBestSplit: ndstart=%d, ndend=%d, jstat=%d, decsplit=%f, msplit=%d\n",
92 ndstart, ndend, jstat, decsplit, msplit);
93
94 #endif
95 if (jstat == 1) {
96 /* Node is terminal: Mark it as such and move on to the next. */
97 nodestatus[k] = NODE_TERMINAL;
98 continue;
99 }
100 /* Found the best split. */
101 mbest[k] = msplit;
102 varUsed[msplit - 1] = 1;
103 upper[k] = ubest;
104 tgini[msplit - 1] += decsplit;
105 nodestatus[k] = NODE_INTERIOR;
106
107 /* leftnode no.= ncur+1, rightnode no. = ncur+2. */
108 nodepop[ncur + 1] = ndendl - ndstart + 1;
109 nodepop[ncur + 2] = ndend - ndendl;
110 nodestart[ncur + 1] = ndstart;
111 nodestart[ncur + 2] = ndendl + 1;
112
113 /* compute mean and sum of squares for the left daughter node */
114 av = 0.0;
115 ss = 0.0;
116 for (j = ndstart; j <= ndendl; ++j) {
117 d = y[jdex[j]-1];
118 m = j - ndstart;
119 ss += m * (av - d) * (av - d) / (m + 1);
120 av = (m * av + d) / (m+1);
121 }
122 avnode[ncur+1] = av;
123 nodestatus[ncur+1] = NODE_TOSPLIT;
124 if (nodepop[ncur + 1] <= nthsize) {
125 nodestatus[ncur + 1] = NODE_TERMINAL;
126 }
127
128 /* compute mean and sum of squares for the right daughter node */
129 av = 0.0;
130 ss = 0.0;
131 for (j = ndendl + 1; j <= ndend; ++j) {
132 d = y[jdex[j]-1];
133 m = j - (ndendl + 1);
134 ss += m * (av - d) * (av - d) / (m + 1);
135 av = (m * av + d) / (m + 1);
136 }
137 avnode[ncur + 2] = av;
138 nodestatus[ncur + 2] = NODE_TOSPLIT;
139 if (nodepop[ncur + 2] <= nthsize) {
140 nodestatus[ncur + 2] = NODE_TERMINAL;
141 }
142
143 /* map the daughter nodes */
144 lDaughter[k] = ncur + 1 + 1;
145 rDaughter[k] = ncur + 2 + 1;
146 /* Augment the tree by two nodes. */
147 ncur += 2;
148 #ifdef RF_DEBUG
149 Rprintf(" after split: ldaughter=%d, rdaughter=%d, ncur=%d\n",
150 lDaughter[k], rDaughter[k], ncur);
151 #endif
152
153 }
154 *treeSize = nrnodes;
155 for (k = nrnodes - 1; k >= 0; --k) {
156 if (nodestatus[k] == 0) (*treeSize)--;
157 if (nodestatus[k] == NODE_TOSPLIT) {
158 nodestatus[k] = NODE_TERMINAL;
159 }
160 }
161 Free(nodestart);
162 Free(jdex);
163 Free(nodepop);
164 }
165
166 /*--------------------------------------------------------------*/
findBestSplit(double * x,int * jdex,double * y,int mdim,int nsample,int ndstart,int ndend,int * msplit,double * decsplit,double * ubest,int * ndendl,int * jstat,int mtry,double sumnode,int nodecnt,int * cat)167 void findBestSplit(double *x, int *jdex, double *y, int mdim, int nsample,
168 int ndstart, int ndend, int *msplit, double *decsplit,
169 double *ubest, int *ndendl, int *jstat, int mtry,
170 double sumnode, int nodecnt, int *cat) {
171 int last, ncat[MAX_CAT], icat[MAX_CAT], lc, nl, nr, npopl, npopr, tieVar, tieVal;
172 int i, j, kv, l, *mind, *ncase;
173 double *xt, *ut, *v, *yl, sumcat[MAX_CAT], avcat[MAX_CAT], tavcat[MAX_CAT], ubestt;
174 double crit, critmax, critvar, suml, sumr, d, critParent;
175
176 ut = (double *) Calloc(nsample, double);
177 xt = (double *) Calloc(nsample, double);
178 v = (double *) Calloc(nsample, double);
179 yl = (double *) Calloc(nsample, double);
180 mind = (int *) Calloc(mdim, int);
181 ncase = (int *) Calloc(nsample, int);
182 zeroDouble(avcat, MAX_CAT);
183 zeroDouble(tavcat, MAX_CAT);
184
185 /* START BIG LOOP */
186 *msplit = -1;
187 *decsplit = 0.0;
188 critmax = 0.0;
189 ubestt = 0.0;
190 for (i=0; i < mdim; ++i) mind[i] = i;
191
192 last = mdim - 1;
193 tieVar = 1;
194 for (i = 0; i < mtry; ++i) {
195 critvar = 0.0;
196 j = (int) (unif_rand() * (last+1));
197 kv = mind[j];
198 swapInt(mind[j], mind[last]);
199 last--;
200
201 lc = cat[kv];
202 if (lc == 1) {
203 /* numeric variable */
204 for (j = ndstart; j <= ndend; ++j) {
205 xt[j] = x[kv + (jdex[j] - 1) * mdim];
206 yl[j] = y[jdex[j] - 1];
207 }
208 } else {
209 /* categorical variable */
210 zeroInt(ncat, MAX_CAT);
211 zeroDouble(sumcat, MAX_CAT);
212 for (j = ndstart; j <= ndend; ++j) {
213 l = (int) x[kv + (jdex[j] - 1) * mdim];
214 sumcat[l - 1] += y[jdex[j] - 1];
215 ncat[l - 1] ++;
216 }
217 /* Compute means of Y by category. */
218 for (j = 0; j < lc; ++j) {
219 avcat[j] = ncat[j] ? sumcat[j] / ncat[j] : 0.0;
220 }
221 /* Make the category mean the `pseudo' X data. */
222 for (j = 0; j < nsample; ++j) {
223 xt[j] = avcat[(int) x[kv + (jdex[j] - 1) * mdim] - 1];
224 yl[j] = y[jdex[j] - 1];
225 }
226 }
227 /* copy the x data in this node. */
228 for (j = ndstart; j <= ndend; ++j) v[j] = xt[j];
229 for (j = 1; j <= nsample; ++j) ncase[j - 1] = j;
230 R_qsort_I(v, ncase, ndstart + 1, ndend + 1);
231 if (v[ndstart] >= v[ndend]) continue;
232 /* ncase(n)=case number of v nth from bottom */
233 /* Start from the right and search to the left. */
234 critParent = sumnode * sumnode / nodecnt;
235 suml = 0.0;
236 sumr = sumnode;
237 npopl = 0;
238 npopr = nodecnt;
239 crit = 0.0;
240 tieVal = 1;
241 /* Search through the "gaps" in the x-variable. */
242 for (j = ndstart; j <= ndend - 1; ++j) {
243 d = yl[ncase[j] - 1];
244 suml += d;
245 sumr -= d;
246 npopl++;
247 npopr--;
248 if (v[j] < v[j+1]) {
249 crit = (suml * suml / npopl) + (sumr * sumr / npopr) - critParent;
250 if (crit > critvar) {
251 ubestt = (v[j] + v[j+1]) / 2.0;
252 critvar = crit;
253 tieVal = 1;
254 }
255 if (crit == critvar) {
256 tieVal++;
257 if (unif_rand() < 1.0 / tieVal) {
258 ubestt = (v[j] + v[j+1]) / 2.0;
259 critvar = crit;
260 }
261 }
262 }
263 }
264 if (critvar > critmax) {
265 *ubest = ubestt;
266 *msplit = kv + 1;
267 critmax = critvar;
268 for (j = ndstart; j <= ndend; ++j) {
269 ut[j] = xt[j];
270 }
271 if (cat[kv] > 1) {
272 for (j = 0; j < cat[kv]; ++j) tavcat[j] = avcat[j];
273 }
274 tieVar = 1;
275 }
276 if (critvar == critmax) {
277 tieVar++;
278 if (unif_rand() < 1.0 / tieVar) {
279 *ubest = ubestt;
280 *msplit = kv + 1;
281 critmax = critvar;
282 for (j = ndstart; j <= ndend; ++j) {
283 ut[j] = xt[j];
284 }
285 if (cat[kv] > 1) {
286 for (j = 0; j < cat[kv]; ++j) tavcat[j] = avcat[j];
287 }
288 }
289 }
290
291 }
292 *decsplit = critmax;
293
294 /* If best split can not be found, set to terminal node and return. */
295 if (*msplit != -1) {
296 nl = ndstart;
297 for (j = ndstart; j <= ndend; ++j) {
298 if (ut[j] <= *ubest) {
299 nl++;
300 ncase[nl-1] = jdex[j];
301 }
302 }
303 *ndendl = imax2(nl - 1, ndstart);
304 nr = *ndendl + 1;
305 for (j = ndstart; j <= ndend; ++j) {
306 if (ut[j] > *ubest) {
307 if (nr >= nsample) break;
308 nr++;
309 ncase[nr - 1] = jdex[j];
310 }
311 }
312 if (*ndendl >= ndend) *ndendl = ndend - 1;
313 for (j = ndstart; j <= ndend; ++j) jdex[j] = ncase[j];
314
315 lc = cat[*msplit - 1];
316 if (lc > 1) {
317 for (j = 0; j < lc; ++j) {
318 icat[j] = (tavcat[j] < *ubest) ? 1 : 0;
319 }
320 *ubest = pack(lc, icat);
321 }
322 } else *jstat = 1;
323
324 Free(ncase);
325 Free(mind);
326 Free(v);
327 Free(yl);
328 Free(xt);
329 Free(ut);
330 }
331
332 /*====================================================================*/
predictRegTree(double * x,int nsample,int mdim,int * lDaughter,int * rDaughter,int * nodestatus,double * ypred,double * split,double * nodepred,int * splitVar,int treeSize,int * cat,int maxcat,int * nodex)333 void predictRegTree(double *x, int nsample, int mdim,
334 int *lDaughter, int *rDaughter, int *nodestatus,
335 double *ypred, double *split, double *nodepred,
336 int *splitVar, int treeSize, int *cat, int maxcat,
337 int *nodex) {
338 int i, j, k, m, *cbestsplit;
339 double dpack;
340
341 /* decode the categorical splits */
342 if (maxcat > 1) {
343 cbestsplit = (int *) Calloc(maxcat * treeSize, int);
344 zeroInt(cbestsplit, maxcat * treeSize);
345 for (i = 0; i < treeSize; ++i) {
346 if (nodestatus[i] != NODE_TERMINAL && cat[splitVar[i] - 1] > 1) {
347 dpack = split[i];
348 /* unpack `npack' into bits */
349 /* unpack(dpack, maxcat, cbestsplit + i * maxcat); */
350 for (j = 0; j < cat[splitVar[i] - 1]; ++j) {
351 cbestsplit[j + i*maxcat] = ((unsigned long) dpack & 1) ? 1 : 0;
352 dpack = dpack / 2.0 ;
353 /* cbestsplit[j + i*maxcat] = npack & 1; */
354 }
355 }
356 }
357 }
358
359 for (i = 0; i < nsample; ++i) {
360 k = 0;
361 while (nodestatus[k] != NODE_TERMINAL) { /* go down the tree */
362 m = splitVar[k] - 1;
363 if (cat[m] == 1) {
364 k = (x[m + i*mdim] <= split[k]) ?
365 lDaughter[k] - 1 : rDaughter[k] - 1;
366 } else {
367 /* Split by a categorical predictor */
368 k = cbestsplit[(int) x[m + i * mdim] - 1 + k * maxcat] ?
369 lDaughter[k] - 1 : rDaughter[k] - 1;
370 }
371 }
372 /* terminal node: assign prediction and move on to next */
373 ypred[i] = nodepred[k];
374 nodex[i] = k + 1;
375 }
376 if (maxcat > 1) Free(cbestsplit);
377 }
378