1 /*
2  # This file is part of libkd.
3  # Licensed under a 3-clause BSD style license - see LICENSE
4  */
5 
6 #include <string.h>
7 #include <math.h>
8 #include <stdlib.h>
9 #include <assert.h>
10 
11 #include "os-features.h"
12 #include "dualtree_nearestneighbour.h"
13 #include "dualtree.h"
14 #include "mathutil.h"
15 
16 struct rs_params {
17     kdtree_t* xtree;
18     kdtree_t* ytree;
19 
20     anbool notself;
21 
22     double* node_nearest_d2;
23 
24     double d2;
25     double* nearest_d2;
26     int* nearest_ind;
27     int* count_in_range;
28 };
29 typedef struct rs_params rs_params;
30 
31 static anbool rs_within_range(void* params, kdtree_t* searchtree, int searchnode,
32                               kdtree_t* querytree, int querynode);
33 static void rs_handle_result(void* extra, kdtree_t* searchtree, int searchnode,
34                              kdtree_t* querytree, int querynode);
35 
dualtree_nearestneighbour(kdtree_t * xtree,kdtree_t * ytree,double maxdist2,double ** nearest_d2,int ** nearest_ind,int ** count_in_range,int notself)36 void dualtree_nearestneighbour(kdtree_t* xtree, kdtree_t* ytree, double maxdist2,
37                                double** nearest_d2, int** nearest_ind,
38                                int** count_in_range,
39                                int notself) {
40     int i, NY, NNY;
41 
42     // dual-tree search callback functions
43     dualtree_callbacks callbacks;
44     rs_params params;
45 
46     // These two inputs must be non-NULL (they are essential return values);
47     // but they may point to pointers that are NULL (indicating that the caller wants us to
48     // allocate and return new arrays).
49     assert(nearest_d2);
50     assert(nearest_ind);
51 
52     memset(&callbacks, 0, sizeof(dualtree_callbacks));
53     callbacks.decision = rs_within_range;
54     callbacks.decision_extra = &params;
55     callbacks.result = rs_handle_result;
56     callbacks.result_extra = &params;
57 
58     // set search params
59     NY = kdtree_n(ytree);
60     memset(&params, 0, sizeof(params));
61     params.xtree = xtree;
62     params.ytree = ytree;
63     params.notself = notself;
64     params.d2 = maxdist2;
65 
66     params.count_in_range = NULL;
67     if (count_in_range) {
68         if (!(*count_in_range)) {
69             *count_in_range = (int*)calloc(NY, sizeof(int));
70         }
71         params.count_in_range = *count_in_range;
72     }
73 
74     // were we given a d2 array?
75     if (*nearest_d2)
76         params.nearest_d2 = *nearest_d2;
77     else
78         params.nearest_d2 = malloc(NY * sizeof(double));
79 
80     if (maxdist2 == 0.0)
81         maxdist2 = HUGE_VAL;
82     for (i=0; i<NY; i++)
83         params.nearest_d2[i] = maxdist2;
84 
85     // were we given an ind array?
86     if (*nearest_ind)
87         params.nearest_ind = *nearest_ind;
88     else
89         params.nearest_ind = malloc(NY * sizeof(int));
90     for (i=0; i<NY; i++)
91         params.nearest_ind[i] = -1;
92 
93     NNY = kdtree_nnodes(ytree);
94     params.node_nearest_d2 = malloc(NNY * sizeof(double));
95     for (i=0; i<NNY; i++)
96         params.node_nearest_d2[i] = maxdist2;
97 
98     dualtree_search(xtree, ytree, &callbacks);
99 
100     // Return array addresses
101     *nearest_d2 = params.nearest_d2;
102     *nearest_ind = params.nearest_ind;
103     free(params.node_nearest_d2);
104 }
105 
rs_within_range(void * vparams,kdtree_t * xtree,int xnode,kdtree_t * ytree,int ynode)106 static anbool rs_within_range(void* vparams,
107                               kdtree_t* xtree, int xnode,
108                               kdtree_t* ytree, int ynode) {
109     rs_params* p = (rs_params*)vparams;
110     double maxd2;
111 
112     // count-in-range is actually more like rangesearch...
113     if (p->count_in_range) {
114         if (kdtree_node_node_mindist2_exceeds(xtree, xnode, ytree, ynode, p->d2))
115             return FALSE;
116         return TRUE;
117     }
118 
119     if (kdtree_node_node_mindist2_exceeds(xtree, xnode, ytree, ynode,
120                                           p->node_nearest_d2[ynode]))
121         return FALSE;
122 
123     maxd2 = kdtree_node_node_maxdist2(xtree, xnode, ytree, ynode);
124     if (maxd2 < p->node_nearest_d2[ynode]) {
125         // update this node and its children.
126         p->node_nearest_d2[ynode] = maxd2;
127         if (!KD_IS_LEAF(ytree, ynode)) {
128             int child = KD_CHILD_LEFT(ynode);
129             p->node_nearest_d2[child] = MIN(p->node_nearest_d2[child], maxd2);
130             child = KD_CHILD_RIGHT(ynode);
131             p->node_nearest_d2[child] = MIN(p->node_nearest_d2[child], maxd2);
132         }
133     }
134     return TRUE;
135 }
136 
137 /**
138  This callback gets called when we've reached a node in the Y tree and
139  a node in the X tree (one or both may be leaves), and it's time to
140  look at individual data points.
141  */
rs_handle_result(void * vparams,kdtree_t * xtree,int xnode,kdtree_t * ytree,int ynode)142 static void rs_handle_result(void* vparams,
143                              kdtree_t* xtree, int xnode,
144                              kdtree_t* ytree, int ynode) {
145     int xl, xr, yl, yr;
146     int x, y;
147     rs_params* p = (rs_params*)vparams;
148     int D = ytree->ndim;
149     double checkd2;
150 
151     xl = kdtree_left (xtree, xnode);
152     xr = kdtree_right(xtree, xnode);
153     yl = kdtree_left (ytree, ynode);
154     yr = kdtree_right(ytree, ynode);
155 
156     for (y=yl; y<=yr; y++) {
157         void* py = kdtree_get_data(ytree, y);
158 
159         if (p->count_in_range) {
160             checkd2 = p->d2;
161         } else {
162             p->nearest_d2[y] = MIN(p->nearest_d2[y], p->node_nearest_d2[ynode]);
163             checkd2 = p->nearest_d2[y];
164         }
165 
166         // check if we can eliminate the whole x node for this y point...
167         if (kdtree_node_point_mindist2_exceeds(xtree, xnode, py, checkd2))
168             continue;
169 
170         for (x=xl; x<=xr; x++) {
171             double d2;
172             void* px;
173             if (p->notself && (y == x))
174                 continue;
175             px = kdtree_get_data(xtree, x);
176             d2 = distsq(px, py, D);
177 
178             if (p->count_in_range) {
179                 if (d2 < p->d2) {
180                     p->count_in_range[y]++;
181                 }
182             }
183 
184             if (d2 > p->nearest_d2[y])
185                 continue;
186             p->nearest_d2[y] = d2;
187             p->nearest_ind[y] = x;
188         }
189     }
190 }
191 
192