1 #include "module_automaton_load.h"
2
3 #include "../../../Automaton.h"
4 #include "loadbuffer.h"
5
6
7 // --- public -----------------------------------------------------------
8
9 static bool
10 automaton_load_impl(Automaton* automaton, const char* path, PyObject* deserializer);
11
12 PyObject*
module_automaton_load(PyObject * module,PyObject * args)13 module_automaton_load(PyObject* module, PyObject* args) {
14
15 SaveLoadParameters params;
16 Automaton* automaton;
17 int ret;
18
19 automaton = (Automaton*)automaton_create();
20 if (UNLIKELY(automaton == NULL)) {
21 return NULL;
22 }
23
24 if (UNLIKELY(!automaton_save_load_parse_args(automaton->store, args, ¶ms))) {
25 Py_DECREF(automaton);
26 return NULL;
27 }
28
29 ret = automaton_load_impl(automaton, PyBytes_AsString(params.path), params.callback);
30 Py_DECREF(params.path);
31
32 if (LIKELY(ret))
33 return (PyObject*)automaton;
34 else
35 return NULL;
36 }
37
38 // ----private ----------------------------------------------------------
39
40 static bool
41 automaton_load_node(LoadBuffer* input);
42
43 static TrieNode*
44 automaton_load_fixup_pointers(LoadBuffer* input);
45
46 static bool
automaton_load_impl(Automaton * automaton,const char * path,PyObject * deserializer)47 automaton_load_impl(Automaton* automaton, const char* path, PyObject* deserializer) {
48
49 TrieNode* root;
50 LoadBuffer input;
51 CustompickleHeader header;
52 CustompickleFooter footer;
53 size_t i;
54
55 if (!loadbuffer_open(&input, path, deserializer)) {
56 return false;
57 }
58
59 if (!loadbuffer_init(&input, &header, &footer)) {
60 goto exception;
61 }
62
63 if (header.data.kind == TRIE || header.data.kind == AHOCORASICK) {
64 for (i=0; i < input.capacity; i++) {
65 if (UNLIKELY(!automaton_load_node(&input))) {
66 goto exception;
67 }
68 }
69
70 root = automaton_load_fixup_pointers(&input);
71 if (UNLIKELY(root == NULL)) {
72 goto exception;
73 }
74 } else if (header.data.kind == EMPTY) {
75
76 root = NULL;
77
78 } else {
79 PyErr_SetString(PyExc_ValueError, "automaton kind save in file is invalid");
80 goto exception;
81 }
82
83 loadbuffer_close(&input);
84
85 // setup object
86 automaton->kind = header.data.kind;
87 automaton->store = header.data.store;
88 automaton->key_type = header.data.key_type;
89 automaton->count = header.data.words_count;
90 automaton->longest_word = header.data.longest_word;
91 automaton->version = 0;
92 automaton->stats.version = -1;
93 automaton->root = root;
94
95 return true;
96
97 exception:
98 loadbuffer_close(&input);
99 return false;
100 }
101
102 static bool
automaton_load_node(LoadBuffer * input)103 automaton_load_node(LoadBuffer* input) {
104
105 PyObject* bytes; // XXX: it might be reused (i.e. be part of input)
106 PyObject* object;
107 TrieNode* original;
108 TrieNode* node;
109 size_t size;
110 int ret;
111
112 // 1. get original address of upcoming node
113 ret = loadbuffer_loadinto(input, &original, TrieNode*);
114 if (UNLIKELY(!ret)) {
115 return false;
116 }
117
118 // 2. load node data
119 node = (TrieNode*)memory_alloc(sizeof(TrieNode));
120 if (UNLIKELY(node == NULL)) {
121 PyErr_NoMemory();
122 return false;
123 }
124
125 ret = loadbuffer_load(input, (char*)node, PICKLE_TRIENODE_SIZE);
126 if (UNLIKELY(!ret)) {
127 memory_free(node);
128 return false;
129 }
130
131 node->next = NULL;
132
133 // 3. load next pointers
134 if (node->n > 0) {
135 size = sizeof(Pair) * node->n;
136 node->next = (Pair*)memory_alloc(size);
137 if (UNLIKELY(node->next == NULL)) {
138 PyErr_NoMemory();
139 goto exception;
140 }
141
142 ret = loadbuffer_load(input, (char*)(node->next), size);
143 if (UNLIKELY(!ret)) {
144 goto exception;
145 }
146 }
147
148 // 4. load custom python object
149 if (node->eow && input->store == STORE_ANY) {
150 size = (size_t)(node->output.integer);
151 bytes = F(PyBytes_FromStringAndSize)(NULL, size);
152 if (UNLIKELY(bytes == NULL)) {
153 goto exception;
154 }
155
156 ret = loadbuffer_load(input, PyBytes_AS_STRING(bytes), size);
157 if (UNLIKELY(!ret)) {
158 Py_DECREF(bytes);
159 goto exception;
160 }
161
162 object = F(PyObject_CallFunction)(input->deserializer, "O", bytes);
163 if (UNLIKELY(object == NULL)) {
164 Py_DECREF(bytes);
165 goto exception;
166 }
167
168 node->output.object = object;
169 Py_DECREF(bytes);
170 }
171
172 input->lookup[input->size].original = original;
173 input->lookup[input->size].current = node;
174 input->size += 1;
175
176 return true;
177
178 exception:
179 memory_safefree(node->next);
180 memory_free(node);
181
182 return false;
183 }
184
185
186 static int
addresspair_cmp(const void * a,const void * b)187 addresspair_cmp(const void* a, const void *b) {
188 const TrieNode* Aptr;
189 const TrieNode* Bptr;
190 uintptr_t A;
191 uintptr_t B;
192
193 Aptr = ((AddressPair*)a)->original;
194 Bptr = ((AddressPair*)b)->original;
195
196 A = (uintptr_t)Aptr;
197 B = (uintptr_t)Bptr;
198
199 if (A < B) {
200 return -1;
201 } else if (A > B) {
202 return +1;
203 } else {
204 return 0;
205 }
206 }
207
208
209 static TrieNode*
lookup_address(LoadBuffer * input,TrieNode * original)210 lookup_address(LoadBuffer* input, TrieNode* original) {
211
212 AddressPair* pair;
213
214 pair = (AddressPair*)bsearch(&original,
215 input->lookup,
216 input->size,
217 sizeof(AddressPair),
218 addresspair_cmp);
219
220 if (LIKELY(pair != NULL)) {
221 return pair->current;
222 } else {
223 return NULL;
224 }
225 }
226
227
228 static bool
automaton_load_fixup_node(LoadBuffer * input,TrieNode * node)229 automaton_load_fixup_node(LoadBuffer* input, TrieNode* node) {
230
231 size_t i;
232
233 if (input->kind == AHOCORASICK && node->fail != NULL) {
234 node->fail = lookup_address(input, node->fail);
235 if (UNLIKELY(node->fail == NULL)) {
236 return false;
237 }
238 }
239
240 if (node->n > 0) {
241 for (i=0; i < node->n; i++) {
242 node->next[i].child = lookup_address(input, node->next[i].child);
243 if (UNLIKELY(node->next[i].child == NULL)) {
244 return false;
245 }
246 }
247 }
248
249 return true;
250 }
251
252
253 static TrieNode*
automaton_load_fixup_pointers(LoadBuffer * input)254 automaton_load_fixup_pointers(LoadBuffer* input) {
255
256 TrieNode* root;
257 TrieNode* node;
258 size_t i;
259
260 ASSERT(input != NULL);
261
262 // 1. root is the first node stored in the array
263 root = input->lookup[0].current;
264
265 // 2. sort array to make it bsearch-able
266 qsort(input->lookup, input->size, sizeof(AddressPair), addresspair_cmp);
267
268 // 3. convert all next and fail pointers to current pointers
269 for (i=0; i < input->size; i++) {
270 node = input->lookup[i].current;
271 if (UNLIKELY(!automaton_load_fixup_node(input, node))) {
272 PyErr_Format(PyExc_ValueError, "Detected malformed pointer during unpickling node %lu", i);
273 return NULL;
274 }
275 }
276
277 loadbuffer_invalidate(input);
278
279 return root;
280 }
281