1 #include <stdint.h>
2
3 #include "address_parser.h"
4 #include "address_parser_io.h"
5 #include "address_dictionary.h"
6 #include "averaged_perceptron_trainer.h"
7 #include "crf_trainer_averaged_perceptron.h"
8 #include "collections.h"
9 #include "constants.h"
10 #include "file_utils.h"
11 #include "graph.h"
12 #include "graph_builder.h"
13 #include "shuffle.h"
14 #include "transliterate.h"
15
16 #include "log/log.h"
17
18 typedef struct phrase_stats {
19 khash_t(int_uint32) *class_counts;
20 uint16_t components;
21 } phrase_stats_t;
22
KHASH_MAP_INIT_STR(phrase_stats,phrase_stats_t)23 KHASH_MAP_INIT_STR(phrase_stats, phrase_stats_t)
24 KHASH_MAP_INIT_STR(postal_code_context_phrases, khash_t(str_set) *)
25 KHASH_MAP_INIT_STR(phrase_types, address_parser_types_t)
26
27 // Training
28
29 #define DEFAULT_ITERATIONS 5
30 #define DEFAULT_MIN_UPDATES 5
31 #define DEFAULT_MODEL_TYPE ADDRESS_PARSER_TYPE_CRF
32
33 #define MIN_VOCAB_COUNT 5
34 #define MIN_PHRASE_COUNT 1
35
36 static inline bool is_postal_code(char *label) {
37 return string_equals(label, ADDRESS_PARSER_LABEL_POSTAL_CODE);
38 }
39
is_admin_component(char * label)40 static inline bool is_admin_component(char *label) {
41 return (string_equals(label, ADDRESS_PARSER_LABEL_SUBURB) ||
42 string_equals(label, ADDRESS_PARSER_LABEL_CITY_DISTRICT) ||
43 string_equals(label, ADDRESS_PARSER_LABEL_CITY) ||
44 string_equals(label, ADDRESS_PARSER_LABEL_STATE_DISTRICT) ||
45 string_equals(label, ADDRESS_PARSER_LABEL_ISLAND) ||
46 string_equals(label, ADDRESS_PARSER_LABEL_STATE) ||
47 string_equals(label, ADDRESS_PARSER_LABEL_COUNTRY_REGION) ||
48 string_equals(label, ADDRESS_PARSER_LABEL_COUNTRY) ||
49 string_equals(label, ADDRESS_PARSER_LABEL_WORLD_REGION));
50 }
51
52 typedef struct vocab_context {
53 char_array *token_builder;
54 char_array *postal_code_token_builder;
55 char_array *sub_token_builder;
56 char_array *phrase_builder;
57 phrase_array *dictionary_phrases;
58 int64_array *phrase_memberships;
59 phrase_array *postal_code_dictionary_phrases;
60 token_array *sub_tokens;
61 } vocab_context_t;
62
address_phrases_and_labels(address_parser_data_set_t * data_set,cstring_array * phrases,cstring_array * phrase_labels,vocab_context_t * ctx)63 bool address_phrases_and_labels(address_parser_data_set_t *data_set, cstring_array *phrases, cstring_array *phrase_labels, vocab_context_t *ctx) {
64 tokenized_string_t *tokenized_str = data_set->tokenized_str;
65 if (tokenized_str == NULL) {
66 log_error("tokenized_str == NULL\n");
67 return false;
68 }
69
70 char *language = char_array_get_string(data_set->language);
71 if (string_equals(language, UNKNOWN_LANGUAGE) || string_equals(language, AMBIGUOUS_LANGUAGE)) {
72 language = NULL;
73 }
74
75 char_array *token_builder = ctx->token_builder;
76 char_array *postal_code_token_builder = ctx->postal_code_token_builder;
77 char_array *sub_token_builder = ctx->sub_token_builder;
78 char_array *phrase_builder = ctx->phrase_builder;
79 phrase_array *dictionary_phrases = ctx->dictionary_phrases;
80 int64_array *phrase_memberships = ctx->phrase_memberships;
81 phrase_array *postal_code_dictionary_phrases = ctx->postal_code_dictionary_phrases;
82 token_array *sub_tokens = ctx->sub_tokens;
83
84 uint32_t i = 0;
85 uint32_t j = 0;
86
87 char *normalized;
88 char *phrase;
89
90 char *label;
91 char *prev_label;
92
93 const char *token;
94
95 char *str = tokenized_str->str;
96 token_array *tokens = tokenized_str->tokens;
97
98 prev_label = NULL;
99
100 size_t num_strings = cstring_array_num_strings(tokenized_str->strings);
101
102 cstring_array_clear(phrases);
103 cstring_array_clear(phrase_labels);
104
105 bool is_admin = false;
106 bool is_postal = false;
107 bool have_postal_code = false;
108
109 bool last_was_separator = false;
110
111 int64_array_clear(phrase_memberships);
112 phrase_array_clear(dictionary_phrases);
113 char_array_clear(postal_code_token_builder);
114
115 // One specific case where "CP" or "CEP" can be concatenated onto the front of the token
116 bool have_dictionary_phrases = search_address_dictionaries_tokens_with_phrases(tokenized_str->str, tokenized_str->tokens, language, &dictionary_phrases);
117 token_phrase_memberships(dictionary_phrases, phrase_memberships, tokenized_str->tokens->n);
118
119 cstring_array_foreach(tokenized_str->strings, i, token, {
120 token_t t = tokens->a[i];
121
122 label = cstring_array_get_string(data_set->labels, i);
123 if (label == NULL) {
124 continue;
125 }
126
127 char_array_clear(token_builder);
128
129 is_admin = is_admin_component(label);
130 is_postal = !is_admin && is_postal_code(label);
131
132 uint64_t normalize_token_options = ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS;
133
134 if (is_admin || is_postal) {
135 normalize_token_options = ADDRESS_PARSER_NORMALIZE_ADMIN_TOKEN_OPTIONS;
136 }
137
138 add_normalized_token(token_builder, str, t, normalize_token_options);
139 if (token_builder->n == 0) {
140 continue;
141 }
142
143 normalized = char_array_get_string(token_builder);
144
145 int64_t phrase_membership = NULL_PHRASE_MEMBERSHIP;
146
147 if (!is_admin && !is_postal) {
148 // Check if this is a (potentially multi-word) dictionary phrase
149 phrase_membership = phrase_memberships->a[i];
150 if (phrase_membership != NULL_PHRASE_MEMBERSHIP) {
151 phrase_t current_phrase = dictionary_phrases->a[phrase_membership];
152
153 if (current_phrase.start == i) {
154 char_array_clear(phrase_builder);
155 char *first_label = label;
156 bool invalid_phrase = false;
157 // On the start of every phrase, check that all its tokens have the
158 // same label, otherwise set to memberships to the null phrase
159 for (j = current_phrase.start + 1; j < current_phrase.start + current_phrase.len; j++) {
160 char *token_label = cstring_array_get_string(data_set->labels, j);
161 if (!string_equals(token_label, first_label)) {
162 for (j = current_phrase.start; j < current_phrase.start + current_phrase.len; j++) {
163 phrase_memberships->a[j] = NULL_PHRASE_MEMBERSHIP;
164 }
165 invalid_phrase = true;
166 break;
167 }
168 }
169 // If the phrase was invalid, add the single word
170 if (invalid_phrase) {
171 cstring_array_add_string(phrases, normalized);
172 cstring_array_add_string(phrase_labels, label);
173 }
174 }
175 // If we're in a valid phrase, add the current word to the phrase
176 char_array_cat(phrase_builder, normalized);
177 if (i < current_phrase.start + current_phrase.len - 1) {
178 char_array_cat(phrase_builder, " ");
179 } else {
180 // If we're at the end of a phrase, add entire phrase as a string
181 normalized = char_array_get_string(phrase_builder);
182 cstring_array_add_string(phrases, normalized);
183 cstring_array_add_string(phrase_labels, label);
184 }
185
186 } else {
187
188 cstring_array_add_string(phrases, normalized);
189 cstring_array_add_string(phrase_labels, label);
190 }
191
192 prev_label = NULL;
193
194 continue;
195 }
196
197 if (is_postal) {
198 add_normalized_token(postal_code_token_builder, str, t, ADDRESS_PARSER_NORMALIZE_POSTAL_CODE_TOKEN_OPTIONS);
199 char *postal_code_normalized = char_array_get_string(postal_code_token_builder);
200
201 token_array_clear(sub_tokens);
202 phrase_array_clear(postal_code_dictionary_phrases);
203 tokenize_add_tokens(sub_tokens, postal_code_normalized, strlen(postal_code_normalized), false);
204
205 // One specific case where "CP" or "CEP" can be concatenated onto the front of the token
206 if (sub_tokens->n > 1 && search_address_dictionaries_tokens_with_phrases(postal_code_normalized, sub_tokens, language, &postal_code_dictionary_phrases) && postal_code_dictionary_phrases->n > 0) {
207 phrase_t first_postal_code_phrase = postal_code_dictionary_phrases->a[0];
208 address_expansion_value_t *value = address_dictionary_get_expansions(first_postal_code_phrase.data);
209 if (value != NULL && value->components & LIBPOSTAL_ADDRESS_POSTAL_CODE) {
210 char_array_clear(token_builder);
211 size_t first_real_token_index = first_postal_code_phrase.start + first_postal_code_phrase.len;
212 token_t first_real_token = sub_tokens->a[first_real_token_index];
213 char_array_cat(token_builder, postal_code_normalized + first_real_token.offset);
214 normalized = char_array_get_string(token_builder);
215 }
216 }
217 }
218
219 bool last_was_postal = string_equals(prev_label, ADDRESS_PARSER_LABEL_POSTAL_CODE);
220 bool same_as_previous_label = string_equals(label, prev_label) && (!last_was_separator || last_was_postal);
221
222 if (prev_label == NULL || !same_as_previous_label || i == num_strings - 1) {
223 if (i == num_strings - 1 && (same_as_previous_label || prev_label == NULL)) {
224 if (prev_label != NULL) {
225 char_array_cat(phrase_builder, " ");
226 }
227
228 char_array_cat(phrase_builder, normalized);
229 }
230
231 // End of phrase, add to hashtable
232 if (prev_label != NULL) {
233
234 phrase = char_array_get_string(phrase_builder);
235
236 if (last_was_postal) {
237 token_array_clear(sub_tokens);
238 phrase_array_clear(dictionary_phrases);
239
240 tokenize_add_tokens(sub_tokens, phrase, strlen(phrase), false);
241
242 if (sub_tokens->n > 0 && search_address_dictionaries_tokens_with_phrases(phrase, sub_tokens, language, &dictionary_phrases) && dictionary_phrases->n > 0) {
243 char_array_clear(sub_token_builder);
244
245 phrase_t current_phrase = NULL_PHRASE;
246 phrase_t prev_phrase = NULL_PHRASE;
247 token_t current_sub_token;
248
249 for (size_t pc = 0; pc < dictionary_phrases->n; pc++) {
250 current_phrase = dictionary_phrases->a[pc];
251
252 address_expansion_value_t *phrase_value = address_dictionary_get_expansions(current_phrase.data);
253 size_t current_phrase_end = current_phrase.start + current_phrase.len;
254 if (phrase_value != NULL && phrase_value->components & LIBPOSTAL_ADDRESS_POSTAL_CODE) {
255 current_phrase_end = current_phrase.start;
256 }
257
258 for (size_t j = prev_phrase.start + prev_phrase.len; j < current_phrase_end; j++) {
259 current_sub_token = sub_tokens->a[j];
260
261 char_array_cat_len(sub_token_builder, phrase + current_sub_token.offset, current_sub_token.len);
262
263 if (j < sub_tokens->n - 1) {
264 char_array_cat(sub_token_builder, " ");
265 }
266 }
267 prev_phrase = current_phrase;
268 }
269
270 if (prev_phrase.len > 0) {
271 for (size_t j = prev_phrase.start + prev_phrase.len; j < sub_tokens->n; j++) {
272 current_sub_token = sub_tokens->a[j];
273
274 char_array_cat_len(sub_token_builder, phrase + current_sub_token.offset, current_sub_token.len);
275
276 if (j < sub_tokens->n - 1) {
277 char_array_cat(sub_token_builder, " ");
278 }
279 }
280 }
281
282 phrase = char_array_get_string(sub_token_builder);
283 }
284 }
285
286 cstring_array_add_string(phrases, phrase);
287 cstring_array_add_string(phrase_labels, prev_label);
288
289 }
290
291 if (i == num_strings - 1 && !same_as_previous_label && prev_label != NULL) {
292 cstring_array_add_string(phrases, normalized);
293 cstring_array_add_string(phrase_labels, label);
294 }
295
296 char_array_clear(phrase_builder);
297 } else if (prev_label != NULL) {
298 char_array_cat(phrase_builder, " ");
299 }
300
301 char_array_cat(phrase_builder, normalized);
302
303 prev_label = label;
304
305 last_was_separator = data_set->separators->a[i] == ADDRESS_SEPARATOR_FIELD_INTERNAL;
306
307 })
308
309 return true;
310 }
311
address_parser_init(char * filename)312 address_parser_t *address_parser_init(char *filename) {
313 if (filename == NULL) {
314 log_error("Filename was NULL\n");
315 return NULL;
316 }
317
318 address_parser_data_set_t *data_set = address_parser_data_set_init(filename);
319
320 if (data_set == NULL) {
321 log_error("Error initializing data set\n");
322 return NULL;
323 }
324
325 address_parser_t *parser = address_parser_new();
326 if (parser == NULL) {
327 log_error("Error allocating parser\n");
328 return NULL;
329 }
330
331 address_parser_context_t *context = address_parser_context_new();
332 if (context == NULL) {
333 log_error("Error allocating context\n");
334 return NULL;
335 }
336 parser->context = context;
337
338 khash_t(str_uint32) *vocab = kh_init(str_uint32);
339 if (vocab == NULL) {
340 log_error("Could not allocate vocab\n");
341 return NULL;
342 }
343
344 khash_t(str_uint32) *phrase_counts = kh_init(str_uint32);
345 if (vocab == NULL) {
346 log_error("Could not allocate vocab\n");
347 return NULL;
348 }
349
350 khash_t(str_uint32) *class_counts = kh_init(str_uint32);
351 if (class_counts == NULL) {
352 log_error("Could not allocate class_counts\n");
353 return NULL;
354 }
355
356 khash_t(phrase_stats) *phrase_stats = kh_init(phrase_stats);
357 if (phrase_stats == NULL) {
358 log_error("Could not allocate phrase_stats\n");
359 return NULL;
360 }
361
362 khash_t(phrase_types) *phrase_types = kh_init(phrase_types);
363 if (phrase_types == NULL) {
364 log_error("Could not allocate phrase_types\n");
365 return NULL;
366 }
367
368 khash_t(str_uint32) *postal_code_counts = kh_init(str_uint32);
369 if (postal_code_counts == NULL) {
370 log_error("Could not allocate postal_code_counts\n");
371 return NULL;
372 }
373
374 khash_t(postal_code_context_phrases) *postal_code_admin_contexts = kh_init(postal_code_context_phrases);
375 if (postal_code_admin_contexts == NULL) {
376 log_error("Could not allocate postal_code_admin_contexts\n");
377 return NULL;
378 }
379
380 khiter_t k;
381 char *str;
382
383 uint32_t i, j;
384
385 phrase_stats_t stats;
386 khash_t(int_uint32) *place_class_counts;
387
388 size_t examples = 0;
389
390 const char *token;
391 char *normalized;
392 uint32_t count;
393
394 char *key;
395 int ret = 0;
396
397 postal_code_context_value_t pc_ctx;
398
399 bool is_postal = false;
400
401 char *label;
402 char *prev_label;
403
404 vocab_context_t *vocab_context = malloc(sizeof(vocab_context_t));
405 if (vocab_context == NULL) {
406 log_error("Error allocationg vocab_context\n");
407 return NULL;
408 }
409
410 vocab_context->token_builder = char_array_new();
411 vocab_context->postal_code_token_builder = char_array_new();
412 vocab_context->sub_token_builder = char_array_new();
413 vocab_context->phrase_builder = char_array_new();
414 vocab_context->dictionary_phrases = phrase_array_new();
415 vocab_context->phrase_memberships = int64_array_new();
416 vocab_context->postal_code_dictionary_phrases = phrase_array_new();
417 vocab_context->sub_tokens = token_array_new();
418
419 if (vocab_context->token_builder == NULL ||
420 vocab_context->postal_code_token_builder == NULL ||
421 vocab_context->sub_token_builder == NULL ||
422 vocab_context->phrase_builder == NULL ||
423 vocab_context->dictionary_phrases == NULL ||
424 vocab_context->phrase_memberships == NULL ||
425 vocab_context->postal_code_dictionary_phrases == NULL ||
426 vocab_context->sub_tokens == NULL) {
427 log_error("Error initializing vocab_context\n");
428 return NULL;
429 }
430
431 cstring_array *phrases = cstring_array_new();
432 cstring_array *phrase_labels = cstring_array_new();
433
434 if (phrases == NULL || phrase_labels == NULL) {
435 log_error("Error setting up arrays for vocab building\n");
436 return NULL;
437 }
438
439 char *phrase;
440
441 trie_t *phrase_counts_trie = NULL;
442
443 tokenized_string_t *tokenized_str;
444 token_array *tokens;
445
446 while (address_parser_data_set_next(data_set)) {
447 tokenized_str = data_set->tokenized_str;
448
449 if (tokenized_str == NULL) {
450 log_error("tokenized str is NULL\n");
451 goto exit_hashes_allocated;
452 }
453
454 if (!address_phrases_and_labels(data_set, phrases, phrase_labels, vocab_context)) {
455 log_error("Error in address phrases and labels\n");
456 goto exit_hashes_allocated;
457 }
458
459 // Iterate through one time to see if there is a postal code in the string
460 bool have_postal_code = false;
461 char *postal_code_phrase = NULL;
462
463 cstring_array_foreach(phrases, i, phrase, {
464 if (phrase == NULL) continue;
465 char *phrase_label = cstring_array_get_string(phrase_labels, i);
466
467 if (is_postal_code(phrase_label)) {
468 have_postal_code = true;
469 postal_code_phrase = phrase;
470 break;
471 }
472 })
473
474 cstring_array_foreach(phrase_labels, i, label, {
475 if (!str_uint32_hash_incr(class_counts, label)) {
476 log_error("Error in hash_incr for class_counts\n");
477 goto exit_hashes_allocated;
478 }
479 })
480
481 cstring_array_foreach(phrases, i, phrase, {
482 if (phrase == NULL) continue;
483
484 uint32_t class_id;
485 uint32_t component = 0;
486
487 char *phrase_label = cstring_array_get_string(phrase_labels, i);
488 if (phrase_label == NULL) continue;
489
490 is_postal = false;
491
492 // Too many variations on these
493 if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_CITY)) {
494 class_id = ADDRESS_PARSER_BOUNDARY_CITY;
495 component = ADDRESS_COMPONENT_CITY;
496 } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_STATE)) {
497 class_id = ADDRESS_PARSER_BOUNDARY_STATE;
498 component = ADDRESS_COMPONENT_STATE;
499 } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_COUNTRY)) {
500 class_id = ADDRESS_PARSER_BOUNDARY_COUNTRY;
501 component = ADDRESS_COMPONENT_COUNTRY;
502 } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_POSTAL_CODE)) {
503 is_postal = true;
504
505 char_array *token_builder = vocab_context->token_builder;
506 token_array *sub_tokens = vocab_context->sub_tokens;
507 tokenize_add_tokens(sub_tokens, phrase, strlen(phrase), false);
508
509 char_array_clear(token_builder);
510
511 for (j = 0; j < sub_tokens->n; j++) {
512 token_array_clear(sub_tokens);
513 token_t t = sub_tokens->a[j];
514 add_normalized_token(token_builder, phrase, t, ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS);
515
516 if (token_builder->n == 0) {
517 continue;
518 }
519
520 char *sub_token = char_array_get_string(token_builder);
521 if (!str_uint32_hash_incr(vocab, sub_token)) {
522 log_error("Error in str_uint32_hash_incr\n");
523 goto exit_hashes_allocated;
524 }
525
526 }
527
528 } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_COUNTRY_REGION)) {
529 class_id = ADDRESS_PARSER_BOUNDARY_COUNTRY_REGION;
530 component = ADDRESS_COMPONENT_COUNTRY_REGION;
531 } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_STATE_DISTRICT)) {
532 class_id = ADDRESS_PARSER_BOUNDARY_STATE_DISTRICT;
533 component = ADDRESS_COMPONENT_STATE_DISTRICT;
534 } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_SUBURB)) {
535 class_id = ADDRESS_PARSER_BOUNDARY_SUBURB;
536 component = ADDRESS_COMPONENT_SUBURB;
537 } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_CITY_DISTRICT)) {
538 class_id = ADDRESS_PARSER_BOUNDARY_CITY_DISTRICT;
539 component = ADDRESS_COMPONENT_CITY_DISTRICT;
540 } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_WORLD_REGION)) {
541 class_id = ADDRESS_PARSER_BOUNDARY_WORLD_REGION;
542 component = ADDRESS_COMPONENT_WORLD_REGION;
543 } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_ISLAND)) {
544 class_id = ADDRESS_PARSER_BOUNDARY_ISLAND;
545 component = ADDRESS_COMPONENT_ISLAND;
546 } else {
547 bool in_vocab = false;
548 if (!str_uint32_hash_incr_exists(vocab, phrase, &in_vocab)) {
549 log_error("Error in str_uint32_hash_incr\n");
550 goto exit_hashes_allocated;
551 }
552 continue;
553 }
554
555 char *normalized_phrase = NULL;
556
557 if (!is_postal && string_contains_hyphen(phrase)) {
558 normalized_phrase = normalize_string_utf8(phrase, NORMALIZE_STRING_REPLACE_HYPHENS);
559 }
560
561 char *phrases[2];
562 phrases[0] = phrase;
563 phrases[1] = normalized_phrase;
564
565 for (size_t p_i = 0; p_i < sizeof(phrases) / sizeof(char *); p_i++) {
566 phrase = phrases[p_i];
567 if (phrase == NULL) continue;
568
569 if (is_postal) {
570 if (!str_uint32_hash_incr(postal_code_counts, phrase)) {
571 log_error("Error in str_uint32_hash_incr for postal_code_counts\n");
572 goto exit_hashes_allocated;
573 }
574 continue;
575 }
576
577 if (have_postal_code && !is_postal) {
578 khash_t(str_set) *context_postal_codes = NULL;
579
580 k = kh_get(postal_code_context_phrases, postal_code_admin_contexts, postal_code_phrase);
581 if (k == kh_end(postal_code_admin_contexts)) {
582 key = strdup(postal_code_phrase);
583 ret = 0;
584 k = kh_put(postal_code_context_phrases, postal_code_admin_contexts, key, &ret);
585
586 if (ret < 0) {
587 log_error("Error in kh_put in postal_code_admin_contexts\n");
588 free(key);
589 goto exit_hashes_allocated;
590 }
591 context_postal_codes = kh_init(str_set);
592 if (context_postal_codes == NULL) {
593 log_error("Error in kh_init for context_postal_codes\n");
594 free(key);
595 goto exit_hashes_allocated;
596 }
597 kh_value(postal_code_admin_contexts, k) = context_postal_codes;
598 } else {
599 context_postal_codes = kh_value(postal_code_admin_contexts, k);
600 }
601
602 k = kh_get(str_set, context_postal_codes, phrase);
603 if (k == kh_end(context_postal_codes)) {
604 char *context_key = strdup(phrase);
605 k = kh_put(str_set, context_postal_codes, context_key, &ret);
606 if (ret < 0) {
607 log_error("Error in kh_put in context_postal_codes\n");
608 free(context_key);
609 goto exit_hashes_allocated;
610 }
611 }
612 }
613
614 k = kh_get(phrase_stats, phrase_stats, phrase);
615
616 if (k == kh_end(phrase_stats)) {
617 key = strdup(phrase);
618 ret = 0;
619 k = kh_put(phrase_stats, phrase_stats, key, &ret);
620 if (ret < 0) {
621 log_error("Error in kh_put in phrase_stats\n");
622 free(key);
623 goto exit_hashes_allocated;
624 }
625 place_class_counts = kh_init(int_uint32);
626
627 stats.class_counts = place_class_counts;
628 stats.components = component;
629
630 kh_value(phrase_stats, k) = stats;
631 } else {
632 stats = kh_value(phrase_stats, k);
633 place_class_counts = stats.class_counts;
634 stats.components |= component;
635 kh_value(phrase_stats, k) = stats;
636 }
637
638 if (!int_uint32_hash_incr(place_class_counts, (khint_t)class_id)) {
639 log_error("Error in int_uint32_hash_incr in class_counts\n");
640 goto exit_hashes_allocated;
641 }
642
643 if (!str_uint32_hash_incr(phrase_counts, phrase)) {
644 log_error("Error in str_uint32_hash_incr in phrase_counts\n");
645 goto exit_hashes_allocated;
646 }
647
648 }
649
650 if (normalized_phrase != NULL) {
651 free(normalized_phrase);
652 normalized_phrase = NULL;
653 }
654
655 })
656
657 tokenized_string_destroy(tokenized_str);
658 examples++;
659 if (examples % 10000 == 0 && examples != 0) {
660 log_info("Counting vocab: did %zu examples\n", examples);
661 }
662
663 }
664
665 log_info("Done with vocab, total size=%" PRIkh32 "\n", kh_size(vocab));
666
667 for (k = kh_begin(vocab); k != kh_end(vocab); ++k) {
668 token = (char *)kh_key(vocab, k);
669 if (!kh_exist(vocab, k)) {
670 continue;
671 }
672 uint32_t count = kh_value(vocab, k);
673 if (count < MIN_VOCAB_COUNT) {
674 kh_del(str_uint32, vocab, k);
675 free((char *)token);
676 }
677 }
678
679 log_info("After pruning vocab size=%" PRIkh32 "\n", kh_size(vocab));
680
681
682 log_info("Creating phrases trie\n");
683
684
685 phrase_counts_trie = trie_new_from_hash(phrase_counts);
686
687 log_info("Calculating phrase types\n");
688
689 size_t num_classes = kh_size(class_counts);
690 log_info("num_classes = %zu\n", num_classes);
691 parser->num_classes = num_classes;
692
693 log_info("Creating vocab trie\n");
694
695 parser->vocab = trie_new_from_hash(vocab);
696 if (parser->vocab == NULL) {
697 log_error("Error initializing vocabulary\n");
698 address_parser_destroy(parser);
699 parser = NULL;
700 goto exit_hashes_allocated;
701 }
702
703 kh_foreach(phrase_counts, token, count, {
704 if (!str_uint32_hash_incr_by(vocab, token, count)) {
705 log_error("Error adding phrases to vocabulary\n");
706 address_parser_destroy(parser);
707 parser = NULL;
708 goto exit_hashes_allocated;
709 }
710 })
711
712 kh_foreach(postal_code_counts, token, count, {
713 if (!str_uint32_hash_incr_by(vocab, token, count)) {
714 log_error("Error adding postal_codes to vocabulary\n");
715 address_parser_destroy(parser);
716 parser = NULL;
717 goto exit_hashes_allocated;
718 }
719 })
720
721 size_t hash_size;
722 const char *context_token;
723 bool sort_reverse = true;
724
725 log_info("Creating phrase_types trie\n");
726
727 sort_reverse = true;
728 char **phrase_keys = str_uint32_hash_sort_keys_by_value(phrase_counts, sort_reverse);
729 if (phrase_keys == NULL) {
730 log_error("phrase_keys == NULL\n");
731 address_parser_destroy(parser);
732 parser = NULL;
733 goto exit_hashes_allocated;
734 }
735
736 hash_size = kh_size(phrase_counts);
737 address_parser_types_array *phrase_types_array = address_parser_types_array_new_size(hash_size);
738
739 for (size_t idx = 0; idx < hash_size; idx++) {
740 char *phrase_key = phrase_keys[idx];
741 khiter_t pk = kh_get(str_uint32, phrase_counts, phrase_key);
742 if (pk == kh_end(phrase_counts)) {
743 log_error("Key %zu did not exist in phrase_counts: %s\n", idx, phrase_key);
744 address_parser_destroy(parser);
745 parser = NULL;
746 goto exit_hashes_allocated;
747 }
748
749 uint32_t phrase_count = kh_value(phrase_counts, pk);
750 if (phrase_count < MIN_PHRASE_COUNT) {
751 token = (char *)kh_key(phrase_counts, pk);
752 kh_del(str_uint32, phrase_counts, pk);
753 free((char *)token);
754 continue;
755 }
756
757 k = kh_get(phrase_stats, phrase_stats, phrase_key);
758
759 if (k == kh_end(phrase_stats)) {
760 log_error("Key %zu did not exist in phrase_stats: %s\n", idx, phrase_key);
761 address_parser_destroy(parser);
762 parser = NULL;
763 goto exit_hashes_allocated;
764 }
765
766 stats = kh_value(phrase_stats, k);
767
768 place_class_counts = stats.class_counts;
769 int32_t most_common = -1;
770 uint32_t max_count = 0;
771 uint32_t total = 0;
772 for (uint32_t i = 0; i < NUM_ADDRESS_PARSER_BOUNDARY_TYPES; i++) {
773 k = kh_get(int_uint32, place_class_counts, (khint_t)i);
774 if (k != kh_end(place_class_counts)) {
775 count = kh_value(place_class_counts, k);
776
777 if (count > max_count) {
778 max_count = count;
779 most_common = i;
780 }
781 total += count;
782 }
783 }
784
785 if (most_common > -1) {
786 address_parser_types_t types;
787 types.components = stats.components;
788 types.most_common = (uint16_t)most_common;
789
790 kh_value(phrase_counts, pk) = (uint32_t)phrase_types_array->n;
791 address_parser_types_array_push(phrase_types_array, types);
792 }
793 }
794
795 if (phrase_keys != NULL) {
796 free(phrase_keys);
797 }
798
799 log_info("Creating phrases trie\n");
800
801 parser->phrases = trie_new_from_hash(phrase_counts);
802 if (parser->phrases == NULL) {
803 log_error("Error converting phrase_counts to trie\n");
804 address_parser_destroy(parser);
805 parser = NULL;
806 goto exit_hashes_allocated;
807 }
808
809 if (phrase_types_array == NULL) {
810 log_error("phrase_types_array is NULL\n");
811 address_parser_destroy(parser);
812 parser = NULL;
813 goto exit_hashes_allocated;
814 }
815
816 parser->phrase_types = phrase_types_array;
817
818 char **postal_code_keys = str_uint32_hash_sort_keys_by_value(postal_code_counts, true);
819 if (postal_code_keys == NULL) {
820 log_error("postal_code_keys == NULL\n");
821 free(phrase_keys);
822 address_parser_destroy(parser);
823 parser = NULL;
824 goto exit_hashes_allocated;
825 }
826
827 log_info("Creating postal codes trie\n");
828
829 hash_size = kh_size(postal_code_counts);
830 for (size_t idx = 0; idx < hash_size; idx++) {
831 char *phrase_key = postal_code_keys[idx];
832
833 k = kh_get(str_uint32, postal_code_counts, phrase_key);
834 if (k == kh_end(postal_code_counts)) {
835 log_error("Key %zu did not exist in postal_code_counts: %s\n", idx, phrase_key);
836 address_parser_destroy(parser);
837 parser = NULL;
838 goto exit_hashes_allocated;
839 }
840 uint32_t pc_count = kh_value(postal_code_counts, k);
841 kh_value(postal_code_counts, k) = (uint32_t)idx;
842 }
843
844 if (postal_code_keys != NULL) {
845 free(postal_code_keys);
846 }
847
848 parser->postal_codes = trie_new_from_hash(postal_code_counts);
849 if (parser->postal_codes == NULL) {
850 log_error("Error converting postal_code_counts to trie\n");
851 address_parser_destroy(parser);
852 parser = NULL;
853 goto exit_hashes_allocated;
854 }
855
856 log_info("Building postal code contexts\n");
857
858 bool fixed_rows = false;
859 graph_builder_t *postal_code_contexts_builder = graph_builder_new(GRAPH_BIPARTITE, fixed_rows);
860
861 uint32_t postal_code_id;
862 uint32_t context_phrase_id;
863
864 khash_t(str_set) *context_phrases;
865
866 kh_foreach(postal_code_admin_contexts, token, context_phrases, {
867 if (!trie_get_data(parser->postal_codes, (char *)token, &postal_code_id)) {
868 log_error("Key %s did not exist in parser->postal_codes\n", (char *)token);
869 address_parser_destroy(parser);
870 parser = NULL;
871 goto exit_hashes_allocated;
872 }
873 kh_foreach_key(context_phrases, context_token, {
874 if (!trie_get_data(parser->phrases, (char *)context_token, &context_phrase_id)) {
875 log_error("Key %s did not exist in phrases trie\n", (char *)context_token);
876 address_parser_destroy(parser);
877 parser = NULL;
878 goto exit_hashes_allocated;
879 }
880
881 graph_builder_add_edge(postal_code_contexts_builder, postal_code_id, context_phrase_id);
882 })
883 })
884
885 bool sort_edges = true;
886 bool remove_duplicates = true;
887 graph_t *postal_code_contexts = graph_builder_finalize(postal_code_contexts_builder, sort_edges, remove_duplicates);
888
889 // NOTE: don't destroy this during deallocation
890 if (postal_code_contexts == NULL) {
891 log_error("postal_code_contexts is NULL\n");
892 address_parser_destroy(parser);
893 parser = NULL;
894 goto exit_hashes_allocated;
895 }
896 parser->postal_code_contexts = postal_code_contexts;
897
898 log_info("Freeing memory from initialization\n");
899
900 exit_hashes_allocated:
901 // Free memory for hashtables, etc.
902 if (vocab_context != NULL) {
903 char_array_destroy(vocab_context->token_builder);
904 char_array_destroy(vocab_context->postal_code_token_builder);
905 char_array_destroy(vocab_context->sub_token_builder);
906 char_array_destroy(vocab_context->phrase_builder);
907 phrase_array_destroy(vocab_context->dictionary_phrases);
908 int64_array_destroy(vocab_context->phrase_memberships);
909 phrase_array_destroy(vocab_context->postal_code_dictionary_phrases);
910 token_array_destroy(vocab_context->sub_tokens);
911 free(vocab_context);
912 }
913
914 cstring_array_destroy(phrases);
915 cstring_array_destroy(phrase_labels);
916
917 address_parser_data_set_destroy(data_set);
918
919 if (phrase_counts_trie != NULL) {
920 trie_destroy(phrase_counts_trie);
921 }
922
923 kh_foreach_key(vocab, token, {
924 free((char *)token);
925 })
926 kh_destroy(str_uint32, vocab);
927
928 kh_foreach_key(class_counts, token, {
929 free((char *)token);
930 })
931 kh_destroy(str_uint32, class_counts);
932
933 kh_foreach(phrase_stats, token, stats, {
934 kh_destroy(int_uint32, stats.class_counts);
935 free((char *)token);
936 })
937
938 kh_destroy(phrase_stats, phrase_stats);
939
940 kh_foreach_key(phrase_counts, token, {
941 free((char *)token);
942 })
943
944 kh_destroy(str_uint32, phrase_counts);
945
946 kh_foreach_key(phrase_types, token, {
947 free((char *)token);
948 })
949 kh_destroy(phrase_types, phrase_types);
950
951 khash_t(str_set) *pc_set;
952
953 kh_foreach(postal_code_admin_contexts, token, pc_set, {
954 if (pc_set != NULL) {
955 kh_foreach_key(pc_set, context_token, {
956 free((char *)context_token);
957 })
958 kh_destroy(str_set, pc_set);
959 }
960 free((char *)token);
961 })
962
963 kh_destroy(postal_code_context_phrases, postal_code_admin_contexts);
964
965 kh_foreach_key(postal_code_counts, token, {
966 free((char *)token);
967 })
968 kh_destroy(str_uint32, postal_code_counts);
969
970 return parser;
971 }
972
address_parser_train_example(address_parser_t * self,void * trainer,address_parser_context_t * context,address_parser_data_set_t * data_set)973 static inline bool address_parser_train_example(address_parser_t *self, void *trainer, address_parser_context_t *context, address_parser_data_set_t *data_set) {
974 if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
975 return averaged_perceptron_trainer_train_example((averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
976 } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
977 return crf_averaged_perceptron_trainer_train_example((crf_averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
978 } else {
979 log_error("Parser model is of unknown type\n");
980 }
981 return false;
982 }
983
address_parser_trainer_destroy(address_parser_t * self,void * trainer)984 static inline void address_parser_trainer_destroy(address_parser_t *self, void *trainer) {
985 if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
986 averaged_perceptron_trainer_destroy((averaged_perceptron_trainer_t *)trainer);
987 } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
988 crf_averaged_perceptron_trainer_destroy((crf_averaged_perceptron_trainer_t *)trainer);
989 }
990 }
991
address_parser_finalize_model(address_parser_t * self,void * trainer)992 static inline bool address_parser_finalize_model(address_parser_t *self, void *trainer) {
993 if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
994 self->model.ap = averaged_perceptron_trainer_finalize((averaged_perceptron_trainer_t *)trainer);
995 return self->model.ap != NULL;
996 } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
997 self->model.crf = crf_averaged_perceptron_trainer_finalize((crf_averaged_perceptron_trainer_t *)trainer);
998 return self->model.crf != NULL;
999 } else {
1000 log_error("Parser model is of unknown type\n");
1001 }
1002 return false;
1003 }
1004
address_parser_train_num_iterations(address_parser_t * self,void * trainer)1005 static inline uint32_t address_parser_train_num_iterations(address_parser_t *self, void *trainer) {
1006 if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
1007 averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
1008 return ap_trainer->iterations;
1009 } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
1010 crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
1011 return crf_trainer->iterations;
1012 } else {
1013 log_error("Parser model is of unknown type\n");
1014 }
1015 return 0;
1016 }
1017
address_parser_train_set_iterations(address_parser_t * self,void * trainer,uint32_t iterations)1018 static inline void address_parser_train_set_iterations(address_parser_t *self, void *trainer, uint32_t iterations) {
1019 if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
1020 averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
1021 ap_trainer->iterations = iterations;
1022 } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
1023 crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
1024 crf_trainer->iterations = iterations;
1025 } else {
1026 log_error("Parser model is of unknown type\n");
1027 }
1028 }
1029
address_parser_train_num_errors(address_parser_t * self,void * trainer)1030 static inline uint64_t address_parser_train_num_errors(address_parser_t *self, void *trainer) {
1031 if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
1032 averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
1033 return ap_trainer->num_updates;
1034 } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
1035 crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
1036 return crf_trainer->num_updates;
1037 } else {
1038 log_error("Parser model is of unknown type\n");
1039 }
1040 return 0;
1041 }
1042
address_parser_train_epoch(address_parser_t * self,void * trainer,char * filename)1043 bool address_parser_train_epoch(address_parser_t *self, void *trainer, char *filename) {
1044 if (filename == NULL) {
1045 log_error("Filename was NULL\n");
1046 return false;
1047 }
1048
1049 address_parser_data_set_t *data_set = address_parser_data_set_init(filename);
1050 if (data_set == NULL) {
1051 log_error("Error initializing data set\n");
1052 return false;
1053 }
1054
1055 address_parser_context_t *context = self->context;
1056
1057 size_t examples = 0;
1058 uint64_t errors = address_parser_train_num_errors(self, trainer);
1059
1060 uint32_t iteration = address_parser_train_num_iterations(self, trainer);
1061
1062 bool logged = false;
1063
1064 while (address_parser_data_set_next(data_set)) {
1065 char *language = char_array_get_string(data_set->language);
1066 if (string_equals(language, UNKNOWN_LANGUAGE) || string_equals(language, AMBIGUOUS_LANGUAGE)) {
1067 language = NULL;
1068 }
1069 char *country = char_array_get_string(data_set->country);
1070
1071 address_parser_context_fill(context, self, data_set->tokenized_str, language, country);
1072
1073 bool example_success = address_parser_train_example(self, trainer, context, data_set);
1074
1075 if (!example_success) {
1076 log_error("Error training example\n");
1077 goto exit_epoch_training_started;
1078 }
1079
1080 tokenized_string_destroy(data_set->tokenized_str);
1081 data_set->tokenized_str = NULL;
1082
1083 if (!example_success) {
1084 log_error("Error training example without country/language\n");
1085 goto exit_epoch_training_started;
1086 }
1087
1088 examples++;
1089 if (examples % 1000 == 0 && examples > 0) {
1090 uint64_t prev_errors = errors;
1091 errors = address_parser_train_num_errors(self, trainer);
1092
1093 log_info("Iter %d: Did %zu examples with %" PRIu64 " errors\n", iteration, examples, errors - prev_errors);
1094 }
1095 }
1096
1097 exit_epoch_training_started:
1098 address_parser_data_set_destroy(data_set);
1099
1100 return true;
1101 }
1102
1103
address_parser_train(address_parser_t * self,char * filename,address_parser_model_type_t model_type,uint32_t num_iterations,size_t min_updates)1104 bool address_parser_train(address_parser_t *self, char *filename, address_parser_model_type_t model_type, uint32_t num_iterations, size_t min_updates) {
1105 self->model_type = model_type;
1106 void *trainer;
1107 if (model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
1108 averaged_perceptron_trainer_t *ap_trainer = averaged_perceptron_trainer_new(min_updates);
1109 trainer = (void *)ap_trainer;
1110 } else if (model_type == ADDRESS_PARSER_TYPE_CRF) {
1111 crf_averaged_perceptron_trainer_t *crf_trainer = crf_averaged_perceptron_trainer_new(self->num_classes, min_updates);
1112 trainer = (void *)crf_trainer;
1113 }
1114
1115 for (uint32_t iter = 0; iter < num_iterations; iter++) {
1116 log_info("Doing epoch %d\n", iter);
1117
1118 address_parser_train_set_iterations(self, trainer, iter);
1119
1120 #if defined(HAVE_SHUF) || defined(HAVE_GSHUF)
1121 log_info("Shuffling\n");
1122
1123 if (!shuffle_file_chunked_size(filename, DEFAULT_SHUFFLE_CHUNK_SIZE)) {
1124 log_error("Error in shuffle\n");
1125 address_parser_trainer_destroy(self, trainer);
1126 return false;
1127 }
1128
1129 log_info("Shuffle complete\n");
1130 #endif
1131
1132 if (!address_parser_train_epoch(self, trainer, filename)) {
1133 log_error("Error in epoch\n");
1134 address_parser_trainer_destroy(self, trainer);
1135 return false;
1136 }
1137 }
1138
1139 log_debug("Done with training, averaging weights\n");
1140
1141 if (!address_parser_finalize_model(self, trainer)) {
1142 log_error("model was NULL\n");
1143 return false;
1144 }
1145
1146 return true;
1147 }
1148
1149 typedef enum {
1150 ADDRESS_PARSER_TRAIN_POSITIONAL_ARG,
1151 ADDRESS_PARSER_TRAIN_ARG_ITERATIONS,
1152 ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES,
1153 ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE
1154 } address_parser_train_keyword_arg_t;
1155
1156 #define USAGE "Usage: ./address_parser_train filename output_dir [--iterations number --min-updates number --model (crf|greedyap)]\n"
1157
main(int argc,char ** argv)1158 int main(int argc, char **argv) {
1159 if (argc < 3) {
1160 printf(USAGE);
1161 exit(EXIT_FAILURE);
1162 }
1163
1164 #if !defined(HAVE_SHUF) && !defined(HAVE_GSHUF)
1165 log_warn("shuf must be installed to train address parser effectively. If this is a production machine, please install shuf. No shuffling will be performed.\n");
1166 #endif
1167
1168 int pos_args = 1;
1169
1170 address_parser_train_keyword_arg_t kwarg = ADDRESS_PARSER_TRAIN_POSITIONAL_ARG;
1171
1172 size_t num_iterations = DEFAULT_ITERATIONS;
1173 uint64_t min_updates = DEFAULT_MIN_UPDATES;
1174 size_t position = 0;
1175
1176 ssize_t arg_iterations;
1177 uint64_t arg_min_updates;
1178
1179 char *filename = NULL;
1180 char *output_dir = NULL;
1181
1182 address_parser_model_type_t model_type = DEFAULT_MODEL_TYPE;
1183
1184 for (int i = pos_args; i < argc; i++) {
1185 char *arg = argv[i];
1186
1187 if (string_equals(arg, "--iterations")) {
1188 kwarg = ADDRESS_PARSER_TRAIN_ARG_ITERATIONS;
1189 continue;
1190 }
1191
1192 if (string_equals(arg, "--min-updates")) {
1193 kwarg = ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES;
1194 continue;
1195 }
1196
1197 if (string_equals(arg, "--model")) {
1198 kwarg = ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE;
1199 continue;
1200 }
1201
1202 if (kwarg == ADDRESS_PARSER_TRAIN_ARG_ITERATIONS) {
1203 if (sscanf(arg, "%zd", &arg_iterations) != 1 || arg_iterations < 0) {
1204 log_error("Bad arg for --iterations: %s\n", arg);
1205 exit(EXIT_FAILURE);
1206 }
1207 num_iterations = (size_t)arg_iterations;
1208 } else if (kwarg == ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES) {
1209 if (sscanf(arg, "%llu", &arg_min_updates) != 1) {
1210 log_error("Bad arg for --min-updates: %s\n", arg);
1211 exit(EXIT_FAILURE);
1212 }
1213 min_updates = arg_min_updates;
1214 log_info("min_updates = %" PRIu64 "\n", min_updates);
1215 } else if (kwarg == ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE) {
1216 if (string_equals(arg, "crf")) {
1217 model_type = ADDRESS_PARSER_TYPE_CRF;
1218 } else if (string_equals(arg, "greedyap")) {
1219 model_type = ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON;
1220 } else {
1221 log_error("Bad arg for --model, valid values are [crf, greedyap]\n");
1222 exit(EXIT_FAILURE);
1223 }
1224 } else if (position == 0) {
1225 filename = arg;
1226 position++;
1227 } else if (position == 1) {
1228 output_dir = arg;
1229 position++;
1230 }
1231 kwarg = ADDRESS_PARSER_TRAIN_POSITIONAL_ARG;
1232
1233 }
1234
1235 if (filename == NULL || output_dir == NULL) {
1236 printf(USAGE);
1237 exit(EXIT_FAILURE);
1238 }
1239
1240 if (!address_dictionary_module_setup(NULL)) {
1241 log_error("Could not load address dictionaries\n");
1242 exit(EXIT_FAILURE);
1243 }
1244
1245 log_info("address dictionary module loaded\n");
1246
1247 // Needs to load for normalization
1248 if (!transliteration_module_setup(NULL)) {
1249 log_error("Could not load transliteration module\n");
1250 exit(EXIT_FAILURE);
1251 }
1252
1253 log_info("transliteration module loaded\n");
1254
1255 address_parser_t *parser = address_parser_init(filename);
1256
1257 if (parser == NULL) {
1258 log_error("Could not initialize parser\n");
1259 exit(EXIT_FAILURE);
1260 }
1261
1262 log_info("Finished initialization\n");
1263
1264 if (!address_parser_train(parser, filename, model_type, num_iterations, min_updates)) {
1265 log_error("Error in training\n");
1266 exit(EXIT_FAILURE);
1267 }
1268
1269 log_debug("Finished training\n");
1270
1271 if (!address_parser_save(parser, output_dir)) {
1272 log_error("Error saving address parser\n");
1273 exit(EXIT_FAILURE);
1274 }
1275
1276 address_parser_destroy(parser);
1277
1278 address_dictionary_module_teardown();
1279 log_debug("Done\n");
1280 }
1281