1 #include "language_classifier.h"
2
3 #include <float.h>
4
5 #include "language_features.h"
6 #include "minibatch.h"
7 #include "normalize.h"
8 #include "token_types.h"
9 #include "unicode_scripts.h"
10
11 #define LANGUAGE_CLASSIFIER_SIGNATURE 0xCCCCCCCC
12 #define LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE 0xC0C0C0C0
13
14 #define LANGUAGE_CLASSIFIER_SETUP_ERROR "language_classifier not loaded, run libpostal_setup_language_classifier()\n"
15
16 #define MIN_PROB (0.05 - DBL_EPSILON)
17
18 static language_classifier_t *language_classifier = NULL;
19
language_classifier_destroy(language_classifier_t * self)20 void language_classifier_destroy(language_classifier_t *self) {
21 if (self == NULL) return;
22
23 if (self->features != NULL) {
24 trie_destroy(self->features);
25 }
26
27 if (self->labels != NULL) {
28 cstring_array_destroy(self->labels);
29 }
30
31 if (self->weights_type == MATRIX_DENSE && self->weights.dense != NULL) {
32 double_matrix_destroy(self->weights.dense);
33 } else if (self->weights_type == MATRIX_SPARSE && self->weights.sparse != NULL) {
34 sparse_matrix_destroy(self->weights.sparse);
35 }
36
37 free(self);
38 }
39
language_classifier_new(void)40 language_classifier_t *language_classifier_new(void) {
41 language_classifier_t *language_classifier = calloc(1, sizeof(language_classifier_t));
42 return language_classifier;
43 }
44
get_language_classifier(void)45 language_classifier_t *get_language_classifier(void) {
46 return language_classifier;
47 }
48
language_classifier_response_destroy(language_classifier_response_t * self)49 void language_classifier_response_destroy(language_classifier_response_t *self) {
50 if (self == NULL) return;
51 if (self->languages != NULL) {
52 free(self->languages);
53 }
54
55 if (self->probs) {
56 free(self->probs);
57 }
58
59 free(self);
60 }
61
classify_languages(char * address)62 language_classifier_response_t *classify_languages(char *address) {
63 language_classifier_t *classifier = get_language_classifier();
64
65 if (classifier == NULL) {
66 log_error(LANGUAGE_CLASSIFIER_SETUP_ERROR);
67 return NULL;
68 }
69
70 char *normalized = language_classifier_normalize_string(address);
71
72 token_array *tokens = token_array_new();
73 char_array *feature_array = char_array_new();
74
75 khash_t(str_double) *feature_counts = extract_language_features(normalized, NULL, tokens, feature_array);
76 if (feature_counts == NULL || kh_size(feature_counts) == 0) {
77 token_array_destroy(tokens);
78 char_array_destroy(feature_array);
79 if (feature_counts != NULL) {
80 kh_destroy(str_double, feature_counts);
81 }
82 free(normalized);
83 return NULL;
84 }
85
86 sparse_matrix_t *x = feature_vector(classifier->features, feature_counts);
87
88 size_t n = classifier->num_labels;
89 double_matrix_t *p_y = double_matrix_new_zeros(1, n);
90
91 language_classifier_response_t *response = NULL;
92 bool model_exp = false;
93 if (classifier->weights_type == MATRIX_DENSE) {
94 model_exp = logistic_regression_model_expectation(classifier->weights.dense, x, p_y);
95 } else if (classifier->weights_type == MATRIX_SPARSE) {
96 model_exp = logistic_regression_model_expectation_sparse(classifier->weights.sparse, x, p_y);
97 }
98
99 if (model_exp) {
100 double *predictions = double_matrix_get_row(p_y, 0);
101 size_t *indices = double_array_argsort(predictions, n);
102 size_t num_languages = 0;
103 size_t i;
104 double prob;
105
106 double min_prob = 1.0 / n;
107 if (min_prob < MIN_PROB) min_prob = MIN_PROB;
108
109 for (i = 0; i < n; i++) {
110 size_t idx = indices[n - i - 1];
111 prob = predictions[idx];
112
113 if (i == 0 || prob > min_prob) {
114 num_languages++;
115 } else {
116 break;
117 }
118 }
119 char **languages = malloc(sizeof(char *) * num_languages);
120 double *probs = malloc(sizeof(double) * num_languages);
121
122 for (i = 0; i < num_languages; i++) {
123 size_t idx = indices[n - i - 1];
124 char *lang = cstring_array_get_string(classifier->labels, (uint32_t)idx);
125 prob = predictions[idx];
126 languages[i] = lang;
127 probs[i] = prob;
128 }
129
130 free(indices);
131
132 response = malloc(sizeof(language_classifier_response_t));
133 response->num_languages = num_languages;
134 response->languages = languages;
135 response->probs = probs;
136 }
137
138 sparse_matrix_destroy(x);
139 double_matrix_destroy(p_y);
140 token_array_destroy(tokens);
141 char_array_destroy(feature_array);
142 const char *key;
143 kh_foreach_key(feature_counts, key, {
144 free((char *)key);
145 })
146 kh_destroy(str_double, feature_counts);
147 free(normalized);
148 return response;
149
150 }
151
language_classifier_read(FILE * f)152 language_classifier_t *language_classifier_read(FILE *f) {
153 if (f == NULL) return NULL;
154 long save_pos = ftell(f);
155
156 uint32_t signature;
157
158 if (!file_read_uint32(f, &signature)) {
159 goto exit_file_read;
160 }
161
162 if (signature != LANGUAGE_CLASSIFIER_SIGNATURE && signature != LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE) {
163 goto exit_file_read;
164 }
165
166 language_classifier_t *classifier = language_classifier_new();
167 if (classifier == NULL) {
168 goto exit_file_read;
169 }
170
171 trie_t *features = trie_read(f);
172 if (features == NULL) {
173 goto exit_classifier_created;
174 }
175 classifier->features = features;
176 uint64_t num_features;
177 if (!file_read_uint64(f, &num_features)) {
178 goto exit_classifier_created;
179 }
180 classifier->num_features = (size_t)num_features;
181
182 uint64_t labels_str_len;
183
184 if (!file_read_uint64(f, &labels_str_len)) {
185 goto exit_classifier_created;
186 }
187
188 char_array *array = char_array_new_size(labels_str_len);
189
190 if (array == NULL) {
191 goto exit_classifier_created;
192 }
193
194 if (!file_read_chars(f, array->a, labels_str_len)) {
195 char_array_destroy(array);
196 goto exit_classifier_created;
197 }
198
199 array->n = labels_str_len;
200
201 classifier->labels = cstring_array_from_char_array(array);
202 if (classifier->labels == NULL) {
203 goto exit_classifier_created;
204 }
205 classifier->num_labels = cstring_array_num_strings(classifier->labels);
206
207 if (signature == LANGUAGE_CLASSIFIER_SIGNATURE) {
208 double_matrix_t *weights = double_matrix_read(f);
209 if (weights == NULL) {
210 goto exit_classifier_created;
211 }
212 classifier->weights_type = MATRIX_DENSE;
213 classifier->weights.dense = weights;
214 } else if (signature == LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE) {
215 sparse_matrix_t *sparse_weights = sparse_matrix_read(f);
216 if (sparse_weights == NULL) {
217 goto exit_classifier_created;
218 }
219 classifier->weights_type = MATRIX_SPARSE;
220 classifier->weights.sparse = sparse_weights;
221 }
222
223 return classifier;
224
225 exit_classifier_created:
226 language_classifier_destroy(classifier);
227 exit_file_read:
228 fseek(f, save_pos, SEEK_SET);
229 return NULL;
230 }
231
232
language_classifier_load(char * path)233 language_classifier_t *language_classifier_load(char *path) {
234 FILE *f;
235
236 f = fopen(path, "rb");
237 if (!f) return NULL;
238
239 language_classifier_t *classifier = language_classifier_read(f);
240
241 fclose(f);
242 return classifier;
243 }
244
language_classifier_write(language_classifier_t * self,FILE * f)245 bool language_classifier_write(language_classifier_t *self, FILE *f) {
246 if (f == NULL || self == NULL) return false;
247
248 if (self->weights_type == MATRIX_DENSE && !file_write_uint32(f, LANGUAGE_CLASSIFIER_SIGNATURE)) {
249 return false;
250 } else if (self->weights_type == MATRIX_SPARSE && !file_write_uint32(f, LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE)) {
251 return false;
252 }
253
254 if (!trie_write(self->features, f) ||
255 !file_write_uint64(f, self->num_features) ||
256 !file_write_uint64(f, self->labels->str->n) ||
257 !file_write_chars(f, (const char *)self->labels->str->a, self->labels->str->n)) {
258 return false;
259 }
260
261 if (self->weights_type == MATRIX_DENSE && !double_matrix_write(self->weights.dense, f)) {
262 return false;
263 } else if (self->weights_type == MATRIX_SPARSE && !sparse_matrix_write(self->weights.sparse, f)) {
264 return false;
265 }
266
267 return true;
268 }
269
language_classifier_save(language_classifier_t * self,char * path)270 bool language_classifier_save(language_classifier_t *self, char *path) {
271 if (self == NULL || path == NULL) return false;
272
273 FILE *f = fopen(path, "wb");
274 if (!f) return false;
275
276 bool result = language_classifier_write(self, f);
277 fclose(f);
278
279 return result;
280 }
281
282 // Module setup/teardown
283
language_classifier_module_setup(char * dir)284 bool language_classifier_module_setup(char *dir) {
285 if (language_classifier != NULL) {
286 return true;
287 }
288
289 if (dir == NULL) {
290 dir = LIBPOSTAL_LANGUAGE_CLASSIFIER_DIR;
291 }
292
293 char *classifier_path;
294
295 char_array *path = char_array_new_size(strlen(dir) + PATH_SEPARATOR_LEN + strlen(LANGUAGE_CLASSIFIER_FILENAME));
296 if (language_classifier == NULL) {
297 char_array_cat_joined(path, PATH_SEPARATOR, true, 2, dir, LANGUAGE_CLASSIFIER_FILENAME);
298 classifier_path = char_array_get_string(path);
299
300 language_classifier = language_classifier_load(classifier_path);
301
302 }
303
304 char_array_destroy(path);
305 return true;
306 }
307
language_classifier_module_teardown(void)308 void language_classifier_module_teardown(void) {
309 if (language_classifier != NULL) {
310 language_classifier_destroy(language_classifier);
311 }
312 language_classifier = NULL;
313 }
314
315