1 /*
2 * Copyright (C) 2013 Dávid Márk Nemeskey
3 *
4 * This program is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU General Public License as
6 * published by the Free Software Foundation; either version 2 of the
7 * License, or (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful, but
10 * WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 * General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, see <https://www.gnu.org/licenses/>.
16 */
17
18 #include <lttoolbox/att_compiler.h>
19 #include <lttoolbox/alphabet.h>
20 #include <lttoolbox/transducer.h>
21 #include <lttoolbox/compression.h>
22 #include <lttoolbox/string_to_wostream.h>
23 #include <algorithm>
24 #include <stack>
25
26 using namespace std;
27
AttCompiler()28 AttCompiler::AttCompiler() :
29 starting_state(0),
30 default_weight(0.0000)
31 {
32 }
33
~AttCompiler()34 AttCompiler::~AttCompiler()
35 {
36 }
37
38 void
clear()39 AttCompiler::clear()
40 {
41 for (auto& it : graph)
42 {
43 delete it.second;
44 }
45 graph.clear();
46 alphabet = Alphabet();
47 }
48
49 /**
50 * Converts symbols like @0@ to epsilon, @_SPACE_@ to space, etc.
51 * @todo Are there other special symbols? If so, add them, and maybe use a map
52 * for conversion?
53 */
54 void
convert_hfst(wstring & symbol)55 AttCompiler::convert_hfst(wstring& symbol)
56 {
57 if (symbol == L"@0@" || symbol == L"ε")
58 {
59 symbol = L"";
60 }
61 else if (symbol == L"@_SPACE_@")
62 {
63 symbol = L" ";
64 }
65 }
66
67 bool
is_word_punct(wchar_t symbol)68 AttCompiler::is_word_punct(wchar_t symbol)
69 {
70 // https://en.wikipedia.org/wiki/Combining_character#Unicode_ranges
71 if((symbol >= 0x0300 && symbol <= 0x036F) // Combining Diacritics
72 || (symbol >= 0x1AB0 && symbol <= 0x1AFF) // ... Extended
73 || (symbol >= 0x1DC0 && symbol <= 0x1DFF) // ... Supplement
74 || (symbol >= 0x20D0 && symbol <= 0x20FF) // ... for Symbols
75 || (symbol >= 0xFE20 && symbol <= 0xFE2F)) // Combining Half Marks
76 {
77 return true;
78 }
79
80 return false;
81 }
82
83 /**
84 * Returns the code of the symbol in the alphabet. Run after convert_hfst has
85 * run.
86 *
87 * Also adds all non-multicharacter symbols (letters) to the @p letters set.
88 *
89 * @return the code of the symbol, if @p symbol is multichar; its first (and
90 * only) character otherwise.
91 */
92 int
symbol_code(const wstring & symbol)93 AttCompiler::symbol_code(const wstring& symbol)
94 {
95 if (symbol.length() > 1) {
96 alphabet.includeSymbol(symbol);
97 return alphabet(symbol);
98 } else if (symbol == L"") {
99 return 0;
100 } else if ((iswpunct(symbol[0]) || iswspace(symbol[0])) && !is_word_punct(symbol[0])) {
101 return symbol[0];
102 } else {
103 letters.insert(symbol[0]);
104 if(iswlower(symbol[0]))
105 {
106 letters.insert(towupper(symbol[0]));
107 }
108 else if(iswupper(symbol[0]))
109 {
110 letters.insert(towlower(symbol[0]));
111 }
112 return symbol[0];
113 }
114 }
115
116 bool
has_multiple_fsts(string const & file_name)117 AttCompiler::has_multiple_fsts(string const &file_name)
118 {
119 wifstream infile(file_name.c_str()); // TODO: error checking
120 wstring line;
121
122 while(getline(infile, line)){
123 if (line.find('-') == 0)
124 return true;
125 }
126
127 return false;
128 }
129
130 void
parse(string const & file_name,wstring const & dir)131 AttCompiler::parse(string const &file_name, wstring const &dir)
132 {
133 clear();
134
135 wifstream infile(file_name.c_str()); // TODO: error checking
136 vector<wstring> tokens;
137 wstring line;
138 bool first_line_in_fst = true; // First line -- see below
139 int state_id_offset = 0;
140 int largest_seen_state_id = 0;
141
142 if (has_multiple_fsts(file_name)){
143 wcerr << "Warning: Multiple fsts in '" << file_name << "' will be disjuncted." << endl;
144
145 // Set the starting state to 0 (Epsilon transtions will be added later)
146 starting_state = 0;
147 state_id_offset = 1;
148 }
149
150 while (getline(infile, line))
151 {
152 tokens.clear();
153 int from, to;
154 wstring upper, lower;
155 double weight;
156
157 if (line.length() == 0 && first_line_in_fst)
158 {
159 wcerr << "Error: empty file '" << file_name << "'." << endl;
160 exit(EXIT_FAILURE);
161 }
162 if (first_line_in_fst && line.find(L"\t") == wstring::npos)
163 {
164 wcerr << "Error: invalid format '" << file_name << "'." << endl;
165 exit(EXIT_FAILURE);
166 }
167
168 /* Empty line. */
169 if (line.length() == 0)
170 {
171 continue;
172 }
173 split(line, L'\t', tokens);
174
175 if (tokens[0].find('-') == 0)
176 {
177 // Update the offset for the new FST
178 state_id_offset = largest_seen_state_id + 1;
179 first_line_in_fst = true;
180 continue;
181 }
182
183 from = stoi(tokens[0]) + state_id_offset;
184 largest_seen_state_id = max(largest_seen_state_id, from);
185
186 AttNode* source = get_node(from);
187 /* First line: the initial state is of both types. */
188 if (first_line_in_fst)
189 {
190 // If the file has a single FST - No need for state id mapping
191 if (state_id_offset == 0)
192 starting_state = from;
193 else{
194 AttNode * starting_node = get_node(starting_state);
195
196 // Add an Epsilon transition from the new starting state
197 starting_node->transductions.push_back(
198 Transduction(from, L"", L"",
199 alphabet(symbol_code(L""), symbol_code(L"")),
200 default_weight));
201 }
202 first_line_in_fst = false;
203 }
204
205 /* Final state. */
206 if (tokens.size() <= 2)
207 {
208 if (tokens.size() > 1)
209 {
210 weight = stod(tokens[1]);
211 }
212 else
213 {
214 weight = default_weight;
215 }
216 finals.insert(pair <int, double>(from, weight));
217 }
218 else
219 {
220 to = stoi(tokens[1]) + state_id_offset;
221 largest_seen_state_id = max(largest_seen_state_id, to);
222 if(dir == L"RL")
223 {
224 upper = tokens[3];
225 lower = tokens[2];
226 }
227 else
228 {
229 upper = tokens[2];
230 lower = tokens[3];
231 }
232 convert_hfst(upper);
233 convert_hfst(lower);
234 int tag = alphabet(symbol_code(upper), symbol_code(lower));
235 if(tokens.size() > 4)
236 {
237 weight = stod(tokens[4]);
238 }
239 else
240 {
241 weight = default_weight;
242 }
243 source->transductions.push_back(Transduction(to, upper, lower, tag, weight));
244 classify_single_transition(source->transductions.back());
245
246 get_node(to);
247 }
248 }
249
250 /* Classify the nodes of the graph. */
251 classify_forwards();
252 set<int> path;
253 classify_backwards(starting_state, path);
254
255 infile.close();
256 }
257
258 /** Extracts the sub-transducer made of states of type @p type. */
259 Transducer
extract_transducer(TransducerType type)260 AttCompiler::extract_transducer(TransducerType type)
261 {
262 Transducer transducer;
263 /* Correlation between the graph's state ids and those in the transducer. */
264 map<int, int> corr;
265 set<int> visited;
266
267 corr[starting_state] = transducer.getInitial();
268 _extract_transducer(type, starting_state, transducer, corr, visited);
269
270 /* The final states. */
271 bool noFinals = true;
272 for (auto& f : finals)
273 {
274 if (corr.find(f.first) != corr.end())
275 {
276 transducer.setFinal(corr[f.first], f.second);
277 noFinals = false;
278 }
279 }
280
281 /*
282 if(noFinals)
283 {
284 wcerr << L"No final states (" << type << ")" << endl;
285 wcerr << L" were:" << endl;
286 wcerr << L"\t" ;
287 for (auto& f : finals)
288 {
289 wcerr << f.first << L" ";
290 }
291 wcerr << endl;
292 }
293 */
294 return transducer;
295 }
296
297 /**
298 * Recursively fills @p transducer (and @p corr) -- helper method called by
299 * extract_transducer().
300 */
301 void
_extract_transducer(TransducerType type,int from,Transducer & transducer,map<int,int> & corr,set<int> & visited)302 AttCompiler::_extract_transducer(TransducerType type, int from,
303 Transducer& transducer, map<int, int>& corr,
304 set<int>& visited)
305 {
306 if (visited.find(from) != visited.end())
307 {
308 return;
309 }
310 else
311 {
312 visited.insert(from);
313 }
314
315 AttNode* source = get_node(from);
316
317 /* Is the source state new? */
318 bool new_from = corr.find(from) == corr.end();
319 int from_t, to_t;
320
321 for (auto& it : source->transductions)
322 {
323 if ((it.type & type) != type)
324 {
325 continue; // Not the right type
326 }
327 /* Is the target state new? */
328 bool new_to = corr.find(it.to) == corr.end();
329
330 if (new_from)
331 {
332 corr[from] = transducer.size() + (new_to ? 1 : 0);
333 }
334 from_t = corr[from];
335
336 /* Now with the target state: */
337 if (!new_to)
338 {
339 /* We already know it, possibly by a different name: link them! */
340 to_t = corr[it.to];
341 transducer.linkStates(from_t, to_t, it.tag, it.weight);
342 }
343 else
344 {
345 /* We haven't seen it yet: add a new state! */
346 to_t = transducer.insertNewSingleTransduction(it.tag, from_t, it.weight);
347 corr[it.to] = to_t;
348 }
349 _extract_transducer(type, it.to, transducer, corr, visited);
350 } // for
351 }
352
353 void
classify_single_transition(Transduction & t)354 AttCompiler::classify_single_transition(Transduction& t)
355 {
356 if (t.upper.length() == 1) {
357 if (letters.find(t.upper[0]) != letters.end()) {
358 t.type |= WORD;
359 }
360 if (iswpunct(t.upper[0])) {
361 t.type |= PUNCT;
362 }
363 }
364 }
365
366 /**
367 * Propagate edge types forwards.
368 */
369 void
classify_forwards()370 AttCompiler::classify_forwards()
371 {
372 stack<int> todo;
373 set<int> done;
374 todo.push(starting_state);
375 while(!todo.empty()) {
376 int next = todo.top();
377 todo.pop();
378 if(done.find(next) != done.end()) continue;
379 AttNode* n1 = get_node(next);
380 for(auto& t1 : n1->transductions) {
381 AttNode* n2 = get_node(t1.to);
382 for(auto& t2 : n2->transductions) {
383 t2.type |= t1.type;
384 }
385 if(done.find(t1.to) == done.end()) {
386 todo.push(t1.to);
387 }
388 }
389 done.insert(next);
390 }
391 }
392
393 /**
394 * Recursively determine edge types of initial epsilon transitions
395 * Also check for epsilon loops or epsilon transitions to final states
396 * @param state the state to examine
397 * @param path the path we took to get here
398 */
399 TransducerType
classify_backwards(int state,set<int> & path)400 AttCompiler::classify_backwards(int state, set<int>& path)
401 {
402 if(finals.find(state) != finals.end()) {
403 wcerr << L"ERROR: Transducer contains epsilon transition to a final state. Aborting." << endl;
404 exit(EXIT_FAILURE);
405 }
406 AttNode* node = get_node(state);
407 TransducerType type = UNDECIDED;
408 for(auto& t1 : node->transductions) {
409 if(t1.type != UNDECIDED) {
410 type |= t1.type;
411 } else if(path.find(t1.to) != path.end()) {
412 wcerr << L"ERROR: Transducer contains initial epsilon loop. Aborting." << endl;
413 exit(EXIT_FAILURE);
414 } else {
415 path.insert(t1.to);
416 t1.type = classify_backwards(t1.to, path);
417 type |= t1.type;
418 path.erase(t1.to);
419 }
420 }
421 // Note: if type is still UNDECIDED at this point, then we have a dead-end
422 // path, which is fine since it will be discarded by _extract_transducer()
423 return type;
424 }
425
426
427 /** Writes the transducer to @p file_name in lt binary format. */
428 void
write(FILE * output)429 AttCompiler::write(FILE *output)
430 {
431 // FILE* output = fopen(file_name, "wb");
432 fwrite(HEADER_LTTOOLBOX, 1, 4, output);
433 uint64_t features = 0;
434 write_le(output, features);
435
436 Transducer punct_fst = extract_transducer(PUNCT);
437
438 /* Non-multichar symbols. */
439 Compression::wstring_write(wstring(letters.begin(), letters.end()), output);
440 /* Multichar symbols. */
441 alphabet.write(output);
442 /* And now the FST. */
443 if(punct_fst.numberOfTransitions() == 0)
444 {
445 Compression::multibyte_write(1, output);
446 }
447 else
448 {
449 Compression::multibyte_write(2, output);
450 }
451 Compression::wstring_write(L"main@standard", output);
452 Transducer word_fst = extract_transducer(WORD);
453 word_fst.write(output);
454 wcout << L"main@standard" << " " << word_fst.size();
455 wcout << " " << word_fst.numberOfTransitions() << endl;
456 Compression::wstring_write(L"final@inconditional", output);
457 if(punct_fst.numberOfTransitions() != 0)
458 {
459 punct_fst.write(output);
460 wcout << L"final@inconditional" << " " << punct_fst.size();
461 wcout << " " << punct_fst.numberOfTransitions() << endl;
462 }
463 // fclose(output);
464 }
465