1 // [[Rcpp::depends(RcppProgress)]]
2 #include <progress.hpp>
3 #include <Rcpp.h>
4 using namespace Rcpp;
5
6 // [[Rcpp::plugins(cpp11)]]
7
8 // [[Rcpp::export]]
nn_matchC(const IntegerMatrix & mm_,const IntegerVector & treat_,const IntegerVector & ord_,const IntegerVector & ratio,const int & max_rat,const LogicalVector & discarded,const int & reuse_max,const Nullable<NumericVector> & distance_=R_NilValue,const Nullable<NumericMatrix> & distance_mat_=R_NilValue,const Nullable<IntegerVector> & exact_=R_NilValue,const Nullable<double> & caliper_dist_=R_NilValue,const Nullable<NumericVector> & caliper_covs_=R_NilValue,const Nullable<NumericMatrix> & calcovs_covs_mat_=R_NilValue,const Nullable<NumericMatrix> & mah_covs_=R_NilValue,const Nullable<IntegerMatrix> & antiexact_covs_=R_NilValue,const bool & disl_prog=false)9 IntegerMatrix nn_matchC(const IntegerMatrix& mm_,
10 const IntegerVector& treat_,
11 const IntegerVector& ord_,
12 const IntegerVector& ratio,
13 const int& max_rat,
14 const LogicalVector& discarded,
15 const int& reuse_max,
16 const Nullable<NumericVector>& distance_ = R_NilValue,
17 const Nullable<NumericMatrix>& distance_mat_ = R_NilValue,
18 const Nullable<IntegerVector>& exact_ = R_NilValue,
19 const Nullable<double>& caliper_dist_ = R_NilValue,
20 const Nullable<NumericVector>& caliper_covs_ = R_NilValue,
21 const Nullable<NumericMatrix>& calcovs_covs_mat_ = R_NilValue,
22 const Nullable<NumericMatrix>& mah_covs_ = R_NilValue,
23 const Nullable<IntegerMatrix>& antiexact_covs_ = R_NilValue,
24 const bool& disl_prog = false)
25 {
26
27 // Initialize
28
29 NumericVector distance, caliper_covs;
30 double caliper_dist;
31 NumericMatrix distance_mat, calcovs_covs_mat, mah_covs, mah_covs_c;
32 IntegerMatrix antiexact_covs;
33 IntegerVector exact, exact_c, antiexact_col;
34
35 bool use_dist_mat = false;
36 bool use_exact = false;
37 bool use_caliper_dist = false;
38 bool use_caliper_covs = false;
39 bool use_mah_covs = false;
40 bool use_antiexact = false;
41 bool use_reuse_max = false;
42
43 // Info about original treat
44 int n_ = treat_.size();
45 IntegerVector ind_ = Range(0, n_ - 1);
46 IntegerVector ind1_ = ind_[treat_ == 1];
47 IntegerVector ind0_ = ind_[treat_ == 0];
48 int n1_ = ind1_.size();
49 int n0_ = n_ - n1_;
50
51 // Output matrix with sample indices of C units
52 IntegerMatrix mm = mm_;
53
54 // Store who has been matched
55 IntegerVector matched = rep(0, n_);
56 matched[discarded] = n1_; //discarded are unmatchable
57
58 // After discarding
59
60 IntegerVector ind = ind_[!discarded];
61 IntegerVector treat = treat_[!discarded];
62 IntegerVector ind0 = ind[treat == 0];
63 int n0 = ind0.size();
64
65 int t, t_ind, min_ind, c_chosen, num_eligible, cal_len, t_rat, n_anti;
66 double dt, cal_var_t;
67
68 NumericVector cal_var, cal_diff, ps_diff, diff, dist_t, mah_covs_t, mah_covs_row,
69 match_distance(n0);
70
71 IntegerVector c_eligible(n0), indices(n0);
72 LogicalVector finite_match_distance(n0);
73
74 if (distance_.isNotNull()) {
75 distance = distance_;
76 }
77 if (exact_.isNotNull()) {
78 exact = exact_;
79 use_exact = true;
80 }
81 if (caliper_dist_.isNotNull()) {
82 caliper_dist = as<double>(caliper_dist_);
83 use_caliper_dist = true;
84 ps_diff = NumericVector(n_);
85 }
86 if (caliper_covs_.isNotNull()) {
87 caliper_covs = caliper_covs_;
88 use_caliper_covs = true;
89 cal_len = caliper_covs.size();
90 cal_diff = NumericVector(n0);
91 }
92 if (calcovs_covs_mat_.isNotNull()) {
93 calcovs_covs_mat = as<NumericMatrix>(calcovs_covs_mat_);
94 }
95 if (mah_covs_.isNotNull()) {
96 mah_covs = as<NumericMatrix>(mah_covs_);
97 NumericVector mah_covs_row(mah_covs.ncol());
98 use_mah_covs = true;
99 } else {
100 if (distance_mat_.isNotNull()) {
101 distance_mat = as<NumericMatrix>(distance_mat_);
102
103 // IntegerVector ind0_ = ind_[treat_ == 0];
104 NumericVector dist_t(n0_);
105 use_dist_mat = true;
106 }
107 ps_diff = NumericVector(n_);
108 }
109 if (antiexact_covs_.isNotNull()) {
110 antiexact_covs = as<IntegerMatrix>(antiexact_covs_);
111 n_anti = antiexact_covs.ncol();
112 use_antiexact = true;
113 }
114 if (reuse_max < n1_) {
115 use_reuse_max = true;
116 }
117
118 bool ps_diff_assigned;
119
120 //progress bar
121 int prog_length;
122 if (!use_reuse_max) prog_length = n1_ + 1;
123 else prog_length = max_rat*n1_ + 1;
124 Progress p(prog_length, disl_prog);
125
126 //Counters
127 int rat, i, x, j, j_, a, k;
128 k = -1;
129
130 //Matching
131 for (rat = 0; rat < max_rat; ++rat) {
132 for (i = 0; i < n1_; ++i) {
133
134 k++;
135 if (k % 500 == 0) Rcpp::checkUserInterrupt();
136
137 p.increment();
138
139 if (all(as<IntegerVector>(matched[ind0]) >= reuse_max).is_true()){
140 break;
141 }
142
143 t = ord_[i] - 1; // index among treated
144 t_ind = ind1_[t]; // index among sample
145
146 if (matched[t_ind]) {
147 continue;
148 }
149
150 //Check if unit has enough matches
151 t_rat = ratio[t];
152
153 if (t_rat < rat + 1) {
154 continue;
155 }
156
157 c_eligible = ind0; // index among sample
158
159 c_eligible = c_eligible[as<IntegerVector>(matched[c_eligible]) < reuse_max];
160
161 if (use_exact) {
162 exact_c = exact[c_eligible];
163 c_eligible = c_eligible[exact_c == exact[t_ind]];
164 }
165
166 if (c_eligible.size() == 0) {
167 continue;
168 }
169
170 if (use_antiexact) {
171 for (a = 0; a < n_anti; ++a) {
172 antiexact_col = antiexact_covs(_, a);
173 antiexact_col = antiexact_col[c_eligible];
174 c_eligible = c_eligible[antiexact_col != antiexact_col[t_ind]];
175 }
176 }
177
178 if (c_eligible.size() == 0) {
179 continue;
180 }
181
182 ps_diff_assigned = false;
183
184 if (use_caliper_dist) {
185 if (use_dist_mat) {
186 dist_t = distance_mat.row(t);
187 diff = dist_t[match(c_eligible, ind0_) - 1];
188 } else {
189 dt = distance[t_ind];
190 diff = Rcpp::abs(as<NumericVector>(distance[c_eligible]) - dt);
191 }
192
193 ps_diff[c_eligible] = diff;
194 ps_diff_assigned = true;
195
196 c_eligible = c_eligible[diff <= caliper_dist];
197
198 if (c_eligible.size() == 0) {
199 continue;
200 }
201 }
202
203 if (use_caliper_covs) {
204 for (x = 0; (x < cal_len) && c_eligible.size() > 0; ++x) {
205 cal_var = calcovs_covs_mat( _ , x );
206
207 cal_var_t = cal_var[t_ind];
208
209 diff = Rcpp::abs(as<NumericVector>(cal_var[c_eligible]) - cal_var_t);
210
211 cal_diff = diff;
212
213 c_eligible = c_eligible[cal_diff <= caliper_covs[x]];
214 }
215
216 if (c_eligible.size() == 0) {
217 continue;
218 }
219 }
220
221 //Compute distances among eligible
222 num_eligible = c_eligible.size();
223
224 //If replace and few eligible controls, assign all and move on
225 if (!use_reuse_max && (num_eligible <= t_rat)) {
226 for (j = 0; j < num_eligible; ++j) {
227 mm( t , j ) = c_eligible[j] + 1;
228 }
229 continue;
230 }
231
232 if (use_mah_covs) {
233
234 match_distance = rep(0.0, num_eligible);
235 mah_covs_t = mah_covs( t_ind , _ );
236
237 for (j = 0; j < num_eligible; j++) {
238 j_ = c_eligible[j];
239 mah_covs_row = mah_covs(c_eligible[j], _);
240 match_distance[j] = sqrt(sum(pow(mah_covs_t - mah_covs_row, 2.0)));
241 }
242
243 } else if (ps_diff_assigned) {
244 match_distance = ps_diff[c_eligible]; //c_eligible might have shrunk since previous assignment
245 } else if (use_dist_mat) {
246 dist_t = distance_mat.row(t);
247 match_distance = dist_t[match(c_eligible, ind0_) - 1];
248 } else {
249 dt = distance[t_ind];
250 match_distance = Rcpp::abs(as<NumericVector>(distance[c_eligible]) - dt);
251 }
252
253 //Remove infinite distances
254 finite_match_distance = is_finite(match_distance);
255 c_eligible = c_eligible[finite_match_distance];
256 if (c_eligible.size() == 0) {
257 continue;
258 }
259 match_distance = match_distance[finite_match_distance];
260
261 if (!use_reuse_max) {
262 //When matching w/ replacement, get t_rat closest control units
263 indices = Range(0, num_eligible - 1);
264
265 std::partial_sort(indices.begin(), indices.begin() + t_rat, indices.end(),
266 [&match_distance](int k, int j) {return match_distance[k] < match_distance[j];});
267
268 for (j = 0; j < t_rat; ++j) {
269 min_ind = indices[j];
270 mm( t , j ) = c_eligible[min_ind] + 1;
271 }
272 }
273 else {
274 min_ind = which_min(match_distance);
275 c_chosen = c_eligible[min_ind];
276
277 mm( t , rat ) = c_chosen + 1; // + 1 because C indexing starts at 0 but mm is sent to R
278
279 matched[c_chosen] = matched[c_chosen] + 1;
280 }
281 }
282
283 if (!use_reuse_max) break;
284 }
285
286 p.update(prog_length);
287
288 return mm;
289 }
290