1 #include <vector>
2 #include <stdio.h>
3 #include <math.h>
4 #include <iostream>
5 #include <fstream>
6 #include <stdint.h>
7 #include "c_utils.hpp"
8 #include <float.h>
9 
isEqual(double f1,double f2)10 bool isEqual(double f1, double f2) {
11   return (fabs(f1 - f2) <= FLT_EPSILON);
12 }
13 
max_pts(double * pts,uint64_t n,uint32_t m)14 double* max_pts(double *pts, uint64_t n, uint32_t m)
15 {
16   double* max = (double*)std::malloc(m*sizeof(double));
17   uint32_t d;
18   for (d = 0; d < m; d++) max[d] = -DBL_MAX; // pts[d];
19   for (uint64_t i = 0; i < n; i++) {
20     for (d = 0; d < m; d++) {
21       if (pts[m*i + d] > max[d])
22         max[d] = pts[m*i + d];
23     }
24   }
25   return max;
26 }
27 
min_pts(double * pts,uint64_t n,uint32_t m)28 double* min_pts(double *pts, uint64_t n, uint32_t m)
29 {
30   double* min = (double*)std::malloc(m*sizeof(double));
31   uint32_t d;
32   for (d = 0; d < m; d++) min[d] = DBL_MAX; // pts[d];
33   for (uint64_t i = 0; i < n; i++) {
34     for (d = 0; d < m; d++) {
35       if (pts[m*i + d] < min[d])
36         min[d] = pts[m*i + d];
37     }
38   }
39   return min;
40 }
41 
argmax_pts_dim(double * pts,uint64_t * idx,uint32_t m,uint32_t d,uint64_t Lidx,uint64_t Ridx)42 uint64_t argmax_pts_dim(double *pts, uint64_t *idx,
43 			uint32_t m, uint32_t d,
44 			uint64_t Lidx, uint64_t Ridx)
45 {
46   double max = -DBL_MAX;
47   uint64_t idx_max = Lidx;
48   for (uint64_t i = Lidx; i <= Ridx; i++) {
49     if (pts[m*idx[i] + d] > max) {
50       max = pts[m*idx[i] + d];
51       idx_max = i;
52     }
53   }
54   return idx_max;
55 }
56 
argmin_pts_dim(double * pts,uint64_t * idx,uint32_t m,uint32_t d,uint64_t Lidx,uint64_t Ridx)57 uint64_t argmin_pts_dim(double *pts, uint64_t *idx,
58 			uint32_t m, uint32_t d,
59 			uint64_t Lidx, uint64_t Ridx)
60 {
61   double min = DBL_MAX;
62   uint64_t idx_min = Lidx;
63   for (uint64_t i = Lidx; i <= Ridx; i++) {
64     if (pts[m*idx[i] + d] < min) {
65       min = pts[m*idx[i] + d];
66       idx_min = i;
67     }
68   }
69   return idx_min;
70 }
71 
72 // http://www.comp.dit.ie/rlawlor/Alg_DS/sorting/quickSort.c
quickSort(double * pts,uint64_t * idx,uint32_t ndim,uint32_t d,int64_t l,int64_t r)73 void quickSort(double *pts, uint64_t *idx,
74                uint32_t ndim, uint32_t d,
75                int64_t l, int64_t r)
76 {
77   int64_t j;
78   if( l < r )
79     {
80       // divide and conquer
81       j = partition(pts, idx, ndim, d, l, r, (l+r)/2);
82       quickSort(pts, idx, ndim, d, l, j-1);
83       quickSort(pts, idx, ndim, d, j+1, r);
84     }
85 }
86 
insertSort(double * pts,uint64_t * idx,uint32_t ndim,uint32_t d,int64_t l,int64_t r)87 void insertSort(double *pts, uint64_t *idx,
88                 uint32_t ndim, uint32_t d,
89                 int64_t l, int64_t r)
90 {
91   int64_t i, j;
92   uint64_t t;
93 
94   if (r <= l) return;
95   for (i = l+1; i <= r; i++) {
96     t = idx[i];
97     j = i - 1;
98     while ((j >= l) && (pts[ndim*idx[j]+d] > pts[ndim*t+d])) {
99       idx[j+1] = idx[j];
100       j--;
101     }
102     idx[j+1] = t;
103   }
104 }
105 
pivot(double * pts,uint64_t * idx,uint32_t ndim,uint32_t d,int64_t l,int64_t r)106 int64_t pivot(double *pts, uint64_t *idx,
107               uint32_t ndim, uint32_t d,
108               int64_t l, int64_t r)
109 {
110   if (r < l) {
111     return -1;
112   } else if (r == l) {
113     return l;
114   } else if ((r - l) < 5) {
115     insertSort(pts, idx, ndim, d, l, r);
116     return (l+r)/2;
117   }
118 
119   int64_t i, subr, m5;
120   uint64_t t;
121   int64_t nsub = 0;
122   for (i = l; i <= r; i+=5) {
123     subr = i + 4;
124     if (subr > r) subr = r;
125 
126     insertSort(pts, idx, ndim, d, i, subr);
127     m5 = (i+subr)/2;
128     t = idx[m5]; idx[m5] = idx[l + nsub]; idx[l + nsub] = t;
129 
130     nsub++;
131   }
132   return pivot(pts, idx, ndim, d, l, l+nsub-1);
133   // return select(pts, idx, ndim, d, l, l+nsub-1, (nsub/2)+(nsub%2));
134 }
135 
partition_given_pivot(double * pts,uint64_t * idx,uint32_t ndim,uint32_t d,int64_t l,int64_t r,double pivot)136 int64_t partition_given_pivot(double *pts, uint64_t *idx,
137 			      uint32_t ndim, uint32_t d,
138 			      int64_t l, int64_t r, double pivot) {
139   // If all less than pivot, j will remain r
140   // If all greater than pivot, j will be l-1
141   if (r < l)
142     return -1;
143   int64_t i, j, tp = -1;
144   uint64_t t;
145   for (i = l, j = r; i <= j; ) {
146     if ((pts[ndim*idx[i]+d] > pivot) && (pts[ndim*idx[j]+d] <= pivot)) {
147       t = idx[i]; idx[i] = idx[j]; idx[j] = t;
148     }
149     if (isEqual(pts[ndim*idx[i]+d], pivot)) tp = i;
150     // if (pts[ndim*idx[i]+d] == pivot) tp = i;
151     if (pts[ndim*idx[i]+d] <= pivot) i++;
152     if (pts[ndim*idx[j]+d] > pivot) j--;
153   }
154   if ((tp >= 0) && (tp != j)) {
155     t = idx[tp]; idx[tp] = idx[j]; idx[j] = t;
156   }
157 
158   return j;
159 }
160 
partition(double * pts,uint64_t * idx,uint32_t ndim,uint32_t d,int64_t l,int64_t r,int64_t p)161 int64_t partition(double *pts, uint64_t *idx,
162                   uint32_t ndim, uint32_t d,
163                   int64_t l, int64_t r, int64_t p)
164 {
165   double pivot;
166   int64_t j;
167   uint64_t t;
168   if (r < l)
169     return -1;
170   pivot = pts[ndim*idx[p]+d];
171   t = idx[p]; idx[p] = idx[l]; idx[l] = t;
172 
173   j = partition_given_pivot(pts, idx, ndim, d, l+1, r, pivot);
174 
175   t = idx[l]; idx[l] = idx[j]; idx[j] = t;
176 
177   return j;
178 }
179 
180 // https://en.wikipedia.org/wiki/Median_of_medians
select(double * pts,uint64_t * idx,uint32_t ndim,uint32_t d,int64_t l0,int64_t r0,int64_t n)181 int64_t select(double *pts, uint64_t *idx,
182                uint32_t ndim, uint32_t d,
183                int64_t l0, int64_t r0, int64_t n)
184 {
185   int64_t p;
186   int64_t l = l0, r = r0;
187 
188   while ( 1 ) {
189     if (l == r) return l;
190 
191     p = pivot(pts, idx, ndim, d, l, r);
192     p = partition(pts, idx, ndim, d, l, r, p);
193     if (p < 0)
194       return -1;
195     else if (n == (p-l0+1)) {
196       return p;
197     } else if (n < (p-l0+1)) {
198       r = p - 1;
199     } else {
200       l = p + 1;
201     }
202   }
203 }
204 
split(double * all_pts,uint64_t * all_idx,uint64_t Lidx,uint64_t n,uint32_t ndim,double * mins,double * maxes,int64_t & split_idx,double & split_val,bool use_sliding_midpoint)205 uint32_t split(double *all_pts, uint64_t *all_idx,
206                uint64_t Lidx, uint64_t n, uint32_t ndim,
207                double *mins, double *maxes,
208                int64_t &split_idx, double &split_val,
209 	       bool use_sliding_midpoint) {
210   // Return immediately if variables empty
211   if ((n == 0) || (ndim == 0)) {
212     split_idx = -1;
213     split_val = 0.0;
214     return 0;
215   }
216 
217   // Find dimension to split along
218   uint32_t dmax, d;
219   dmax = 0;
220   for (d = 1; d < ndim; d++)
221     if ((maxes[d]-mins[d]) > (maxes[dmax]-mins[dmax]))
222       dmax = d;
223   if (maxes[dmax] == mins[dmax]) {
224     // all points singular
225     return ndim;
226   }
227 
228   if (use_sliding_midpoint) {
229     // Split at middle, then slide midpoint as necessary
230     split_val = (mins[dmax] + maxes[dmax])/2.0;
231     split_idx = partition_given_pivot(all_pts, all_idx, ndim, dmax,
232 				      Lidx, Lidx+n-1, split_val);
233     if (split_idx == (int64_t)(Lidx-1)) {
234       uint64_t t;
235       split_idx = argmin_pts_dim(all_pts, all_idx, ndim, dmax, Lidx, Lidx+n-1);
236       t = all_idx[split_idx]; all_idx[split_idx] = all_idx[Lidx]; all_idx[Lidx] = t;
237       split_idx = Lidx;
238       split_val = all_pts[ndim*all_idx[split_idx] + dmax];
239     } else if (split_idx == (int64_t)(Lidx+n-1)) {
240       uint64_t t;
241       split_idx = argmax_pts_dim(all_pts, all_idx, ndim, dmax, Lidx, Lidx+n-1);
242       t = all_idx[split_idx]; all_idx[split_idx] = all_idx[Lidx+n-1]; all_idx[Lidx+n-1] = t;
243       split_idx = Lidx+n-2;
244       split_val = all_pts[ndim*all_idx[split_idx] + dmax];
245     }
246   } else {
247     // Find median along dimension
248     int64_t nsel = (n/2) + (n%2);
249     split_idx = select(all_pts, all_idx, ndim, dmax, Lidx, Lidx+n-1, nsel);
250     split_val = all_pts[ndim*all_idx[split_idx] + dmax];
251   }
252 
253   return dmax;
254 }
255