1 // [[Rcpp::depends(RcppProgress)]]
2 #include <progress.hpp>
3 #include <Rcpp.h>
4 using namespace Rcpp;
5 
6 // [[Rcpp::plugins(cpp11)]]
7 
8 // [[Rcpp::export]]
nn_matchC(const IntegerMatrix & mm_,const IntegerVector & treat_,const IntegerVector & ord_,const IntegerVector & ratio,const int & max_rat,const LogicalVector & discarded,const int & reuse_max,const Nullable<NumericVector> & distance_=R_NilValue,const Nullable<NumericMatrix> & distance_mat_=R_NilValue,const Nullable<IntegerVector> & exact_=R_NilValue,const Nullable<double> & caliper_dist_=R_NilValue,const Nullable<NumericVector> & caliper_covs_=R_NilValue,const Nullable<NumericMatrix> & calcovs_covs_mat_=R_NilValue,const Nullable<NumericMatrix> & mah_covs_=R_NilValue,const Nullable<IntegerMatrix> & antiexact_covs_=R_NilValue,const bool & disl_prog=false)9 IntegerMatrix nn_matchC(const IntegerMatrix& mm_,
10                         const IntegerVector& treat_,
11                         const IntegerVector& ord_,
12                         const IntegerVector& ratio,
13                         const int& max_rat,
14                         const LogicalVector& discarded,
15                         const int& reuse_max,
16                         const Nullable<NumericVector>& distance_ = R_NilValue,
17                         const Nullable<NumericMatrix>& distance_mat_ = R_NilValue,
18                         const Nullable<IntegerVector>& exact_ = R_NilValue,
19                         const Nullable<double>& caliper_dist_ = R_NilValue,
20                         const Nullable<NumericVector>& caliper_covs_ = R_NilValue,
21                         const Nullable<NumericMatrix>& calcovs_covs_mat_ = R_NilValue,
22                         const Nullable<NumericMatrix>& mah_covs_ = R_NilValue,
23                         const Nullable<IntegerMatrix>& antiexact_covs_ = R_NilValue,
24                         const bool& disl_prog = false)
25   {
26 
27   // Initialize
28 
29   NumericVector distance, caliper_covs;
30   double caliper_dist;
31   NumericMatrix distance_mat, calcovs_covs_mat, mah_covs, mah_covs_c;
32   IntegerMatrix antiexact_covs;
33   IntegerVector exact, exact_c, antiexact_col;
34 
35   bool use_dist_mat = false;
36   bool use_exact = false;
37   bool use_caliper_dist = false;
38   bool use_caliper_covs = false;
39   bool use_mah_covs = false;
40   bool use_antiexact = false;
41   bool use_reuse_max = false;
42 
43   // Info about original treat
44   int n_ = treat_.size();
45   IntegerVector ind_ = Range(0, n_ - 1);
46   IntegerVector ind1_ = ind_[treat_ == 1];
47   IntegerVector ind0_ = ind_[treat_ == 0];
48   int n1_ = ind1_.size();
49   int n0_ = n_ - n1_;
50 
51   // Output matrix with sample indices of C units
52   IntegerMatrix mm = mm_;
53 
54   // Store who has been matched
55   IntegerVector matched = rep(0, n_);
56   matched[discarded] = n1_; //discarded are unmatchable
57 
58   // After discarding
59 
60   IntegerVector ind = ind_[!discarded];
61   IntegerVector treat = treat_[!discarded];
62   IntegerVector ind0 = ind[treat == 0];
63   int n0 = ind0.size();
64 
65   int t, t_ind, min_ind, c_chosen, num_eligible, cal_len, t_rat, n_anti;
66   double dt, cal_var_t;
67 
68   NumericVector cal_var, cal_diff, ps_diff, diff, dist_t, mah_covs_t, mah_covs_row,
69                 match_distance(n0);
70 
71   IntegerVector c_eligible(n0), indices(n0);
72   LogicalVector finite_match_distance(n0);
73 
74   if (distance_.isNotNull()) {
75     distance = distance_;
76   }
77   if (exact_.isNotNull()) {
78     exact = exact_;
79     use_exact = true;
80   }
81   if (caliper_dist_.isNotNull()) {
82     caliper_dist = as<double>(caliper_dist_);
83     use_caliper_dist = true;
84     ps_diff = NumericVector(n_);
85   }
86   if (caliper_covs_.isNotNull()) {
87     caliper_covs = caliper_covs_;
88     use_caliper_covs = true;
89     cal_len = caliper_covs.size();
90     cal_diff = NumericVector(n0);
91   }
92   if (calcovs_covs_mat_.isNotNull()) {
93     calcovs_covs_mat = as<NumericMatrix>(calcovs_covs_mat_);
94   }
95   if (mah_covs_.isNotNull()) {
96     mah_covs = as<NumericMatrix>(mah_covs_);
97     NumericVector mah_covs_row(mah_covs.ncol());
98     use_mah_covs = true;
99   } else {
100     if (distance_mat_.isNotNull()) {
101       distance_mat = as<NumericMatrix>(distance_mat_);
102 
103       // IntegerVector ind0_ = ind_[treat_ == 0];
104       NumericVector dist_t(n0_);
105       use_dist_mat = true;
106     }
107     ps_diff = NumericVector(n_);
108   }
109   if (antiexact_covs_.isNotNull()) {
110     antiexact_covs = as<IntegerMatrix>(antiexact_covs_);
111     n_anti = antiexact_covs.ncol();
112     use_antiexact = true;
113   }
114   if (reuse_max < n1_) {
115     use_reuse_max = true;
116   }
117 
118   bool ps_diff_assigned;
119 
120   //progress bar
121   int prog_length;
122   if (!use_reuse_max) prog_length = n1_ + 1;
123   else prog_length = max_rat*n1_ + 1;
124   Progress p(prog_length, disl_prog);
125 
126   //Counters
127   int rat, i, x, j, j_, a, k;
128   k = -1;
129 
130   //Matching
131   for (rat = 0; rat < max_rat; ++rat) {
132     for (i = 0; i < n1_; ++i) {
133 
134       k++;
135       if (k % 500 == 0) Rcpp::checkUserInterrupt();
136 
137       p.increment();
138 
139       if (all(as<IntegerVector>(matched[ind0]) >= reuse_max).is_true()){
140         break;
141       }
142 
143       t = ord_[i] - 1;   // index among treated
144       t_ind = ind1_[t]; // index among sample
145 
146       if (matched[t_ind]) {
147         continue;
148       }
149 
150       //Check if unit has enough matches
151       t_rat = ratio[t];
152 
153       if (t_rat < rat + 1) {
154         continue;
155       }
156 
157       c_eligible = ind0; // index among sample
158 
159       c_eligible = c_eligible[as<IntegerVector>(matched[c_eligible]) < reuse_max];
160 
161       if (use_exact) {
162         exact_c = exact[c_eligible];
163         c_eligible = c_eligible[exact_c == exact[t_ind]];
164       }
165 
166       if (c_eligible.size() == 0) {
167         continue;
168       }
169 
170       if (use_antiexact) {
171         for (a = 0; a < n_anti; ++a) {
172           antiexact_col = antiexact_covs(_, a);
173           antiexact_col = antiexact_col[c_eligible];
174           c_eligible = c_eligible[antiexact_col != antiexact_col[t_ind]];
175         }
176       }
177 
178       if (c_eligible.size() == 0) {
179         continue;
180       }
181 
182       ps_diff_assigned = false;
183 
184       if (use_caliper_dist) {
185         if (use_dist_mat) {
186           dist_t = distance_mat.row(t);
187           diff = dist_t[match(c_eligible, ind0_) - 1];
188         } else {
189           dt = distance[t_ind];
190           diff = Rcpp::abs(as<NumericVector>(distance[c_eligible]) - dt);
191         }
192 
193         ps_diff[c_eligible] = diff;
194         ps_diff_assigned = true;
195 
196         c_eligible = c_eligible[diff <= caliper_dist];
197 
198         if (c_eligible.size() == 0) {
199           continue;
200         }
201       }
202 
203       if (use_caliper_covs) {
204         for (x = 0; (x < cal_len) && c_eligible.size() > 0; ++x) {
205           cal_var = calcovs_covs_mat( _ , x );
206 
207           cal_var_t = cal_var[t_ind];
208 
209           diff = Rcpp::abs(as<NumericVector>(cal_var[c_eligible]) - cal_var_t);
210 
211           cal_diff = diff;
212 
213           c_eligible = c_eligible[cal_diff <= caliper_covs[x]];
214         }
215 
216         if (c_eligible.size() == 0) {
217           continue;
218         }
219       }
220 
221       //Compute distances among eligible
222       num_eligible = c_eligible.size();
223 
224       //If replace and few eligible controls, assign all and move on
225       if (!use_reuse_max && (num_eligible <= t_rat)) {
226         for (j = 0; j < num_eligible; ++j) {
227           mm( t , j ) = c_eligible[j] + 1;
228         }
229         continue;
230       }
231 
232       if (use_mah_covs) {
233 
234         match_distance = rep(0.0, num_eligible);
235         mah_covs_t = mah_covs( t_ind , _ );
236 
237         for (j = 0; j < num_eligible; j++) {
238           j_ = c_eligible[j];
239           mah_covs_row = mah_covs(c_eligible[j], _);
240           match_distance[j] = sqrt(sum(pow(mah_covs_t - mah_covs_row, 2.0)));
241         }
242 
243       } else if (ps_diff_assigned) {
244         match_distance = ps_diff[c_eligible]; //c_eligible might have shrunk since previous assignment
245       } else if (use_dist_mat) {
246         dist_t = distance_mat.row(t);
247         match_distance = dist_t[match(c_eligible, ind0_) - 1];
248       } else {
249         dt = distance[t_ind];
250         match_distance = Rcpp::abs(as<NumericVector>(distance[c_eligible]) - dt);
251       }
252 
253       //Remove infinite distances
254       finite_match_distance = is_finite(match_distance);
255       c_eligible = c_eligible[finite_match_distance];
256       if (c_eligible.size() == 0) {
257         continue;
258       }
259       match_distance = match_distance[finite_match_distance];
260 
261       if (!use_reuse_max) {
262         //When matching w/ replacement, get t_rat closest control units
263         indices = Range(0, num_eligible - 1);
264 
265         std::partial_sort(indices.begin(), indices.begin() + t_rat, indices.end(),
266                           [&match_distance](int k, int j) {return match_distance[k] < match_distance[j];});
267 
268         for (j = 0; j < t_rat; ++j) {
269           min_ind = indices[j];
270           mm( t , j ) = c_eligible[min_ind] + 1;
271         }
272       }
273       else {
274         min_ind = which_min(match_distance);
275         c_chosen = c_eligible[min_ind];
276 
277         mm( t , rat ) = c_chosen + 1; // + 1 because C indexing starts at 0 but mm is sent to R
278 
279         matched[c_chosen] = matched[c_chosen] + 1;
280       }
281     }
282 
283     if (!use_reuse_max) break;
284   }
285 
286   p.update(prog_length);
287 
288   return mm;
289 }
290