1 /*******************************************************************
2    Copyright (C) 2001-7 Leo Breiman, Adele Cutler and Merck & Co., Inc.
3 
4    This program is free software; you can redistribute it and/or
5    modify it under the terms of the GNU General Public License
6    as published by the Free Software Foundation; either version 2
7    of the License, or (at your option) any later version.
8 
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13 *******************************************************************/
14 
15 /******************************************************************
16  * buildtree and findbestsplit routines translated from Leo's
17  * original Fortran code.
18  *
19  *      copyright 1999 by leo Breiman
20  *      this is free software and can be used for any purpose.
21  *      It comes with no guarantee.
22  *
23  ******************************************************************/
24 #include <Rmath.h>
25 #include <R.h>
26 #include "rf.h"
27 
regTree(double * x,double * y,int mdim,int nsample,int * lDaughter,int * rDaughter,double * upper,double * avnode,int * nodestatus,int nrnodes,int * treeSize,int nthsize,int mtry,int * mbest,int * cat,double * tgini,int * varUsed)28 void regTree(double *x, double *y, int mdim, int nsample, int *lDaughter,
29              int *rDaughter,
30              double *upper, double *avnode, int *nodestatus, int nrnodes,
31              int *treeSize, int nthsize, int mtry, int *mbest, int *cat,
32              double *tgini, int *varUsed) {
33   int i, j, k, m, ncur, *jdex, *nodestart, *nodepop;
34   int ndstart, ndend, ndendl, nodecnt, jstat, msplit;
35   double d, ss, av, decsplit, ubest, sumnode;
36 
37   nodestart = (int *) Calloc(nrnodes, int);
38   nodepop   = (int *) Calloc(nrnodes, int);
39 
40   /* initialize some arrays for the tree */
41   zeroInt(nodestatus, nrnodes);
42   zeroInt(nodestart, nrnodes);
43   zeroInt(nodepop, nrnodes);
44   zeroDouble(avnode, nrnodes);
45 
46   jdex = (int *) Calloc(nsample, int);
47   for (i = 1; i <= nsample; ++i) jdex[i-1] = i;
48 
49   ncur = 0;
50   nodestart[0] = 0;
51   nodepop[0] = nsample;
52   nodestatus[0] = NODE_TOSPLIT;
53 
54   /* compute mean and sum of squares for Y */
55   av = 0.0;
56   ss = 0.0;
57   for (i = 0; i < nsample; ++i) {
58     d = y[jdex[i] - 1];
59     ss += i * (av - d) * (av - d) / (i + 1);
60     av = (i * av + d) / (i + 1);
61   }
62   avnode[0] = av;
63 
64   /* start main loop */
65   for (k = 0; k < nrnodes - 2; ++k) {
66     if (k > ncur || ncur >= nrnodes - 2) break;
67     /* skip if the node is not to be split */
68     if (nodestatus[k] != NODE_TOSPLIT) continue;
69 
70 #ifdef RF_DEBUG
71     Rprintf("regTree: k=%d, av=%f, ss=%f\n", k, av, ss);
72 #endif
73 
74     /* initialize for next call to findbestsplit */
75     ndstart = nodestart[k];
76     ndend = ndstart + nodepop[k] - 1;
77     nodecnt = nodepop[k];
78     sumnode = nodecnt * avnode[k];
79     jstat = 0;
80     decsplit = 0.0;
81 
82 #ifdef RF_DEBUG
83     Rprintf("before findBestSplit: ndstart=%d, ndend=%d, jstat=%d, decsplit=%f\n",
84             ndstart, ndend, jstat, decsplit);
85 #endif
86 
87     findBestSplit(x, jdex, y, mdim, nsample, ndstart, ndend, &msplit,
88                   &decsplit, &ubest, &ndendl, &jstat, mtry, sumnode,
89                   nodecnt, cat);
90 #ifdef RF_DEBUG
91     Rprintf(" after findBestSplit: ndstart=%d, ndend=%d, jstat=%d, decsplit=%f, msplit=%d\n",
92             ndstart, ndend, jstat, decsplit, msplit);
93 
94 #endif
95     if (jstat == 1) {
96       /* Node is terminal: Mark it as such and move on to the next. */
97       nodestatus[k] = NODE_TERMINAL;
98       continue;
99     }
100     /* Found the best split. */
101     mbest[k] = msplit;
102     varUsed[msplit - 1] = 1;
103     upper[k] = ubest;
104     tgini[msplit - 1] += decsplit;
105     nodestatus[k] = NODE_INTERIOR;
106 
107     /* leftnode no.= ncur+1, rightnode no. = ncur+2. */
108     nodepop[ncur + 1] = ndendl - ndstart + 1;
109     nodepop[ncur + 2] = ndend - ndendl;
110     nodestart[ncur + 1] = ndstart;
111     nodestart[ncur + 2] = ndendl + 1;
112 
113     /* compute mean and sum of squares for the left daughter node */
114     av = 0.0;
115     ss = 0.0;
116     for (j = ndstart; j <= ndendl; ++j) {
117       d = y[jdex[j]-1];
118       m = j - ndstart;
119       ss += m * (av - d) * (av - d) / (m + 1);
120       av = (m * av + d) / (m+1);
121     }
122     avnode[ncur+1] = av;
123     nodestatus[ncur+1] = NODE_TOSPLIT;
124     if (nodepop[ncur + 1] <= nthsize) {
125       nodestatus[ncur + 1] = NODE_TERMINAL;
126     }
127 
128     /* compute mean and sum of squares for the right daughter node */
129     av = 0.0;
130     ss = 0.0;
131     for (j = ndendl + 1; j <= ndend; ++j) {
132       d = y[jdex[j]-1];
133       m = j - (ndendl + 1);
134       ss += m * (av - d) * (av - d) / (m + 1);
135       av = (m * av + d) / (m + 1);
136     }
137     avnode[ncur + 2] = av;
138     nodestatus[ncur + 2] = NODE_TOSPLIT;
139     if (nodepop[ncur + 2] <= nthsize) {
140       nodestatus[ncur + 2] = NODE_TERMINAL;
141     }
142 
143     /* map the daughter nodes */
144     lDaughter[k] = ncur + 1 + 1;
145     rDaughter[k] = ncur + 2 + 1;
146     /* Augment the tree by two nodes. */
147     ncur += 2;
148 #ifdef RF_DEBUG
149     Rprintf(" after split: ldaughter=%d, rdaughter=%d, ncur=%d\n",
150             lDaughter[k], rDaughter[k], ncur);
151 #endif
152 
153   }
154   *treeSize = nrnodes;
155   for (k = nrnodes - 1; k >= 0; --k) {
156     if (nodestatus[k] == 0) (*treeSize)--;
157     if (nodestatus[k] == NODE_TOSPLIT) {
158       nodestatus[k] = NODE_TERMINAL;
159     }
160   }
161   Free(nodestart);
162   Free(jdex);
163   Free(nodepop);
164 }
165 
166 /*--------------------------------------------------------------*/
findBestSplit(double * x,int * jdex,double * y,int mdim,int nsample,int ndstart,int ndend,int * msplit,double * decsplit,double * ubest,int * ndendl,int * jstat,int mtry,double sumnode,int nodecnt,int * cat)167 void findBestSplit(double *x, int *jdex, double *y, int mdim, int nsample,
168                    int ndstart, int ndend, int *msplit, double *decsplit,
169                    double *ubest, int *ndendl, int *jstat, int mtry,
170                    double sumnode, int nodecnt, int *cat) {
171   int last, ncat[MAX_CAT], icat[MAX_CAT], lc, nl, nr, npopl, npopr, tieVar, tieVal;
172   int i, j, kv, l, *mind, *ncase;
173   double *xt, *ut, *v, *yl, sumcat[MAX_CAT], avcat[MAX_CAT], tavcat[MAX_CAT], ubestt;
174   double crit, critmax, critvar, suml, sumr, d, critParent;
175 
176   ut = (double *) Calloc(nsample, double);
177   xt = (double *) Calloc(nsample, double);
178   v  = (double *) Calloc(nsample, double);
179   yl = (double *) Calloc(nsample, double);
180   mind  = (int *) Calloc(mdim, int);
181   ncase = (int *) Calloc(nsample, int);
182   zeroDouble(avcat, MAX_CAT);
183   zeroDouble(tavcat, MAX_CAT);
184 
185   /* START BIG LOOP */
186   *msplit = -1;
187   *decsplit = 0.0;
188   critmax = 0.0;
189   ubestt = 0.0;
190   for (i=0; i < mdim; ++i) mind[i] = i;
191 
192   last = mdim - 1;
193   tieVar = 1;
194   for (i = 0; i < mtry; ++i) {
195     critvar = 0.0;
196     j = (int) (unif_rand() * (last+1));
197     kv = mind[j];
198     swapInt(mind[j], mind[last]);
199     last--;
200 
201     lc = cat[kv];
202     if (lc == 1) {
203       /* numeric variable */
204       for (j = ndstart; j <= ndend; ++j) {
205         xt[j] = x[kv + (jdex[j] - 1) * mdim];
206         yl[j] = y[jdex[j] - 1];
207       }
208     } else {
209       /* categorical variable */
210       zeroInt(ncat, MAX_CAT);
211       zeroDouble(sumcat, MAX_CAT);
212       for (j = ndstart; j <= ndend; ++j) {
213         l = (int) x[kv + (jdex[j] - 1) * mdim];
214         sumcat[l - 1] += y[jdex[j] - 1];
215         ncat[l - 1] ++;
216       }
217       /* Compute means of Y by category. */
218       for (j = 0; j < lc; ++j) {
219         avcat[j] = ncat[j] ? sumcat[j] / ncat[j] : 0.0;
220       }
221       /* Make the category mean the `pseudo' X data. */
222       for (j = 0; j < nsample; ++j) {
223         xt[j] = avcat[(int) x[kv + (jdex[j] - 1) * mdim] - 1];
224         yl[j] = y[jdex[j] - 1];
225       }
226     }
227     /* copy the x data in this node. */
228     for (j = ndstart; j <= ndend; ++j) v[j] = xt[j];
229     for (j = 1; j <= nsample; ++j) ncase[j - 1] = j;
230     R_qsort_I(v, ncase, ndstart + 1, ndend + 1);
231     if (v[ndstart] >= v[ndend]) continue;
232     /* ncase(n)=case number of v nth from bottom */
233     /* Start from the right and search to the left. */
234     critParent = sumnode * sumnode / nodecnt;
235     suml = 0.0;
236     sumr = sumnode;
237     npopl = 0;
238     npopr = nodecnt;
239     crit = 0.0;
240     tieVal = 1;
241     /* Search through the "gaps" in the x-variable. */
242     for (j = ndstart; j <= ndend - 1; ++j) {
243       d = yl[ncase[j] - 1];
244       suml += d;
245       sumr -= d;
246       npopl++;
247       npopr--;
248       if (v[j] < v[j+1]) {
249         crit = (suml * suml / npopl) + (sumr * sumr / npopr) - critParent;
250         if (crit > critvar) {
251           ubestt = (v[j] + v[j+1]) / 2.0;
252           critvar = crit;
253           tieVal = 1;
254         }
255         if (crit == critvar) {
256           tieVal++;
257           if (unif_rand() < 1.0 / tieVal) {
258             ubestt = (v[j] + v[j+1]) / 2.0;
259             critvar = crit;
260           }
261         }
262       }
263     }
264     if (critvar > critmax) {
265       *ubest = ubestt;
266       *msplit = kv + 1;
267       critmax = critvar;
268       for (j = ndstart; j <= ndend; ++j) {
269         ut[j] = xt[j];
270       }
271       if (cat[kv] > 1) {
272         for (j = 0; j < cat[kv]; ++j) tavcat[j] = avcat[j];
273       }
274       tieVar = 1;
275     }
276     if (critvar == critmax) {
277       tieVar++;
278       if (unif_rand() < 1.0 / tieVar) {
279         *ubest = ubestt;
280         *msplit = kv + 1;
281         critmax = critvar;
282         for (j = ndstart; j <= ndend; ++j) {
283           ut[j] = xt[j];
284         }
285         if (cat[kv] > 1) {
286           for (j = 0; j < cat[kv]; ++j) tavcat[j] = avcat[j];
287         }
288       }
289     }
290 
291   }
292   *decsplit = critmax;
293 
294   /* If best split can not be found, set to terminal node and return. */
295   if (*msplit != -1) {
296     nl = ndstart;
297     for (j = ndstart; j <= ndend; ++j) {
298       if (ut[j] <= *ubest) {
299         nl++;
300         ncase[nl-1] = jdex[j];
301       }
302     }
303     *ndendl = imax2(nl - 1, ndstart);
304     nr = *ndendl + 1;
305     for (j = ndstart; j <= ndend; ++j) {
306       if (ut[j] > *ubest) {
307         if (nr >= nsample) break;
308         nr++;
309         ncase[nr - 1] = jdex[j];
310       }
311     }
312     if (*ndendl >= ndend) *ndendl = ndend - 1;
313     for (j = ndstart; j <= ndend; ++j) jdex[j] = ncase[j];
314 
315     lc = cat[*msplit - 1];
316     if (lc > 1) {
317       for (j = 0; j < lc; ++j) {
318         icat[j] = (tavcat[j] < *ubest) ? 1 : 0;
319       }
320       *ubest = pack(lc, icat);
321     }
322   } else *jstat = 1;
323 
324   Free(ncase);
325   Free(mind);
326   Free(v);
327   Free(yl);
328   Free(xt);
329   Free(ut);
330 }
331 
332 /*====================================================================*/
predictRegTree(double * x,int nsample,int mdim,int * lDaughter,int * rDaughter,int * nodestatus,double * ypred,double * split,double * nodepred,int * splitVar,int treeSize,int * cat,int maxcat,int * nodex)333 void predictRegTree(double *x, int nsample, int mdim,
334                     int *lDaughter, int *rDaughter, int *nodestatus,
335                     double *ypred, double *split, double *nodepred,
336                     int *splitVar, int treeSize, int *cat, int maxcat,
337                     int *nodex) {
338   int i, j, k, m, *cbestsplit;
339   double dpack;
340 
341   /* decode the categorical splits */
342   if (maxcat > 1) {
343     cbestsplit = (int *) Calloc(maxcat * treeSize, int);
344     zeroInt(cbestsplit, maxcat * treeSize);
345     for (i = 0; i < treeSize; ++i) {
346       if (nodestatus[i] != NODE_TERMINAL && cat[splitVar[i] - 1] > 1) {
347         dpack = split[i];
348         /* unpack `npack' into bits */
349         /* unpack(dpack, maxcat, cbestsplit + i * maxcat); */
350         for (j = 0; j < cat[splitVar[i] - 1]; ++j) {
351           cbestsplit[j + i*maxcat] = ((unsigned long) dpack & 1) ? 1 : 0;
352           dpack = dpack / 2.0 ;
353           /* cbestsplit[j + i*maxcat] = npack & 1; */
354         }
355       }
356     }
357   }
358 
359   for (i = 0; i < nsample; ++i) {
360     k = 0;
361     while (nodestatus[k] != NODE_TERMINAL) { /* go down the tree */
362           m = splitVar[k] - 1;
363       if (cat[m] == 1) {
364         k = (x[m + i*mdim] <= split[k]) ?
365         lDaughter[k] - 1 : rDaughter[k] - 1;
366       } else {
367         /* Split by a categorical predictor */
368         k = cbestsplit[(int) x[m + i * mdim] - 1 + k * maxcat] ?
369         lDaughter[k] - 1 : rDaughter[k] - 1;
370       }
371     }
372     /* terminal node: assign prediction and move on to next */
373     ypred[i] = nodepred[k];
374     nodex[i] = k + 1;
375   }
376   if (maxcat > 1) Free(cbestsplit);
377 }
378