1 #include <stdint.h>
2 
3 #include "address_parser.h"
4 #include "address_parser_io.h"
5 #include "address_dictionary.h"
6 #include "averaged_perceptron_trainer.h"
7 #include "crf_trainer_averaged_perceptron.h"
8 #include "collections.h"
9 #include "constants.h"
10 #include "file_utils.h"
11 #include "graph.h"
12 #include "graph_builder.h"
13 #include "shuffle.h"
14 #include "transliterate.h"
15 
16 #include "log/log.h"
17 
18 typedef struct phrase_stats {
19     khash_t(int_uint32) *class_counts;
20     uint16_t components;
21 } phrase_stats_t;
22 
KHASH_MAP_INIT_STR(phrase_stats,phrase_stats_t)23 KHASH_MAP_INIT_STR(phrase_stats, phrase_stats_t)
24 KHASH_MAP_INIT_STR(postal_code_context_phrases, khash_t(str_set) *)
25 KHASH_MAP_INIT_STR(phrase_types, address_parser_types_t)
26 
27 // Training
28 
29 #define DEFAULT_ITERATIONS 5
30 #define DEFAULT_MIN_UPDATES 5
31 #define DEFAULT_MODEL_TYPE ADDRESS_PARSER_TYPE_CRF
32 
33 #define MIN_VOCAB_COUNT 5
34 #define MIN_PHRASE_COUNT 1
35 
36 static inline bool is_postal_code(char *label) {
37     return string_equals(label, ADDRESS_PARSER_LABEL_POSTAL_CODE);
38 }
39 
is_admin_component(char * label)40 static inline bool is_admin_component(char *label) {
41     return (string_equals(label, ADDRESS_PARSER_LABEL_SUBURB) ||
42            string_equals(label, ADDRESS_PARSER_LABEL_CITY_DISTRICT) ||
43            string_equals(label, ADDRESS_PARSER_LABEL_CITY) ||
44            string_equals(label, ADDRESS_PARSER_LABEL_STATE_DISTRICT) ||
45            string_equals(label, ADDRESS_PARSER_LABEL_ISLAND) ||
46            string_equals(label, ADDRESS_PARSER_LABEL_STATE) ||
47            string_equals(label, ADDRESS_PARSER_LABEL_COUNTRY_REGION) ||
48            string_equals(label, ADDRESS_PARSER_LABEL_COUNTRY) ||
49            string_equals(label, ADDRESS_PARSER_LABEL_WORLD_REGION));
50 }
51 
52 typedef struct vocab_context {
53     char_array *token_builder;
54     char_array *postal_code_token_builder;
55     char_array *sub_token_builder;
56     char_array *phrase_builder;
57     phrase_array *dictionary_phrases;
58     int64_array *phrase_memberships;
59     phrase_array *postal_code_dictionary_phrases;
60     token_array *sub_tokens;
61 } vocab_context_t;
62 
address_phrases_and_labels(address_parser_data_set_t * data_set,cstring_array * phrases,cstring_array * phrase_labels,vocab_context_t * ctx)63 bool address_phrases_and_labels(address_parser_data_set_t *data_set, cstring_array *phrases, cstring_array *phrase_labels, vocab_context_t *ctx) {
64     tokenized_string_t *tokenized_str = data_set->tokenized_str;
65     if (tokenized_str == NULL) {
66         log_error("tokenized_str == NULL\n");
67         return false;
68     }
69 
70     char *language = char_array_get_string(data_set->language);
71     if (string_equals(language, UNKNOWN_LANGUAGE) || string_equals(language, AMBIGUOUS_LANGUAGE)) {
72         language = NULL;
73     }
74 
75     char_array *token_builder = ctx->token_builder;
76     char_array *postal_code_token_builder = ctx->postal_code_token_builder;
77     char_array *sub_token_builder = ctx->sub_token_builder;
78     char_array *phrase_builder = ctx->phrase_builder;
79     phrase_array *dictionary_phrases = ctx->dictionary_phrases;
80     int64_array *phrase_memberships = ctx->phrase_memberships;
81     phrase_array *postal_code_dictionary_phrases = ctx->postal_code_dictionary_phrases;
82     token_array *sub_tokens = ctx->sub_tokens;
83 
84     uint32_t i = 0;
85     uint32_t j = 0;
86 
87     char *normalized;
88     char *phrase;
89 
90     char *label;
91     char *prev_label;
92 
93     const char *token;
94 
95     char *str = tokenized_str->str;
96     token_array *tokens = tokenized_str->tokens;
97 
98     prev_label = NULL;
99 
100     size_t num_strings = cstring_array_num_strings(tokenized_str->strings);
101 
102     cstring_array_clear(phrases);
103     cstring_array_clear(phrase_labels);
104 
105     bool is_admin = false;
106     bool is_postal = false;
107     bool have_postal_code = false;
108 
109     bool last_was_separator = false;
110 
111     int64_array_clear(phrase_memberships);
112     phrase_array_clear(dictionary_phrases);
113     char_array_clear(postal_code_token_builder);
114 
115     // One specific case where "CP" or "CEP" can be concatenated onto the front of the token
116     bool have_dictionary_phrases = search_address_dictionaries_tokens_with_phrases(tokenized_str->str, tokenized_str->tokens, language, &dictionary_phrases);
117     token_phrase_memberships(dictionary_phrases, phrase_memberships, tokenized_str->tokens->n);
118 
119     cstring_array_foreach(tokenized_str->strings, i, token, {
120         token_t t = tokens->a[i];
121 
122         label = cstring_array_get_string(data_set->labels, i);
123         if (label == NULL) {
124             continue;
125         }
126 
127         char_array_clear(token_builder);
128 
129         is_admin = is_admin_component(label);
130         is_postal = !is_admin && is_postal_code(label);
131 
132         uint64_t normalize_token_options = ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS;
133 
134         if (is_admin || is_postal) {
135             normalize_token_options = ADDRESS_PARSER_NORMALIZE_ADMIN_TOKEN_OPTIONS;
136         }
137 
138         add_normalized_token(token_builder, str, t, normalize_token_options);
139         if (token_builder->n == 0) {
140             continue;
141         }
142 
143         normalized = char_array_get_string(token_builder);
144 
145         int64_t phrase_membership = NULL_PHRASE_MEMBERSHIP;
146 
147         if (!is_admin && !is_postal) {
148             // Check if this is a (potentially multi-word) dictionary phrase
149             phrase_membership = phrase_memberships->a[i];
150             if (phrase_membership != NULL_PHRASE_MEMBERSHIP) {
151                 phrase_t current_phrase = dictionary_phrases->a[phrase_membership];
152 
153                 if (current_phrase.start == i) {
154                     char_array_clear(phrase_builder);
155                     char *first_label = label;
156                     bool invalid_phrase = false;
157                     // On the start of every phrase, check that all its tokens have the
158                     // same label, otherwise set to memberships to the null phrase
159                     for (j = current_phrase.start + 1; j < current_phrase.start + current_phrase.len; j++) {
160                         char *token_label = cstring_array_get_string(data_set->labels, j);
161                         if (!string_equals(token_label, first_label)) {
162                             for (j = current_phrase.start; j < current_phrase.start + current_phrase.len; j++) {
163                                 phrase_memberships->a[j] = NULL_PHRASE_MEMBERSHIP;
164                             }
165                             invalid_phrase = true;
166                             break;
167                         }
168                     }
169                     // If the phrase was invalid, add the single word
170                     if (invalid_phrase) {
171                         cstring_array_add_string(phrases, normalized);
172                         cstring_array_add_string(phrase_labels, label);
173                      }
174                 }
175                 // If we're in a valid phrase, add the current word to the phrase
176                 char_array_cat(phrase_builder, normalized);
177                 if (i < current_phrase.start + current_phrase.len - 1) {
178                     char_array_cat(phrase_builder, " ");
179                 } else {
180                     // If we're at the end of a phrase, add entire phrase as a string
181                     normalized = char_array_get_string(phrase_builder);
182                     cstring_array_add_string(phrases, normalized);
183                     cstring_array_add_string(phrase_labels, label);
184                 }
185 
186             } else {
187 
188                 cstring_array_add_string(phrases, normalized);
189                 cstring_array_add_string(phrase_labels, label);
190             }
191 
192             prev_label = NULL;
193 
194             continue;
195         }
196 
197         if (is_postal) {
198             add_normalized_token(postal_code_token_builder, str, t, ADDRESS_PARSER_NORMALIZE_POSTAL_CODE_TOKEN_OPTIONS);
199             char *postal_code_normalized = char_array_get_string(postal_code_token_builder);
200 
201             token_array_clear(sub_tokens);
202             phrase_array_clear(postal_code_dictionary_phrases);
203             tokenize_add_tokens(sub_tokens, postal_code_normalized, strlen(postal_code_normalized), false);
204 
205             // One specific case where "CP" or "CEP" can be concatenated onto the front of the token
206             if (sub_tokens->n > 1 && search_address_dictionaries_tokens_with_phrases(postal_code_normalized, sub_tokens, language, &postal_code_dictionary_phrases) && postal_code_dictionary_phrases->n > 0) {
207                 phrase_t first_postal_code_phrase = postal_code_dictionary_phrases->a[0];
208                 address_expansion_value_t *value = address_dictionary_get_expansions(first_postal_code_phrase.data);
209                 if (value != NULL && value->components & LIBPOSTAL_ADDRESS_POSTAL_CODE) {
210                     char_array_clear(token_builder);
211                     size_t first_real_token_index = first_postal_code_phrase.start + first_postal_code_phrase.len;
212                     token_t first_real_token =  sub_tokens->a[first_real_token_index];
213                     char_array_cat(token_builder, postal_code_normalized + first_real_token.offset);
214                     normalized = char_array_get_string(token_builder);
215                 }
216             }
217         }
218 
219         bool last_was_postal = string_equals(prev_label, ADDRESS_PARSER_LABEL_POSTAL_CODE);
220         bool same_as_previous_label = string_equals(label, prev_label) && (!last_was_separator || last_was_postal);
221 
222         if (prev_label == NULL || !same_as_previous_label || i == num_strings - 1) {
223             if (i == num_strings - 1 && (same_as_previous_label || prev_label == NULL)) {
224                 if (prev_label != NULL) {
225                    char_array_cat(phrase_builder, " ");
226                 }
227 
228                 char_array_cat(phrase_builder, normalized);
229             }
230 
231             // End of phrase, add to hashtable
232             if (prev_label != NULL) {
233 
234                 phrase = char_array_get_string(phrase_builder);
235 
236                 if (last_was_postal) {
237                     token_array_clear(sub_tokens);
238                     phrase_array_clear(dictionary_phrases);
239 
240                     tokenize_add_tokens(sub_tokens, phrase, strlen(phrase), false);
241 
242                     if (sub_tokens->n > 0 && search_address_dictionaries_tokens_with_phrases(phrase, sub_tokens, language, &dictionary_phrases) && dictionary_phrases->n > 0) {
243                         char_array_clear(sub_token_builder);
244 
245                         phrase_t current_phrase = NULL_PHRASE;
246                         phrase_t prev_phrase = NULL_PHRASE;
247                         token_t current_sub_token;
248 
249                         for (size_t pc = 0; pc < dictionary_phrases->n; pc++) {
250                             current_phrase = dictionary_phrases->a[pc];
251 
252                             address_expansion_value_t *phrase_value = address_dictionary_get_expansions(current_phrase.data);
253                             size_t current_phrase_end = current_phrase.start + current_phrase.len;
254                             if (phrase_value != NULL && phrase_value->components & LIBPOSTAL_ADDRESS_POSTAL_CODE) {
255                                 current_phrase_end = current_phrase.start;
256                             }
257 
258                             for (size_t j = prev_phrase.start + prev_phrase.len; j < current_phrase_end; j++) {
259                                 current_sub_token = sub_tokens->a[j];
260 
261                                 char_array_cat_len(sub_token_builder, phrase + current_sub_token.offset, current_sub_token.len);
262 
263                                 if (j < sub_tokens->n - 1) {
264                                     char_array_cat(sub_token_builder, " ");
265                                 }
266                             }
267                             prev_phrase = current_phrase;
268                         }
269 
270                         if (prev_phrase.len > 0) {
271                             for (size_t j = prev_phrase.start + prev_phrase.len; j < sub_tokens->n; j++) {
272                                 current_sub_token = sub_tokens->a[j];
273 
274                                 char_array_cat_len(sub_token_builder, phrase + current_sub_token.offset, current_sub_token.len);
275 
276                                 if (j < sub_tokens->n - 1) {
277                                     char_array_cat(sub_token_builder, " ");
278                                 }
279                             }
280                         }
281 
282                         phrase = char_array_get_string(sub_token_builder);
283                     }
284                 }
285 
286                 cstring_array_add_string(phrases, phrase);
287                 cstring_array_add_string(phrase_labels, prev_label);
288 
289             }
290 
291             if (i == num_strings - 1 && !same_as_previous_label && prev_label != NULL) {
292                 cstring_array_add_string(phrases, normalized);
293                 cstring_array_add_string(phrase_labels, label);
294             }
295 
296             char_array_clear(phrase_builder);
297         } else if (prev_label != NULL) {
298             char_array_cat(phrase_builder, " ");
299         }
300 
301         char_array_cat(phrase_builder, normalized);
302 
303         prev_label = label;
304 
305         last_was_separator = data_set->separators->a[i] == ADDRESS_SEPARATOR_FIELD_INTERNAL;
306 
307     })
308 
309     return true;
310 }
311 
address_parser_init(char * filename)312 address_parser_t *address_parser_init(char *filename) {
313     if (filename == NULL) {
314         log_error("Filename was NULL\n");
315         return NULL;
316     }
317 
318     address_parser_data_set_t *data_set = address_parser_data_set_init(filename);
319 
320     if (data_set == NULL) {
321         log_error("Error initializing data set\n");
322         return NULL;
323     }
324 
325     address_parser_t *parser = address_parser_new();
326     if (parser == NULL) {
327         log_error("Error allocating parser\n");
328         return NULL;
329     }
330 
331     address_parser_context_t *context = address_parser_context_new();
332     if (context == NULL) {
333         log_error("Error allocating context\n");
334         return NULL;
335     }
336     parser->context = context;
337 
338     khash_t(str_uint32) *vocab = kh_init(str_uint32);
339     if (vocab == NULL) {
340         log_error("Could not allocate vocab\n");
341         return NULL;
342     }
343 
344     khash_t(str_uint32) *phrase_counts = kh_init(str_uint32);
345     if (vocab == NULL) {
346         log_error("Could not allocate vocab\n");
347         return NULL;
348     }
349 
350     khash_t(str_uint32) *class_counts = kh_init(str_uint32);
351     if (class_counts == NULL) {
352         log_error("Could not allocate class_counts\n");
353         return NULL;
354     }
355 
356     khash_t(phrase_stats) *phrase_stats = kh_init(phrase_stats);
357     if (phrase_stats == NULL) {
358         log_error("Could not allocate phrase_stats\n");
359         return NULL;
360     }
361 
362     khash_t(phrase_types) *phrase_types = kh_init(phrase_types);
363     if (phrase_types == NULL) {
364         log_error("Could not allocate phrase_types\n");
365         return NULL;
366     }
367 
368     khash_t(str_uint32) *postal_code_counts = kh_init(str_uint32);
369     if (postal_code_counts == NULL) {
370         log_error("Could not allocate postal_code_counts\n");
371         return NULL;
372     }
373 
374     khash_t(postal_code_context_phrases) *postal_code_admin_contexts = kh_init(postal_code_context_phrases);
375     if (postal_code_admin_contexts == NULL) {
376         log_error("Could not allocate postal_code_admin_contexts\n");
377         return NULL;
378     }
379 
380     khiter_t k;
381     char *str;
382 
383     uint32_t i, j;
384 
385     phrase_stats_t stats;
386     khash_t(int_uint32) *place_class_counts;
387 
388     size_t examples = 0;
389 
390     const char *token;
391     char *normalized;
392     uint32_t count;
393 
394     char *key;
395     int ret = 0;
396 
397     postal_code_context_value_t pc_ctx;
398 
399     bool is_postal = false;
400 
401     char *label;
402     char *prev_label;
403 
404     vocab_context_t *vocab_context = malloc(sizeof(vocab_context_t));
405     if (vocab_context == NULL) {
406         log_error("Error allocationg vocab_context\n");
407         return NULL;
408     }
409 
410     vocab_context->token_builder = char_array_new();
411     vocab_context->postal_code_token_builder = char_array_new();
412     vocab_context->sub_token_builder = char_array_new();
413     vocab_context->phrase_builder = char_array_new();
414     vocab_context->dictionary_phrases = phrase_array_new();
415     vocab_context->phrase_memberships = int64_array_new();
416     vocab_context->postal_code_dictionary_phrases = phrase_array_new();
417     vocab_context->sub_tokens = token_array_new();
418 
419     if (vocab_context->token_builder == NULL ||
420         vocab_context->postal_code_token_builder == NULL ||
421         vocab_context->sub_token_builder == NULL ||
422         vocab_context->phrase_builder == NULL ||
423         vocab_context->dictionary_phrases == NULL ||
424         vocab_context->phrase_memberships == NULL ||
425         vocab_context->postal_code_dictionary_phrases == NULL ||
426         vocab_context->sub_tokens == NULL) {
427         log_error("Error initializing vocab_context\n");
428         return NULL;
429     }
430 
431     cstring_array *phrases = cstring_array_new();
432     cstring_array *phrase_labels = cstring_array_new();
433 
434     if (phrases == NULL || phrase_labels == NULL) {
435         log_error("Error setting up arrays for vocab building\n");
436         return NULL;
437     }
438 
439     char *phrase;
440 
441     trie_t *phrase_counts_trie = NULL;
442 
443     tokenized_string_t *tokenized_str;
444     token_array *tokens;
445 
446     while (address_parser_data_set_next(data_set)) {
447         tokenized_str = data_set->tokenized_str;
448 
449         if (tokenized_str == NULL) {
450             log_error("tokenized str is NULL\n");
451             goto exit_hashes_allocated;
452         }
453 
454         if (!address_phrases_and_labels(data_set, phrases, phrase_labels, vocab_context)) {
455             log_error("Error in address phrases and labels\n");
456             goto exit_hashes_allocated;
457         }
458 
459         // Iterate through one time to see if there is a postal code in the string
460         bool have_postal_code = false;
461         char *postal_code_phrase = NULL;
462 
463         cstring_array_foreach(phrases, i, phrase, {
464             if (phrase == NULL) continue;
465             char *phrase_label = cstring_array_get_string(phrase_labels, i);
466 
467             if (is_postal_code(phrase_label)) {
468                 have_postal_code = true;
469                 postal_code_phrase = phrase;
470                 break;
471             }
472         })
473 
474         cstring_array_foreach(phrase_labels, i, label, {
475             if (!str_uint32_hash_incr(class_counts, label)) {
476                 log_error("Error in hash_incr for class_counts\n");
477                 goto exit_hashes_allocated;
478             }
479         })
480 
481         cstring_array_foreach(phrases, i, phrase, {
482             if (phrase == NULL) continue;
483 
484             uint32_t class_id;
485             uint32_t component = 0;
486 
487             char *phrase_label = cstring_array_get_string(phrase_labels, i);
488             if (phrase_label == NULL) continue;
489 
490             is_postal = false;
491 
492             // Too many variations on these
493             if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_CITY)) {
494                 class_id = ADDRESS_PARSER_BOUNDARY_CITY;
495                 component = ADDRESS_COMPONENT_CITY;
496             } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_STATE)) {
497                 class_id = ADDRESS_PARSER_BOUNDARY_STATE;
498                 component = ADDRESS_COMPONENT_STATE;
499             } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_COUNTRY)) {
500                 class_id = ADDRESS_PARSER_BOUNDARY_COUNTRY;
501                 component = ADDRESS_COMPONENT_COUNTRY;
502             } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_POSTAL_CODE)) {
503                 is_postal = true;
504 
505                 char_array *token_builder = vocab_context->token_builder;
506                 token_array *sub_tokens = vocab_context->sub_tokens;
507                 tokenize_add_tokens(sub_tokens, phrase, strlen(phrase), false);
508 
509                 char_array_clear(token_builder);
510 
511                 for (j = 0; j < sub_tokens->n; j++) {
512                     token_array_clear(sub_tokens);
513                     token_t t = sub_tokens->a[j];
514                     add_normalized_token(token_builder, phrase, t, ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS);
515 
516                     if (token_builder->n == 0) {
517                         continue;
518                     }
519 
520                     char *sub_token = char_array_get_string(token_builder);
521                     if (!str_uint32_hash_incr(vocab, sub_token)) {
522                         log_error("Error in str_uint32_hash_incr\n");
523                         goto exit_hashes_allocated;
524                     }
525 
526                 }
527 
528             } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_COUNTRY_REGION)) {
529                 class_id = ADDRESS_PARSER_BOUNDARY_COUNTRY_REGION;
530                 component = ADDRESS_COMPONENT_COUNTRY_REGION;
531             } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_STATE_DISTRICT)) {
532                 class_id = ADDRESS_PARSER_BOUNDARY_STATE_DISTRICT;
533                 component = ADDRESS_COMPONENT_STATE_DISTRICT;
534             } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_SUBURB)) {
535                 class_id = ADDRESS_PARSER_BOUNDARY_SUBURB;
536                 component = ADDRESS_COMPONENT_SUBURB;
537             } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_CITY_DISTRICT)) {
538                 class_id = ADDRESS_PARSER_BOUNDARY_CITY_DISTRICT;
539                 component = ADDRESS_COMPONENT_CITY_DISTRICT;
540             } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_WORLD_REGION)) {
541                 class_id = ADDRESS_PARSER_BOUNDARY_WORLD_REGION;
542                 component = ADDRESS_COMPONENT_WORLD_REGION;
543             } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_ISLAND)) {
544                 class_id = ADDRESS_PARSER_BOUNDARY_ISLAND;
545                 component = ADDRESS_COMPONENT_ISLAND;
546             } else {
547                 bool in_vocab = false;
548                 if (!str_uint32_hash_incr_exists(vocab, phrase, &in_vocab)) {
549                     log_error("Error in str_uint32_hash_incr\n");
550                     goto exit_hashes_allocated;
551                 }
552                 continue;
553             }
554 
555             char *normalized_phrase = NULL;
556 
557             if (!is_postal && string_contains_hyphen(phrase)) {
558                 normalized_phrase = normalize_string_utf8(phrase, NORMALIZE_STRING_REPLACE_HYPHENS);
559             }
560 
561             char *phrases[2];
562             phrases[0] = phrase;
563             phrases[1] = normalized_phrase;
564 
565             for (size_t p_i = 0; p_i < sizeof(phrases) / sizeof(char *); p_i++) {
566                 phrase = phrases[p_i];
567                 if (phrase == NULL) continue;
568 
569                 if (is_postal) {
570                     if (!str_uint32_hash_incr(postal_code_counts, phrase)) {
571                         log_error("Error in str_uint32_hash_incr for postal_code_counts\n");
572                         goto exit_hashes_allocated;
573                     }
574                     continue;
575                 }
576 
577                 if (have_postal_code && !is_postal) {
578                     khash_t(str_set) *context_postal_codes = NULL;
579 
580                     k = kh_get(postal_code_context_phrases, postal_code_admin_contexts, postal_code_phrase);
581                     if (k == kh_end(postal_code_admin_contexts)) {
582                         key = strdup(postal_code_phrase);
583                         ret = 0;
584                         k = kh_put(postal_code_context_phrases, postal_code_admin_contexts, key, &ret);
585 
586                         if (ret < 0) {
587                             log_error("Error in kh_put in postal_code_admin_contexts\n");
588                             free(key);
589                             goto exit_hashes_allocated;
590                         }
591                         context_postal_codes = kh_init(str_set);
592                         if (context_postal_codes == NULL) {
593                             log_error("Error in kh_init for context_postal_codes\n");
594                             free(key);
595                             goto exit_hashes_allocated;
596                         }
597                         kh_value(postal_code_admin_contexts, k) = context_postal_codes;
598                     } else {
599                         context_postal_codes = kh_value(postal_code_admin_contexts, k);
600                     }
601 
602                     k = kh_get(str_set, context_postal_codes, phrase);
603                     if (k == kh_end(context_postal_codes)) {
604                         char *context_key = strdup(phrase);
605                         k = kh_put(str_set, context_postal_codes, context_key, &ret);
606                         if (ret < 0) {
607                             log_error("Error in kh_put in context_postal_codes\n");
608                             free(context_key);
609                             goto exit_hashes_allocated;
610                         }
611                     }
612                 }
613 
614                 k = kh_get(phrase_stats, phrase_stats, phrase);
615 
616                 if (k == kh_end(phrase_stats)) {
617                     key = strdup(phrase);
618                     ret = 0;
619                     k = kh_put(phrase_stats, phrase_stats, key, &ret);
620                     if (ret < 0) {
621                         log_error("Error in kh_put in phrase_stats\n");
622                         free(key);
623                         goto exit_hashes_allocated;
624                     }
625                     place_class_counts = kh_init(int_uint32);
626 
627                     stats.class_counts = place_class_counts;
628                     stats.components = component;
629 
630                     kh_value(phrase_stats, k) = stats;
631                 } else {
632                     stats = kh_value(phrase_stats, k);
633                     place_class_counts = stats.class_counts;
634                     stats.components |= component;
635                     kh_value(phrase_stats, k) = stats;
636                 }
637 
638                 if (!int_uint32_hash_incr(place_class_counts, (khint_t)class_id)) {
639                     log_error("Error in int_uint32_hash_incr in class_counts\n");
640                     goto exit_hashes_allocated;
641                 }
642 
643                 if (!str_uint32_hash_incr(phrase_counts, phrase)) {
644                     log_error("Error in str_uint32_hash_incr in phrase_counts\n");
645                     goto exit_hashes_allocated;
646                 }
647 
648             }
649 
650             if (normalized_phrase != NULL) {
651                 free(normalized_phrase);
652                 normalized_phrase = NULL;
653             }
654 
655         })
656 
657         tokenized_string_destroy(tokenized_str);
658         examples++;
659         if (examples % 10000 == 0 && examples != 0) {
660             log_info("Counting vocab: did %zu examples\n", examples);
661         }
662 
663     }
664 
665     log_info("Done with vocab, total size=%" PRIkh32 "\n", kh_size(vocab));
666 
667     for (k = kh_begin(vocab); k != kh_end(vocab); ++k) {
668         token = (char *)kh_key(vocab, k);
669         if (!kh_exist(vocab, k)) {
670             continue;
671         }
672         uint32_t count = kh_value(vocab, k);
673         if (count < MIN_VOCAB_COUNT) {
674             kh_del(str_uint32, vocab, k);
675             free((char *)token);
676         }
677     }
678 
679     log_info("After pruning vocab size=%" PRIkh32 "\n", kh_size(vocab));
680 
681 
682     log_info("Creating phrases trie\n");
683 
684 
685     phrase_counts_trie = trie_new_from_hash(phrase_counts);
686 
687     log_info("Calculating phrase types\n");
688 
689     size_t num_classes = kh_size(class_counts);
690     log_info("num_classes = %zu\n", num_classes);
691     parser->num_classes = num_classes;
692 
693     log_info("Creating vocab trie\n");
694 
695     parser->vocab = trie_new_from_hash(vocab);
696     if (parser->vocab == NULL) {
697         log_error("Error initializing vocabulary\n");
698         address_parser_destroy(parser);
699         parser = NULL;
700         goto exit_hashes_allocated;
701     }
702 
703     kh_foreach(phrase_counts, token, count, {
704         if (!str_uint32_hash_incr_by(vocab, token, count)) {
705             log_error("Error adding phrases to vocabulary\n");
706             address_parser_destroy(parser);
707             parser = NULL;
708             goto exit_hashes_allocated;
709         }
710     })
711 
712     kh_foreach(postal_code_counts, token, count, {
713         if (!str_uint32_hash_incr_by(vocab, token, count)) {
714             log_error("Error adding postal_codes to vocabulary\n");
715             address_parser_destroy(parser);
716             parser = NULL;
717             goto exit_hashes_allocated;
718         }
719     })
720 
721     size_t hash_size;
722     const char *context_token;
723     bool sort_reverse = true;
724 
725     log_info("Creating phrase_types trie\n");
726 
727     sort_reverse = true;
728     char **phrase_keys = str_uint32_hash_sort_keys_by_value(phrase_counts, sort_reverse);
729     if (phrase_keys == NULL) {
730         log_error("phrase_keys == NULL\n");
731         address_parser_destroy(parser);
732         parser = NULL;
733         goto exit_hashes_allocated;
734     }
735 
736     hash_size = kh_size(phrase_counts);
737     address_parser_types_array *phrase_types_array = address_parser_types_array_new_size(hash_size);
738 
739     for (size_t idx = 0; idx < hash_size; idx++) {
740         char *phrase_key = phrase_keys[idx];
741         khiter_t pk = kh_get(str_uint32, phrase_counts, phrase_key);
742         if (pk == kh_end(phrase_counts)) {
743             log_error("Key %zu did not exist in phrase_counts: %s\n", idx, phrase_key);
744             address_parser_destroy(parser);
745             parser = NULL;
746             goto exit_hashes_allocated;
747         }
748 
749         uint32_t phrase_count = kh_value(phrase_counts, pk);
750         if (phrase_count < MIN_PHRASE_COUNT) {
751             token = (char *)kh_key(phrase_counts, pk);
752             kh_del(str_uint32, phrase_counts, pk);
753             free((char *)token);
754             continue;
755         }
756 
757         k = kh_get(phrase_stats, phrase_stats, phrase_key);
758 
759         if (k == kh_end(phrase_stats)) {
760             log_error("Key %zu did not exist in phrase_stats: %s\n", idx, phrase_key);
761             address_parser_destroy(parser);
762             parser = NULL;
763             goto exit_hashes_allocated;
764         }
765 
766         stats = kh_value(phrase_stats, k);
767 
768         place_class_counts = stats.class_counts;
769         int32_t most_common = -1;
770         uint32_t max_count = 0;
771         uint32_t total = 0;
772         for (uint32_t i = 0; i < NUM_ADDRESS_PARSER_BOUNDARY_TYPES; i++) {
773             k = kh_get(int_uint32, place_class_counts, (khint_t)i);
774             if (k != kh_end(place_class_counts)) {
775                 count = kh_value(place_class_counts, k);
776 
777                 if (count > max_count) {
778                     max_count = count;
779                     most_common = i;
780                 }
781                 total += count;
782             }
783         }
784 
785         if (most_common > -1) {
786             address_parser_types_t types;
787             types.components = stats.components;
788             types.most_common = (uint16_t)most_common;
789 
790             kh_value(phrase_counts, pk) = (uint32_t)phrase_types_array->n;
791             address_parser_types_array_push(phrase_types_array, types);
792         }
793     }
794 
795     if (phrase_keys != NULL) {
796         free(phrase_keys);
797     }
798 
799     log_info("Creating phrases trie\n");
800 
801     parser->phrases = trie_new_from_hash(phrase_counts);
802     if (parser->phrases == NULL) {
803         log_error("Error converting phrase_counts to trie\n");
804         address_parser_destroy(parser);
805         parser = NULL;
806         goto exit_hashes_allocated;
807     }
808 
809     if (phrase_types_array == NULL) {
810         log_error("phrase_types_array is NULL\n");
811         address_parser_destroy(parser);
812         parser = NULL;
813         goto exit_hashes_allocated;
814     }
815 
816     parser->phrase_types = phrase_types_array;
817 
818     char **postal_code_keys = str_uint32_hash_sort_keys_by_value(postal_code_counts, true);
819     if (postal_code_keys == NULL) {
820         log_error("postal_code_keys == NULL\n");
821         free(phrase_keys);
822         address_parser_destroy(parser);
823         parser = NULL;
824         goto exit_hashes_allocated;
825     }
826 
827     log_info("Creating postal codes trie\n");
828 
829     hash_size = kh_size(postal_code_counts);
830     for (size_t idx = 0; idx < hash_size; idx++) {
831         char *phrase_key = postal_code_keys[idx];
832 
833         k = kh_get(str_uint32, postal_code_counts, phrase_key);
834         if (k == kh_end(postal_code_counts)) {
835             log_error("Key %zu did not exist in postal_code_counts: %s\n", idx, phrase_key);
836             address_parser_destroy(parser);
837             parser = NULL;
838             goto exit_hashes_allocated;
839         }
840         uint32_t pc_count = kh_value(postal_code_counts, k);
841         kh_value(postal_code_counts, k) = (uint32_t)idx;
842     }
843 
844     if (postal_code_keys != NULL) {
845         free(postal_code_keys);
846     }
847 
848     parser->postal_codes = trie_new_from_hash(postal_code_counts);
849     if (parser->postal_codes == NULL) {
850         log_error("Error converting postal_code_counts to trie\n");
851         address_parser_destroy(parser);
852         parser = NULL;
853         goto exit_hashes_allocated;
854     }
855 
856     log_info("Building postal code contexts\n");
857 
858     bool fixed_rows = false;
859     graph_builder_t *postal_code_contexts_builder = graph_builder_new(GRAPH_BIPARTITE, fixed_rows);
860 
861     uint32_t postal_code_id;
862     uint32_t context_phrase_id;
863 
864     khash_t(str_set) *context_phrases;
865 
866     kh_foreach(postal_code_admin_contexts, token, context_phrases, {
867         if (!trie_get_data(parser->postal_codes, (char *)token, &postal_code_id)) {
868             log_error("Key %s did not exist in parser->postal_codes\n", (char *)token);
869             address_parser_destroy(parser);
870             parser = NULL;
871             goto exit_hashes_allocated;
872         }
873         kh_foreach_key(context_phrases, context_token, {
874             if (!trie_get_data(parser->phrases, (char *)context_token, &context_phrase_id)) {
875                 log_error("Key %s did not exist in phrases trie\n", (char *)context_token);
876                 address_parser_destroy(parser);
877                 parser = NULL;
878                 goto exit_hashes_allocated;
879             }
880 
881             graph_builder_add_edge(postal_code_contexts_builder, postal_code_id, context_phrase_id);
882         })
883     })
884 
885     bool sort_edges = true;
886     bool remove_duplicates = true;
887     graph_t *postal_code_contexts = graph_builder_finalize(postal_code_contexts_builder, sort_edges, remove_duplicates);
888 
889     // NOTE: don't destroy this during deallocation
890     if (postal_code_contexts == NULL) {
891         log_error("postal_code_contexts is NULL\n");
892         address_parser_destroy(parser);
893         parser = NULL;
894         goto exit_hashes_allocated;
895     }
896     parser->postal_code_contexts = postal_code_contexts;
897 
898     log_info("Freeing memory from initialization\n");
899 
900 exit_hashes_allocated:
901     // Free memory for hashtables, etc.
902     if (vocab_context != NULL) {
903         char_array_destroy(vocab_context->token_builder);
904         char_array_destroy(vocab_context->postal_code_token_builder);
905         char_array_destroy(vocab_context->sub_token_builder);
906         char_array_destroy(vocab_context->phrase_builder);
907         phrase_array_destroy(vocab_context->dictionary_phrases);
908         int64_array_destroy(vocab_context->phrase_memberships);
909         phrase_array_destroy(vocab_context->postal_code_dictionary_phrases);
910         token_array_destroy(vocab_context->sub_tokens);
911         free(vocab_context);
912     }
913 
914     cstring_array_destroy(phrases);
915     cstring_array_destroy(phrase_labels);
916 
917     address_parser_data_set_destroy(data_set);
918 
919     if (phrase_counts_trie != NULL) {
920         trie_destroy(phrase_counts_trie);
921     }
922 
923     kh_foreach_key(vocab, token, {
924         free((char *)token);
925     })
926     kh_destroy(str_uint32, vocab);
927 
928     kh_foreach_key(class_counts, token, {
929         free((char *)token);
930     })
931     kh_destroy(str_uint32, class_counts);
932 
933     kh_foreach(phrase_stats, token, stats, {
934         kh_destroy(int_uint32, stats.class_counts);
935         free((char *)token);
936     })
937 
938     kh_destroy(phrase_stats, phrase_stats);
939 
940     kh_foreach_key(phrase_counts, token, {
941         free((char *)token);
942     })
943 
944     kh_destroy(str_uint32, phrase_counts);
945 
946     kh_foreach_key(phrase_types, token, {
947         free((char *)token);
948     })
949     kh_destroy(phrase_types, phrase_types);
950 
951     khash_t(str_set) *pc_set;
952 
953     kh_foreach(postal_code_admin_contexts, token, pc_set, {
954         if (pc_set != NULL) {
955             kh_foreach_key(pc_set, context_token, {
956                 free((char *)context_token);
957             })
958             kh_destroy(str_set, pc_set);
959         }
960         free((char *)token);
961     })
962 
963     kh_destroy(postal_code_context_phrases, postal_code_admin_contexts);
964 
965     kh_foreach_key(postal_code_counts, token, {
966         free((char *)token);
967     })
968     kh_destroy(str_uint32, postal_code_counts);
969 
970     return parser;
971 }
972 
address_parser_train_example(address_parser_t * self,void * trainer,address_parser_context_t * context,address_parser_data_set_t * data_set)973 static inline bool address_parser_train_example(address_parser_t *self, void *trainer, address_parser_context_t *context, address_parser_data_set_t *data_set) {
974     if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
975         return averaged_perceptron_trainer_train_example((averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
976     } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
977         return crf_averaged_perceptron_trainer_train_example((crf_averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
978     } else {
979         log_error("Parser model is of unknown type\n");
980     }
981     return false;
982 }
983 
address_parser_trainer_destroy(address_parser_t * self,void * trainer)984 static inline void address_parser_trainer_destroy(address_parser_t *self, void *trainer) {
985     if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
986         averaged_perceptron_trainer_destroy((averaged_perceptron_trainer_t *)trainer);
987     } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
988         crf_averaged_perceptron_trainer_destroy((crf_averaged_perceptron_trainer_t *)trainer);
989     }
990 }
991 
address_parser_finalize_model(address_parser_t * self,void * trainer)992 static inline bool address_parser_finalize_model(address_parser_t *self, void *trainer) {
993     if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
994         self->model.ap = averaged_perceptron_trainer_finalize((averaged_perceptron_trainer_t *)trainer);
995         return self->model.ap != NULL;
996     } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
997         self->model.crf = crf_averaged_perceptron_trainer_finalize((crf_averaged_perceptron_trainer_t *)trainer);
998         return self->model.crf != NULL;
999     } else {
1000         log_error("Parser model is of unknown type\n");
1001     }
1002     return false;
1003 }
1004 
address_parser_train_num_iterations(address_parser_t * self,void * trainer)1005 static inline uint32_t address_parser_train_num_iterations(address_parser_t *self, void *trainer) {
1006     if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
1007         averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
1008         return ap_trainer->iterations;
1009     } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
1010         crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
1011         return crf_trainer->iterations;
1012     } else {
1013         log_error("Parser model is of unknown type\n");
1014     }
1015     return 0;
1016 }
1017 
address_parser_train_set_iterations(address_parser_t * self,void * trainer,uint32_t iterations)1018 static inline void address_parser_train_set_iterations(address_parser_t *self, void *trainer, uint32_t iterations) {
1019     if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
1020         averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
1021         ap_trainer->iterations = iterations;
1022     } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
1023         crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
1024         crf_trainer->iterations = iterations;
1025     } else {
1026         log_error("Parser model is of unknown type\n");
1027     }
1028 }
1029 
address_parser_train_num_errors(address_parser_t * self,void * trainer)1030 static inline uint64_t address_parser_train_num_errors(address_parser_t *self, void *trainer) {
1031     if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
1032         averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
1033         return ap_trainer->num_updates;
1034     } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
1035         crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
1036         return crf_trainer->num_updates;
1037     } else {
1038         log_error("Parser model is of unknown type\n");
1039     }
1040     return 0;
1041 }
1042 
address_parser_train_epoch(address_parser_t * self,void * trainer,char * filename)1043 bool address_parser_train_epoch(address_parser_t *self, void *trainer, char *filename) {
1044     if (filename == NULL) {
1045         log_error("Filename was NULL\n");
1046         return false;
1047     }
1048 
1049     address_parser_data_set_t *data_set = address_parser_data_set_init(filename);
1050     if (data_set == NULL) {
1051         log_error("Error initializing data set\n");
1052         return false;
1053     }
1054 
1055     address_parser_context_t *context = self->context;
1056 
1057     size_t examples = 0;
1058     uint64_t errors = address_parser_train_num_errors(self, trainer);
1059 
1060     uint32_t iteration = address_parser_train_num_iterations(self, trainer);
1061 
1062     bool logged = false;
1063 
1064     while (address_parser_data_set_next(data_set)) {
1065         char *language = char_array_get_string(data_set->language);
1066         if (string_equals(language, UNKNOWN_LANGUAGE) || string_equals(language, AMBIGUOUS_LANGUAGE)) {
1067             language = NULL;
1068         }
1069         char *country = char_array_get_string(data_set->country);
1070 
1071         address_parser_context_fill(context, self, data_set->tokenized_str, language, country);
1072 
1073         bool example_success = address_parser_train_example(self, trainer, context, data_set);
1074 
1075         if (!example_success) {
1076             log_error("Error training example\n");
1077             goto exit_epoch_training_started;
1078         }
1079 
1080         tokenized_string_destroy(data_set->tokenized_str);
1081         data_set->tokenized_str = NULL;
1082 
1083         if (!example_success) {
1084             log_error("Error training example without country/language\n");
1085             goto exit_epoch_training_started;
1086         }
1087 
1088         examples++;
1089         if (examples % 1000 == 0 && examples > 0) {
1090             uint64_t prev_errors = errors;
1091             errors = address_parser_train_num_errors(self, trainer);
1092 
1093             log_info("Iter %d: Did %zu examples with %" PRIu64 " errors\n", iteration, examples, errors - prev_errors);
1094         }
1095     }
1096 
1097 exit_epoch_training_started:
1098     address_parser_data_set_destroy(data_set);
1099 
1100     return true;
1101 }
1102 
1103 
address_parser_train(address_parser_t * self,char * filename,address_parser_model_type_t model_type,uint32_t num_iterations,size_t min_updates)1104 bool address_parser_train(address_parser_t *self, char *filename, address_parser_model_type_t model_type, uint32_t num_iterations, size_t min_updates) {
1105     self->model_type = model_type;
1106     void *trainer;
1107     if (model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
1108         averaged_perceptron_trainer_t *ap_trainer = averaged_perceptron_trainer_new(min_updates);
1109         trainer = (void *)ap_trainer;
1110     } else if (model_type == ADDRESS_PARSER_TYPE_CRF) {
1111         crf_averaged_perceptron_trainer_t *crf_trainer = crf_averaged_perceptron_trainer_new(self->num_classes, min_updates);
1112         trainer = (void *)crf_trainer;
1113     }
1114 
1115     for (uint32_t iter = 0; iter < num_iterations; iter++) {
1116         log_info("Doing epoch %d\n", iter);
1117 
1118         address_parser_train_set_iterations(self, trainer, iter);
1119 
1120         #if defined(HAVE_SHUF) || defined(HAVE_GSHUF)
1121         log_info("Shuffling\n");
1122 
1123         if (!shuffle_file_chunked_size(filename, DEFAULT_SHUFFLE_CHUNK_SIZE)) {
1124             log_error("Error in shuffle\n");
1125             address_parser_trainer_destroy(self, trainer);
1126             return false;
1127         }
1128 
1129         log_info("Shuffle complete\n");
1130         #endif
1131 
1132         if (!address_parser_train_epoch(self, trainer, filename)) {
1133             log_error("Error in epoch\n");
1134             address_parser_trainer_destroy(self, trainer);
1135             return false;
1136         }
1137     }
1138 
1139     log_debug("Done with training, averaging weights\n");
1140 
1141     if (!address_parser_finalize_model(self, trainer)) {
1142         log_error("model was NULL\n");
1143         return false;
1144     }
1145 
1146     return true;
1147 }
1148 
1149 typedef enum {
1150     ADDRESS_PARSER_TRAIN_POSITIONAL_ARG,
1151     ADDRESS_PARSER_TRAIN_ARG_ITERATIONS,
1152     ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES,
1153     ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE
1154 } address_parser_train_keyword_arg_t;
1155 
1156 #define USAGE "Usage: ./address_parser_train filename output_dir [--iterations number --min-updates number --model (crf|greedyap)]\n"
1157 
main(int argc,char ** argv)1158 int main(int argc, char **argv) {
1159     if (argc < 3) {
1160         printf(USAGE);
1161         exit(EXIT_FAILURE);
1162     }
1163 
1164     #if !defined(HAVE_SHUF) && !defined(HAVE_GSHUF)
1165     log_warn("shuf must be installed to train address parser effectively. If this is a production machine, please install shuf. No shuffling will be performed.\n");
1166     #endif
1167 
1168     int pos_args = 1;
1169 
1170     address_parser_train_keyword_arg_t kwarg = ADDRESS_PARSER_TRAIN_POSITIONAL_ARG;
1171 
1172     size_t num_iterations = DEFAULT_ITERATIONS;
1173     uint64_t min_updates = DEFAULT_MIN_UPDATES;
1174     size_t position = 0;
1175 
1176     ssize_t arg_iterations;
1177     uint64_t arg_min_updates;
1178 
1179     char *filename = NULL;
1180     char *output_dir = NULL;
1181 
1182     address_parser_model_type_t model_type = DEFAULT_MODEL_TYPE;
1183 
1184     for (int i = pos_args; i < argc; i++) {
1185         char *arg = argv[i];
1186 
1187         if (string_equals(arg, "--iterations")) {
1188             kwarg = ADDRESS_PARSER_TRAIN_ARG_ITERATIONS;
1189             continue;
1190         }
1191 
1192         if (string_equals(arg, "--min-updates")) {
1193             kwarg = ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES;
1194             continue;
1195         }
1196 
1197         if (string_equals(arg, "--model")) {
1198             kwarg = ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE;
1199             continue;
1200         }
1201 
1202         if (kwarg == ADDRESS_PARSER_TRAIN_ARG_ITERATIONS) {
1203             if (sscanf(arg, "%zd", &arg_iterations) != 1 || arg_iterations < 0) {
1204                 log_error("Bad arg for --iterations: %s\n", arg);
1205                 exit(EXIT_FAILURE);
1206             }
1207             num_iterations = (size_t)arg_iterations;
1208         } else if (kwarg == ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES) {
1209             if (sscanf(arg, "%llu", &arg_min_updates) != 1) {
1210                 log_error("Bad arg for --min-updates: %s\n", arg);
1211                 exit(EXIT_FAILURE);
1212             }
1213             min_updates = arg_min_updates;
1214             log_info("min_updates = %" PRIu64 "\n", min_updates);
1215         } else if (kwarg == ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE) {
1216             if (string_equals(arg, "crf")) {
1217                 model_type = ADDRESS_PARSER_TYPE_CRF;
1218             } else if (string_equals(arg, "greedyap"))  {
1219                 model_type = ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON;
1220             } else {
1221                 log_error("Bad arg for --model, valid values are [crf, greedyap]\n");
1222                 exit(EXIT_FAILURE);
1223             }
1224         } else if (position == 0) {
1225             filename = arg;
1226             position++;
1227         } else if (position == 1) {
1228             output_dir = arg;
1229             position++;
1230         }
1231         kwarg = ADDRESS_PARSER_TRAIN_POSITIONAL_ARG;
1232 
1233     }
1234 
1235     if (filename == NULL || output_dir == NULL) {
1236         printf(USAGE);
1237         exit(EXIT_FAILURE);
1238     }
1239 
1240     if (!address_dictionary_module_setup(NULL)) {
1241         log_error("Could not load address dictionaries\n");
1242         exit(EXIT_FAILURE);
1243     }
1244 
1245     log_info("address dictionary module loaded\n");
1246 
1247     // Needs to load for normalization
1248     if (!transliteration_module_setup(NULL)) {
1249         log_error("Could not load transliteration module\n");
1250         exit(EXIT_FAILURE);
1251     }
1252 
1253     log_info("transliteration module loaded\n");
1254 
1255     address_parser_t *parser = address_parser_init(filename);
1256 
1257     if (parser == NULL) {
1258         log_error("Could not initialize parser\n");
1259         exit(EXIT_FAILURE);
1260     }
1261 
1262     log_info("Finished initialization\n");
1263 
1264     if (!address_parser_train(parser, filename, model_type, num_iterations, min_updates)) {
1265         log_error("Error in training\n");
1266         exit(EXIT_FAILURE);
1267     }
1268 
1269     log_debug("Finished training\n");
1270 
1271     if (!address_parser_save(parser, output_dir)) {
1272         log_error("Error saving address parser\n");
1273         exit(EXIT_FAILURE);
1274     }
1275 
1276     address_parser_destroy(parser);
1277 
1278     address_dictionary_module_teardown();
1279     log_debug("Done\n");
1280 }
1281