1 #include "language_classifier.h"
2 
3 #include <float.h>
4 
5 #include "language_features.h"
6 #include "minibatch.h"
7 #include "normalize.h"
8 #include "token_types.h"
9 #include "unicode_scripts.h"
10 
11 #define LANGUAGE_CLASSIFIER_SIGNATURE 0xCCCCCCCC
12 #define LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE 0xC0C0C0C0
13 
14 #define LANGUAGE_CLASSIFIER_SETUP_ERROR "language_classifier not loaded, run libpostal_setup_language_classifier()\n"
15 
16 #define MIN_PROB (0.05 - DBL_EPSILON)
17 
18 static language_classifier_t *language_classifier = NULL;
19 
language_classifier_destroy(language_classifier_t * self)20 void language_classifier_destroy(language_classifier_t *self) {
21     if (self == NULL) return;
22 
23     if (self->features != NULL) {
24         trie_destroy(self->features);
25     }
26 
27     if (self->labels != NULL) {
28         cstring_array_destroy(self->labels);
29     }
30 
31     if (self->weights_type == MATRIX_DENSE && self->weights.dense != NULL) {
32         double_matrix_destroy(self->weights.dense);
33     } else if (self->weights_type == MATRIX_SPARSE && self->weights.sparse != NULL) {
34         sparse_matrix_destroy(self->weights.sparse);
35     }
36 
37     free(self);
38 }
39 
language_classifier_new(void)40 language_classifier_t *language_classifier_new(void) {
41     language_classifier_t *language_classifier = calloc(1, sizeof(language_classifier_t));
42     return language_classifier;
43 }
44 
get_language_classifier(void)45 language_classifier_t *get_language_classifier(void) {
46     return language_classifier;
47 }
48 
language_classifier_response_destroy(language_classifier_response_t * self)49 void language_classifier_response_destroy(language_classifier_response_t *self) {
50     if (self == NULL) return;
51     if (self->languages != NULL) {
52         free(self->languages);
53     }
54 
55     if (self->probs) {
56         free(self->probs);
57     }
58 
59     free(self);
60 }
61 
classify_languages(char * address)62 language_classifier_response_t *classify_languages(char *address) {
63     language_classifier_t *classifier = get_language_classifier();
64 
65     if (classifier == NULL) {
66         log_error(LANGUAGE_CLASSIFIER_SETUP_ERROR);
67         return NULL;
68     }
69 
70     char *normalized = language_classifier_normalize_string(address);
71 
72     token_array *tokens = token_array_new();
73     char_array *feature_array = char_array_new();
74 
75     khash_t(str_double) *feature_counts = extract_language_features(normalized, NULL, tokens, feature_array);
76     if (feature_counts == NULL || kh_size(feature_counts) == 0) {
77         token_array_destroy(tokens);
78         char_array_destroy(feature_array);
79         if (feature_counts != NULL) {
80             kh_destroy(str_double, feature_counts);
81         }
82         free(normalized);
83         return NULL;
84     }
85 
86     sparse_matrix_t *x = feature_vector(classifier->features, feature_counts);
87 
88     size_t n = classifier->num_labels;
89     double_matrix_t *p_y = double_matrix_new_zeros(1, n);
90 
91     language_classifier_response_t *response = NULL;
92     bool model_exp = false;
93     if (classifier->weights_type == MATRIX_DENSE) {
94         model_exp = logistic_regression_model_expectation(classifier->weights.dense, x, p_y);
95     } else if (classifier->weights_type == MATRIX_SPARSE) {
96         model_exp = logistic_regression_model_expectation_sparse(classifier->weights.sparse, x, p_y);
97     }
98 
99     if (model_exp) {
100         double *predictions = double_matrix_get_row(p_y, 0);
101         size_t *indices = double_array_argsort(predictions, n);
102         size_t num_languages = 0;
103         size_t i;
104         double prob;
105 
106         double min_prob = 1.0 / n;
107         if (min_prob < MIN_PROB) min_prob = MIN_PROB;
108 
109         for (i = 0; i < n; i++) {
110             size_t idx = indices[n - i - 1];
111             prob = predictions[idx];
112 
113             if (i == 0 || prob > min_prob) {
114                 num_languages++;
115             } else {
116                 break;
117             }
118         }
119         char **languages = malloc(sizeof(char *) * num_languages);
120         double *probs = malloc(sizeof(double) * num_languages);
121 
122         for (i = 0; i < num_languages; i++) {
123             size_t idx = indices[n - i - 1];
124             char *lang = cstring_array_get_string(classifier->labels, (uint32_t)idx);
125             prob = predictions[idx];
126             languages[i] = lang;
127             probs[i] = prob;
128         }
129 
130         free(indices);
131 
132         response = malloc(sizeof(language_classifier_response_t));
133         response->num_languages = num_languages;
134         response->languages = languages;
135         response->probs = probs;
136     }
137 
138     sparse_matrix_destroy(x);
139     double_matrix_destroy(p_y);
140     token_array_destroy(tokens);
141     char_array_destroy(feature_array);
142     const char *key;
143     kh_foreach_key(feature_counts, key, {
144         free((char *)key);
145     })
146     kh_destroy(str_double, feature_counts);
147     free(normalized);
148     return response;
149 
150 }
151 
language_classifier_read(FILE * f)152 language_classifier_t *language_classifier_read(FILE *f) {
153     if (f == NULL) return NULL;
154     long save_pos = ftell(f);
155 
156     uint32_t signature;
157 
158     if (!file_read_uint32(f, &signature)) {
159         goto exit_file_read;
160     }
161 
162     if (signature != LANGUAGE_CLASSIFIER_SIGNATURE && signature != LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE) {
163         goto exit_file_read;
164     }
165 
166     language_classifier_t *classifier = language_classifier_new();
167     if (classifier == NULL) {
168         goto exit_file_read;
169     }
170 
171     trie_t *features = trie_read(f);
172     if (features == NULL) {
173         goto exit_classifier_created;
174     }
175     classifier->features = features;
176     uint64_t num_features;
177     if (!file_read_uint64(f, &num_features)) {
178         goto exit_classifier_created;
179     }
180     classifier->num_features = (size_t)num_features;
181 
182     uint64_t labels_str_len;
183 
184     if (!file_read_uint64(f, &labels_str_len)) {
185         goto exit_classifier_created;
186     }
187 
188     char_array *array = char_array_new_size(labels_str_len);
189 
190     if (array == NULL) {
191         goto exit_classifier_created;
192     }
193 
194     if (!file_read_chars(f, array->a, labels_str_len)) {
195         char_array_destroy(array);
196         goto exit_classifier_created;
197     }
198 
199     array->n = labels_str_len;
200 
201     classifier->labels = cstring_array_from_char_array(array);
202     if (classifier->labels == NULL) {
203         goto exit_classifier_created;
204     }
205     classifier->num_labels = cstring_array_num_strings(classifier->labels);
206 
207     if (signature == LANGUAGE_CLASSIFIER_SIGNATURE) {
208         double_matrix_t *weights = double_matrix_read(f);
209         if (weights == NULL) {
210             goto exit_classifier_created;
211         }
212         classifier->weights_type = MATRIX_DENSE;
213         classifier->weights.dense = weights;
214     } else if (signature == LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE) {
215         sparse_matrix_t *sparse_weights = sparse_matrix_read(f);
216         if (sparse_weights == NULL) {
217             goto exit_classifier_created;
218         }
219         classifier->weights_type = MATRIX_SPARSE;
220         classifier->weights.sparse = sparse_weights;
221     }
222 
223     return classifier;
224 
225 exit_classifier_created:
226     language_classifier_destroy(classifier);
227 exit_file_read:
228     fseek(f, save_pos, SEEK_SET);
229     return NULL;
230 }
231 
232 
language_classifier_load(char * path)233 language_classifier_t *language_classifier_load(char *path) {
234     FILE *f;
235 
236     f = fopen(path, "rb");
237     if (!f) return NULL;
238 
239     language_classifier_t *classifier = language_classifier_read(f);
240 
241     fclose(f);
242     return classifier;
243 }
244 
language_classifier_write(language_classifier_t * self,FILE * f)245 bool language_classifier_write(language_classifier_t *self, FILE *f) {
246     if (f == NULL || self == NULL) return false;
247 
248     if (self->weights_type == MATRIX_DENSE && !file_write_uint32(f, LANGUAGE_CLASSIFIER_SIGNATURE)) {
249         return false;
250     } else if (self->weights_type == MATRIX_SPARSE && !file_write_uint32(f, LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE)) {
251         return false;
252     }
253 
254     if (!trie_write(self->features, f) ||
255         !file_write_uint64(f, self->num_features) ||
256         !file_write_uint64(f, self->labels->str->n) ||
257         !file_write_chars(f, (const char *)self->labels->str->a, self->labels->str->n)) {
258         return false;
259     }
260 
261     if (self->weights_type == MATRIX_DENSE && !double_matrix_write(self->weights.dense, f)) {
262         return false;
263     } else if (self->weights_type == MATRIX_SPARSE && !sparse_matrix_write(self->weights.sparse, f)) {
264         return false;
265     }
266 
267     return true;
268 }
269 
language_classifier_save(language_classifier_t * self,char * path)270 bool language_classifier_save(language_classifier_t *self, char *path) {
271     if (self == NULL || path == NULL) return false;
272 
273     FILE *f = fopen(path, "wb");
274     if (!f) return false;
275 
276     bool result = language_classifier_write(self, f);
277     fclose(f);
278 
279     return result;
280 }
281 
282 // Module setup/teardown
283 
language_classifier_module_setup(char * dir)284 bool language_classifier_module_setup(char *dir) {
285     if (language_classifier != NULL) {
286         return true;
287     }
288 
289     if (dir == NULL) {
290         dir = LIBPOSTAL_LANGUAGE_CLASSIFIER_DIR;
291     }
292 
293     char *classifier_path;
294 
295     char_array *path = char_array_new_size(strlen(dir) + PATH_SEPARATOR_LEN + strlen(LANGUAGE_CLASSIFIER_FILENAME));
296     if (language_classifier == NULL) {
297         char_array_cat_joined(path, PATH_SEPARATOR, true, 2, dir, LANGUAGE_CLASSIFIER_FILENAME);
298         classifier_path = char_array_get_string(path);
299 
300         language_classifier = language_classifier_load(classifier_path);
301 
302     }
303 
304     char_array_destroy(path);
305     return true;
306 }
307 
language_classifier_module_teardown(void)308 void language_classifier_module_teardown(void) {
309     if (language_classifier != NULL) {
310         language_classifier_destroy(language_classifier);
311     }
312     language_classifier = NULL;
313 }
314 
315