1 #include "vctrs.h"
2 #include "utils.h"
3 #include "translate.h"
4 #include <strings.h>
5 
stop_not_comparable(SEXP x,SEXP y,const char * message)6 static void stop_not_comparable(SEXP x, SEXP y, const char* message) {
7   Rf_errorcall(R_NilValue, "`x` and `y` are not comparable: %s", message);
8 }
9 
10 // https://stackoverflow.com/questions/10996418
icmp(int x,int y)11 static inline int icmp(int x, int y) {
12   return (x > y) - (x < y);
13 }
qsort_icmp(const void * x,const void * y)14 int qsort_icmp(const void* x, const void* y) {
15   return icmp(*((int*) x), *((int*) y));
16 }
17 
dcmp(double x,double y)18 static int dcmp(double x, double y) {
19   return (x > y) - (x < y);
20 }
21 
22 // Assume translation handled by `vec_normalize_encoding()`
scmp(SEXP x,SEXP y)23 static inline int scmp(SEXP x, SEXP y) {
24   if (x == y) {
25     return 0;
26   }
27 
28   int cmp = strcmp(CHAR(x), CHAR(y));
29   return cmp / abs(cmp);
30 }
31 
32 // -----------------------------------------------------------------------------
33 
lgl_compare_scalar(const int * x,const int * y,bool na_equal)34 static inline int lgl_compare_scalar(const int* x, const int* y, bool na_equal) {
35   int xi = *x;
36   int yj = *y;
37 
38   if (na_equal) {
39     return icmp(xi, yj);
40   } else {
41     return (xi == NA_LOGICAL || yj == NA_LOGICAL) ? NA_INTEGER : icmp(xi, yj);
42   }
43 }
44 
int_compare_scalar(const int * x,const int * y,bool na_equal)45 static inline int int_compare_scalar(const int* x, const int* y, bool na_equal) {
46   int xi = *x;
47   int yj = *y;
48 
49   if (na_equal) {
50     return icmp(xi, yj);
51   } else {
52     return (xi == NA_INTEGER || yj == NA_INTEGER) ? NA_INTEGER : icmp(xi, yj);
53   }
54 }
55 
dbl_compare_scalar(const double * x,const double * y,bool na_equal)56 static inline int dbl_compare_scalar(const double* x, const double* y, bool na_equal) {
57   double xi = *x;
58   double yj = *y;
59 
60   if (na_equal) {
61     enum vctrs_dbl_class x_class = dbl_classify(xi);
62     enum vctrs_dbl_class y_class = dbl_classify(yj);
63 
64     switch (x_class) {
65     case vctrs_dbl_number: {
66       switch (y_class) {
67       case vctrs_dbl_number: return dcmp(xi, yj);
68       case vctrs_dbl_missing: return 1;
69       case vctrs_dbl_nan: return 1;
70       }
71     }
72     case vctrs_dbl_missing: {
73       switch (y_class) {
74       case vctrs_dbl_number: return -1;
75       case vctrs_dbl_missing: return 0;
76       case vctrs_dbl_nan: return 1;
77       }
78     }
79     case vctrs_dbl_nan: {
80       switch (y_class) {
81       case vctrs_dbl_number: return -1;
82       case vctrs_dbl_missing: return -1;
83       case vctrs_dbl_nan: return 0;
84       }
85     }
86     }
87   } else {
88     return (isnan(xi) || isnan(yj)) ? NA_INTEGER : dcmp(xi, yj);
89   }
90 
91   never_reached("dbl_compare_scalar");
92 }
93 
chr_compare_scalar(const SEXP * x,const SEXP * y,bool na_equal)94 static inline int chr_compare_scalar(const SEXP* x, const SEXP* y, bool na_equal) {
95   const SEXP xi = *x;
96   const SEXP yj = *y;
97 
98   if (na_equal) {
99     if (xi == NA_STRING) {
100       return (yj == NA_STRING) ? 0 : -1;
101     } else {
102       return (yj == NA_STRING) ? 1 : scmp(xi, yj);
103     }
104   } else {
105     return (xi == NA_STRING || yj == NA_STRING) ? NA_INTEGER : scmp(xi, yj);
106   }
107 }
108 
df_compare_scalar(SEXP x,R_len_t i,SEXP y,R_len_t j,bool na_equal,int n_col)109 static inline int df_compare_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal, int n_col) {
110   int cmp;
111 
112   for (int k = 0; k < n_col; ++k) {
113     SEXP col_x = VECTOR_ELT(x, k);
114     SEXP col_y = VECTOR_ELT(y, k);
115 
116     cmp = compare_scalar(col_x, i, col_y, j, na_equal);
117 
118     if (cmp != 0) {
119       return cmp;
120     }
121   }
122 
123   return cmp;
124 }
125 
126 // -----------------------------------------------------------------------------
127 
128 // [[ include("vctrs.h") ]]
compare_scalar(SEXP x,R_len_t i,SEXP y,R_len_t j,bool na_equal)129 int compare_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal) {
130   switch (TYPEOF(x)) {
131   case LGLSXP: return lgl_compare_scalar(LOGICAL_RO(x) + i, LOGICAL_RO(y) + j, na_equal);
132   case INTSXP: return int_compare_scalar(INTEGER_RO(x) + i, INTEGER_RO(y) + j, na_equal);
133   case REALSXP: return dbl_compare_scalar(REAL_RO(x) + i, REAL_RO(y) + j, na_equal);
134   case STRSXP: return chr_compare_scalar(STRING_PTR_RO(x) + i, STRING_PTR_RO(y) + j, na_equal);
135   default: break;
136   }
137 
138   switch (vec_proxy_typeof(x)) {
139   case vctrs_type_list: stop_not_comparable(x, y, "lists are not comparable");
140   case vctrs_type_dataframe: {
141     int n_col = Rf_length(x);
142 
143     if (n_col != Rf_length(y)) {
144       stop_not_comparable(x, y, "must have the same number of columns");
145     }
146 
147     if (n_col == 0) {
148       stop_not_comparable(x, y, "data frame with zero columns");
149     }
150 
151     return df_compare_scalar(x, i, y, j, na_equal, n_col);
152   }
153   default: break;
154   }
155 
156   Rf_errorcall(R_NilValue, "Unsupported type %s", Rf_type2char(TYPEOF(x)));
157 }
158 
159 // -----------------------------------------------------------------------------
160 
161 static SEXP df_compare(SEXP x, SEXP y, bool na_equal, R_len_t size);
162 
163 #define COMPARE(CTYPE, CONST_DEREF, SCALAR_COMPARE)     \
164 do {                                                    \
165   SEXP out = PROTECT(Rf_allocVector(INTSXP, size));     \
166   int* p_out = INTEGER(out);                            \
167                                                         \
168   const CTYPE* p_x = CONST_DEREF(x);                    \
169   const CTYPE* p_y = CONST_DEREF(y);                    \
170                                                         \
171   for (R_len_t i = 0; i < size; ++i, ++p_x, ++p_y) {    \
172     p_out[i] = SCALAR_COMPARE(p_x, p_y, na_equal);      \
173   }                                                     \
174                                                         \
175   UNPROTECT(3);                                         \
176   return out;                                           \
177 }                                                       \
178 while (0)
179 
180 // [[ register() ]]
vctrs_compare(SEXP x,SEXP y,SEXP na_equal_)181 SEXP vctrs_compare(SEXP x, SEXP y, SEXP na_equal_) {
182   bool na_equal = r_bool_as_int(na_equal_);
183 
184   R_len_t size = vec_size(x);
185 
186   enum vctrs_type type = vec_proxy_typeof(x);
187   if (type != vec_proxy_typeof(y) || size != vec_size(y)) {
188     stop_not_comparable(x, y, "must have the same types and lengths");
189   }
190 
191   x = PROTECT(vec_normalize_encoding(x));
192   y = PROTECT(vec_normalize_encoding(y));
193 
194   switch (type) {
195   case vctrs_type_logical:   COMPARE(int, LOGICAL_RO, lgl_compare_scalar);
196   case vctrs_type_integer:   COMPARE(int, INTEGER_RO, int_compare_scalar);
197   case vctrs_type_double:    COMPARE(double, REAL_RO, dbl_compare_scalar);
198   case vctrs_type_character: COMPARE(SEXP, STRING_PTR_RO, chr_compare_scalar);
199   case vctrs_type_dataframe: {
200     SEXP out = df_compare(x, y, na_equal, size);
201     UNPROTECT(2);
202     return out;
203   }
204   case vctrs_type_scalar:    Rf_errorcall(R_NilValue, "Can't compare scalars with `vctrs_compare()`");
205   case vctrs_type_list:      Rf_errorcall(R_NilValue, "Can't compare lists with `vctrs_compare()`");
206   default:                   Rf_error("Unimplemented type in `vctrs_compare()`");
207   }
208 }
209 
210 #undef COMPARE
211 
212 // -----------------------------------------------------------------------------
213 
214 static void vec_compare_col(int* p_out,
215                             struct df_short_circuit_info* p_info,
216                             SEXP x,
217                             SEXP y,
218                             bool na_equal);
219 
220 static void df_compare_impl(int* p_out,
221                             struct df_short_circuit_info* p_info,
222                             SEXP x,
223                             SEXP y,
224                             bool na_equal);
225 
df_compare(SEXP x,SEXP y,bool na_equal,R_len_t size)226 static SEXP df_compare(SEXP x, SEXP y, bool na_equal, R_len_t size) {
227   int nprot = 0;
228 
229   SEXP out = PROTECT_N(Rf_allocVector(INTSXP, size), &nprot);
230   int* p_out = INTEGER(out);
231 
232   // Initialize to "equality" value
233   // and only change if we learn that it differs
234   memset(p_out, 0, size * sizeof(int));
235 
236   struct df_short_circuit_info info = new_df_short_circuit_info(size, false);
237   struct df_short_circuit_info* p_info = &info;
238   PROTECT_DF_SHORT_CIRCUIT_INFO(p_info, &nprot);
239 
240   df_compare_impl(p_out, p_info, x, y, na_equal);
241 
242   UNPROTECT(nprot);
243   return out;
244 }
245 
df_compare_impl(int * p_out,struct df_short_circuit_info * p_info,SEXP x,SEXP y,bool na_equal)246 static void df_compare_impl(int* p_out,
247                             struct df_short_circuit_info* p_info,
248                             SEXP x,
249                             SEXP y,
250                             bool na_equal) {
251   int n_col = Rf_length(x);
252 
253   if (n_col == 0) {
254     stop_not_comparable(x, y, "data frame with zero columns");
255   }
256 
257   if (n_col != Rf_length(y)) {
258     stop_not_comparable(x, y, "must have the same number of columns");
259   }
260 
261   for (R_len_t i = 0; i < n_col; ++i) {
262     SEXP x_col = VECTOR_ELT(x, i);
263     SEXP y_col = VECTOR_ELT(y, i);
264 
265     vec_compare_col(p_out, p_info, x_col, y_col, na_equal);
266 
267     // If we know all comparison values, break
268     if (p_info->remaining == 0) {
269       break;
270     }
271   }
272 }
273 
274 // -----------------------------------------------------------------------------
275 
276 #define COMPARE_COL(CTYPE, CONST_DEREF, SCALAR_COMPARE)              \
277 do {                                                                 \
278   const CTYPE* p_x = CONST_DEREF(x);                                 \
279   const CTYPE* p_y = CONST_DEREF(y);                                 \
280                                                                      \
281   for (R_len_t i = 0; i < p_info->size; ++i, ++p_x, ++p_y) {         \
282     if (p_info->p_row_known[i]) {                                    \
283       continue;                                                      \
284     }                                                                \
285                                                                      \
286     int cmp = SCALAR_COMPARE(p_x, p_y, na_equal);                    \
287                                                                      \
288     if (cmp != 0) {                                                  \
289       p_out[i] = cmp;                                                \
290       p_info->p_row_known[i] = true;                                 \
291       --p_info->remaining;                                           \
292                                                                      \
293       if (p_info->remaining == 0) {                                  \
294         break;                                                       \
295       }                                                              \
296     }                                                                \
297   }                                                                  \
298 }                                                                    \
299 while (0)
300 
vec_compare_col(int * p_out,struct df_short_circuit_info * p_info,SEXP x,SEXP y,bool na_equal)301 static void vec_compare_col(int* p_out,
302                             struct df_short_circuit_info* p_info,
303                             SEXP x,
304                             SEXP y,
305                             bool na_equal) {
306   switch (vec_proxy_typeof(x)) {
307   case vctrs_type_logical:   COMPARE_COL(int, LOGICAL_RO, lgl_compare_scalar); break;
308   case vctrs_type_integer:   COMPARE_COL(int, INTEGER_RO, int_compare_scalar); break;
309   case vctrs_type_double:    COMPARE_COL(double, REAL_RO, dbl_compare_scalar); break;
310   case vctrs_type_character: COMPARE_COL(SEXP, STRING_PTR_RO, chr_compare_scalar); break;
311   case vctrs_type_dataframe: df_compare_impl(p_out, p_info, x, y, na_equal); break;
312   case vctrs_type_scalar:    Rf_errorcall(R_NilValue, "Can't compare scalars with `vctrs_compare()`");
313   case vctrs_type_list:      Rf_errorcall(R_NilValue, "Can't compare lists with `vctrs_compare()`");
314   default:                   Rf_error("Unimplemented type in `vctrs_compare()`");
315   }
316 }
317 
318 #undef COMPARE_COL
319