1 #include "module_automaton_load.h"
2 
3 #include "../../../Automaton.h"
4 #include "loadbuffer.h"
5 
6 
7 // --- public -----------------------------------------------------------
8 
9 static bool
10 automaton_load_impl(Automaton* automaton, const char* path, PyObject* deserializer);
11 
12 PyObject*
module_automaton_load(PyObject * module,PyObject * args)13 module_automaton_load(PyObject* module, PyObject* args) {
14 
15     SaveLoadParameters params;
16     Automaton* automaton;
17     int ret;
18 
19     automaton = (Automaton*)automaton_create();
20     if (UNLIKELY(automaton == NULL)) {
21         return NULL;
22     }
23 
24     if (UNLIKELY(!automaton_save_load_parse_args(automaton->store, args, &params))) {
25         Py_DECREF(automaton);
26         return NULL;
27     }
28 
29     ret = automaton_load_impl(automaton, PyBytes_AsString(params.path), params.callback);
30     Py_DECREF(params.path);
31 
32     if (LIKELY(ret))
33         return (PyObject*)automaton;
34     else
35         return NULL;
36 }
37 
38 // ----private ----------------------------------------------------------
39 
40 static bool
41 automaton_load_node(LoadBuffer* input);
42 
43 static TrieNode*
44 automaton_load_fixup_pointers(LoadBuffer* input);
45 
46 static bool
automaton_load_impl(Automaton * automaton,const char * path,PyObject * deserializer)47 automaton_load_impl(Automaton* automaton, const char* path, PyObject* deserializer) {
48 
49     TrieNode* root;
50     LoadBuffer input;
51     CustompickleHeader header;
52     CustompickleFooter footer;
53     size_t i;
54 
55     if (!loadbuffer_open(&input, path, deserializer)) {
56         return false;
57     }
58 
59     if (!loadbuffer_init(&input, &header, &footer)) {
60         goto exception;
61     }
62 
63     if (header.data.kind == TRIE || header.data.kind == AHOCORASICK) {
64         for (i=0; i < input.capacity; i++) {
65             if (UNLIKELY(!automaton_load_node(&input))) {
66                 goto exception;
67             }
68         }
69 
70         root = automaton_load_fixup_pointers(&input);
71         if (UNLIKELY(root == NULL)) {
72             goto exception;
73         }
74     } else if (header.data.kind == EMPTY) {
75 
76         root = NULL;
77 
78     } else {
79         PyErr_SetString(PyExc_ValueError, "automaton kind save in file is invalid");
80         goto exception;
81     }
82 
83     loadbuffer_close(&input);
84 
85     // setup object
86     automaton->kind          = header.data.kind;
87     automaton->store         = header.data.store;
88     automaton->key_type      = header.data.key_type;
89     automaton->count         = header.data.words_count;
90     automaton->longest_word  = header.data.longest_word;
91     automaton->version       = 0;
92     automaton->stats.version = -1;
93     automaton->root          = root;
94 
95     return true;
96 
97 exception:
98     loadbuffer_close(&input);
99     return false;
100 }
101 
102 static bool
automaton_load_node(LoadBuffer * input)103 automaton_load_node(LoadBuffer* input) {
104 
105     PyObject* bytes; // XXX: it might be reused (i.e. be part of input)
106     PyObject* object;
107     TrieNode* original;
108     TrieNode* node;
109     size_t size;
110     int ret;
111 
112     // 1. get original address of upcoming node
113     ret = loadbuffer_loadinto(input, &original, TrieNode*);
114     if (UNLIKELY(!ret)) {
115         return false;
116     }
117 
118     // 2. load node data
119     node = (TrieNode*)memory_alloc(sizeof(TrieNode));
120     if (UNLIKELY(node == NULL)) {
121         PyErr_NoMemory();
122         return false;
123     }
124 
125     ret = loadbuffer_load(input, (char*)node, PICKLE_TRIENODE_SIZE);
126     if (UNLIKELY(!ret)) {
127         memory_free(node);
128         return false;
129     }
130 
131     node->next = NULL;
132 
133     // 3. load next pointers
134     if (node->n > 0) {
135         size = sizeof(Pair) * node->n;
136         node->next = (Pair*)memory_alloc(size);
137         if (UNLIKELY(node->next == NULL)) {
138             PyErr_NoMemory();
139             goto exception;
140         }
141 
142         ret = loadbuffer_load(input, (char*)(node->next), size);
143         if (UNLIKELY(!ret)) {
144             goto exception;
145         }
146     }
147 
148     // 4. load custom python object
149     if (node->eow && input->store == STORE_ANY) {
150         size = (size_t)(node->output.integer);
151         bytes = F(PyBytes_FromStringAndSize)(NULL, size);
152         if (UNLIKELY(bytes == NULL)) {
153             goto exception;
154         }
155 
156         ret = loadbuffer_load(input, PyBytes_AS_STRING(bytes), size);
157         if (UNLIKELY(!ret)) {
158             Py_DECREF(bytes);
159             goto exception;
160         }
161 
162         object = F(PyObject_CallFunction)(input->deserializer, "O", bytes);
163         if (UNLIKELY(object == NULL)) {
164             Py_DECREF(bytes);
165             goto exception;
166         }
167 
168         node->output.object = object;
169         Py_DECREF(bytes);
170     }
171 
172     input->lookup[input->size].original = original;
173     input->lookup[input->size].current  = node;
174     input->size += 1;
175 
176     return true;
177 
178 exception:
179     memory_safefree(node->next);
180     memory_free(node);
181 
182     return false;
183 }
184 
185 
186 static int
addresspair_cmp(const void * a,const void * b)187 addresspair_cmp(const void* a, const void *b) {
188     const TrieNode* Aptr;
189     const TrieNode* Bptr;
190     uintptr_t A;
191     uintptr_t B;
192 
193     Aptr = ((AddressPair*)a)->original;
194     Bptr = ((AddressPair*)b)->original;
195 
196     A = (uintptr_t)Aptr;
197     B = (uintptr_t)Bptr;
198 
199     if (A < B) {
200         return -1;
201     } else if (A > B) {
202         return +1;
203     } else {
204         return 0;
205     }
206 }
207 
208 
209 static TrieNode*
lookup_address(LoadBuffer * input,TrieNode * original)210 lookup_address(LoadBuffer* input, TrieNode* original) {
211 
212     AddressPair* pair;
213 
214     pair = (AddressPair*)bsearch(&original,
215                                  input->lookup,
216                                  input->size,
217                                  sizeof(AddressPair),
218                                  addresspair_cmp);
219 
220     if (LIKELY(pair != NULL)) {
221         return pair->current;
222     } else {
223         return NULL;
224     }
225 }
226 
227 
228 static bool
automaton_load_fixup_node(LoadBuffer * input,TrieNode * node)229 automaton_load_fixup_node(LoadBuffer* input, TrieNode* node) {
230 
231     size_t i;
232 
233     if (input->kind == AHOCORASICK && node->fail != NULL) {
234         node->fail = lookup_address(input, node->fail);
235         if (UNLIKELY(node->fail == NULL)) {
236             return false;
237         }
238     }
239 
240     if (node->n > 0) {
241         for (i=0; i < node->n; i++) {
242             node->next[i].child = lookup_address(input, node->next[i].child);
243             if (UNLIKELY(node->next[i].child == NULL)) {
244                 return false;
245             }
246         }
247     }
248 
249     return true;
250 }
251 
252 
253 static TrieNode*
automaton_load_fixup_pointers(LoadBuffer * input)254 automaton_load_fixup_pointers(LoadBuffer* input) {
255 
256     TrieNode* root;
257     TrieNode* node;
258     size_t i;
259 
260     ASSERT(input != NULL);
261 
262     // 1. root is the first node stored in the array
263     root = input->lookup[0].current;
264 
265     // 2. sort array to make it bsearch-able
266     qsort(input->lookup, input->size, sizeof(AddressPair), addresspair_cmp);
267 
268     // 3. convert all next and fail pointers to current pointers
269     for (i=0; i < input->size; i++) {
270         node = input->lookup[i].current;
271         if (UNLIKELY(!automaton_load_fixup_node(input, node))) {
272             PyErr_Format(PyExc_ValueError, "Detected malformed pointer during unpickling node %lu", i);
273             return NULL;
274         }
275     }
276 
277     loadbuffer_invalidate(input);
278 
279     return root;
280 }
281