1 #ifndef CRF_AVERAGED_PERCEPTRON_TRAINER_H
2 #define CRF_AVERAGED_PERCEPTRON_TRAINER_H
3 
4 #include <stdio.h>
5 #include <stdlib.h>
6 
7 #include "averaged_perceptron_trainer.h"
8 #include "crf.h"
9 #include "crf_trainer.h"
10 #include "collections.h"
11 #include "string_utils.h"
12 #include "tokens.h"
13 #include "trie.h"
14 #include "trie_utils.h"
15 
16 typedef union tag_bigram {
17     uint64_t value;
18     struct {
19         uint32_t prev_class_id:32;
20         uint32_t class_id:32;
21     };
22 } tag_bigram_t;
23 
24 KHASH_MAP_INIT_INT64(prev_tag_class_weights, class_weight_t)
25 
26 KHASH_MAP_INIT_INT(feature_prev_tag_class_weights, khash_t(prev_tag_class_weights) *)
27 
28 typedef struct crf_averaged_perceptron_trainer {
29     crf_trainer_t *base_trainer;
30     uint64_t num_updates;
31     uint64_t num_errors;
32     uint32_t iterations;
33     uint64_t min_updates;
34     // {feature_id => {class_id => class_weight_t}}
35     khash_t(feature_class_weights) *weights;
36     khash_t(feature_prev_tag_class_weights) *prev_tag_weights;
37     khash_t(prev_tag_class_weights) *trans_weights;
38     uint64_array *update_counts;
39     uint64_array *prev_tag_update_counts;
40     cstring_array *sequence_features;
41     uint32_array *sequence_features_indptr;
42     cstring_array *sequence_prev_tag_features;
43     uint32_array *sequence_prev_tag_features_indptr;
44     uint32_array *label_ids;
45     uint32_array *viterbi;
46 } crf_averaged_perceptron_trainer_t;
47 
48 crf_averaged_perceptron_trainer_t *crf_averaged_perceptron_trainer_new(size_t num_classes, size_t min_updates);
49 
50 uint32_t crf_averaged_perceptron_trainer_predict(crf_averaged_perceptron_trainer_t *self, cstring_array *features);
51 
52 bool crf_averaged_perceptron_trainer_train_example(crf_averaged_perceptron_trainer_t *self,
53                                                    void *tagger,
54                                                    void *context,
55                                                    cstring_array *features,
56                                                    cstring_array *prev_tag_features,
57                                                    tagger_feature_function feature_function,
58                                                    tokenized_string_t *tokenized,
59                                                    cstring_array *labels
60                                                    );
61 
62 crf_t *crf_averaged_perceptron_trainer_finalize(crf_averaged_perceptron_trainer_t *self);
63 
64 void crf_averaged_perceptron_trainer_destroy(crf_averaged_perceptron_trainer_t *self);
65 
66 
67 #endif
68