1 /*
2  * Copyright (C) 2013 Dávid Márk Nemeskey
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU General Public License as
6  * published by the Free Software Foundation; either version 2 of the
7  * License, or (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful, but
10  * WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, see <https://www.gnu.org/licenses/>.
16  */
17 
18 #include <lttoolbox/att_compiler.h>
19 #include <lttoolbox/alphabet.h>
20 #include <lttoolbox/transducer.h>
21 #include <lttoolbox/compression.h>
22 #include <lttoolbox/string_to_wostream.h>
23 #include <algorithm>
24 #include <stack>
25 
26 using namespace std;
27 
AttCompiler()28 AttCompiler::AttCompiler() :
29 starting_state(0),
30 default_weight(0.0000)
31 {
32 }
33 
~AttCompiler()34 AttCompiler::~AttCompiler()
35 {
36 }
37 
38 void
clear()39 AttCompiler::clear()
40 {
41   for (auto& it : graph)
42   {
43     delete it.second;
44   }
45   graph.clear();
46   alphabet = Alphabet();
47 }
48 
49 /**
50  * Converts symbols like @0@ to epsilon, @_SPACE_@ to space, etc.
51  * @todo Are there other special symbols? If so, add them, and maybe use a map
52  *       for conversion?
53  */
54 void
convert_hfst(wstring & symbol)55 AttCompiler::convert_hfst(wstring& symbol)
56 {
57   if (symbol == L"@0@" || symbol == L"ε")
58   {
59     symbol = L"";
60   }
61   else if (symbol == L"@_SPACE_@")
62   {
63     symbol = L" ";
64   }
65 }
66 
67 bool
is_word_punct(wchar_t symbol)68 AttCompiler::is_word_punct(wchar_t symbol)
69 {
70   // https://en.wikipedia.org/wiki/Combining_character#Unicode_ranges
71   if((symbol >= 0x0300 && symbol <= 0x036F) // Combining Diacritics
72   || (symbol >= 0x1AB0 && symbol <= 0x1AFF) // ... Extended
73   || (symbol >= 0x1DC0 && symbol <= 0x1DFF) // ... Supplement
74   || (symbol >= 0x20D0 && symbol <= 0x20FF) // ... for Symbols
75   || (symbol >= 0xFE20 && symbol <= 0xFE2F)) // Combining Half Marks
76   {
77     return true;
78   }
79 
80   return false;
81 }
82 
83 /**
84  * Returns the code of the symbol in the alphabet. Run after convert_hfst has
85  * run.
86  *
87  * Also adds all non-multicharacter symbols (letters) to the @p letters set.
88  *
89  * @return the code of the symbol, if @p symbol is multichar; its first (and
90  *         only) character otherwise.
91  */
92 int
symbol_code(const wstring & symbol)93 AttCompiler::symbol_code(const wstring& symbol)
94 {
95   if (symbol.length() > 1) {
96     alphabet.includeSymbol(symbol);
97     return alphabet(symbol);
98   } else if (symbol == L"") {
99     return 0;
100   } else if ((iswpunct(symbol[0]) || iswspace(symbol[0])) && !is_word_punct(symbol[0])) {
101     return symbol[0];
102   } else {
103     letters.insert(symbol[0]);
104     if(iswlower(symbol[0]))
105     {
106       letters.insert(towupper(symbol[0]));
107     }
108     else if(iswupper(symbol[0]))
109     {
110       letters.insert(towlower(symbol[0]));
111     }
112     return symbol[0];
113   }
114 }
115 
116 bool
has_multiple_fsts(string const & file_name)117 AttCompiler::has_multiple_fsts(string const &file_name)
118 {
119   wifstream infile(file_name.c_str());  // TODO: error checking
120   wstring line;
121 
122   while(getline(infile, line)){
123     if (line.find('-') == 0)
124       return true;
125   }
126 
127   return false;
128 }
129 
130 void
parse(string const & file_name,wstring const & dir)131 AttCompiler::parse(string const &file_name, wstring const &dir)
132 {
133   clear();
134 
135   wifstream infile(file_name.c_str());  // TODO: error checking
136   vector<wstring> tokens;
137   wstring line;
138   bool first_line_in_fst = true;       // First line -- see below
139   int state_id_offset = 0;
140   int largest_seen_state_id = 0;
141 
142   if (has_multiple_fsts(file_name)){
143     wcerr << "Warning: Multiple fsts in '" << file_name << "' will be disjuncted." << endl;
144 
145     // Set the starting state to 0 (Epsilon transtions will be added later)
146     starting_state = 0;
147     state_id_offset = 1;
148   }
149 
150   while (getline(infile, line))
151   {
152     tokens.clear();
153     int from, to;
154     wstring upper, lower;
155     double weight;
156 
157     if (line.length() == 0 && first_line_in_fst)
158     {
159       wcerr << "Error: empty file '" << file_name << "'." << endl;
160       exit(EXIT_FAILURE);
161     }
162     if (first_line_in_fst && line.find(L"\t") == wstring::npos)
163     {
164       wcerr << "Error: invalid format '" << file_name << "'." << endl;
165       exit(EXIT_FAILURE);
166     }
167 
168     /* Empty line. */
169     if (line.length() == 0)
170     {
171       continue;
172     }
173     split(line, L'\t', tokens);
174 
175     if (tokens[0].find('-') == 0)
176     {
177       // Update the offset for the new FST
178       state_id_offset = largest_seen_state_id + 1;
179       first_line_in_fst = true;
180       continue;
181     }
182 
183     from = stoi(tokens[0]) + state_id_offset;
184     largest_seen_state_id = max(largest_seen_state_id, from);
185 
186     AttNode* source = get_node(from);
187     /* First line: the initial state is of both types. */
188     if (first_line_in_fst)
189     {
190       // If the file has a single FST - No need for state id mapping
191       if (state_id_offset == 0)
192         starting_state = from;
193       else{
194         AttNode * starting_node = get_node(starting_state);
195 
196         // Add an Epsilon transition from the new starting state
197         starting_node->transductions.push_back(
198           Transduction(from, L"", L"",
199             alphabet(symbol_code(L""), symbol_code(L"")),
200             default_weight));
201       }
202       first_line_in_fst = false;
203     }
204 
205     /* Final state. */
206     if (tokens.size() <= 2)
207     {
208       if (tokens.size() > 1)
209       {
210         weight = stod(tokens[1]);
211       }
212       else
213       {
214         weight = default_weight;
215       }
216       finals.insert(pair <int, double>(from, weight));
217     }
218     else
219     {
220       to = stoi(tokens[1]) + state_id_offset;
221       largest_seen_state_id = max(largest_seen_state_id, to);
222       if(dir == L"RL")
223       {
224         upper = tokens[3];
225         lower = tokens[2];
226       }
227       else
228       {
229         upper = tokens[2];
230         lower = tokens[3];
231       }
232       convert_hfst(upper);
233       convert_hfst(lower);
234       int tag = alphabet(symbol_code(upper), symbol_code(lower));
235       if(tokens.size() > 4)
236       {
237         weight = stod(tokens[4]);
238       }
239       else
240       {
241         weight = default_weight;
242       }
243       source->transductions.push_back(Transduction(to, upper, lower, tag, weight));
244       classify_single_transition(source->transductions.back());
245 
246       get_node(to);
247     }
248   }
249 
250   /* Classify the nodes of the graph. */
251   classify_forwards();
252   set<int> path;
253   classify_backwards(starting_state, path);
254 
255   infile.close();
256 }
257 
258 /** Extracts the sub-transducer made of states of type @p type. */
259 Transducer
extract_transducer(TransducerType type)260 AttCompiler::extract_transducer(TransducerType type)
261 {
262   Transducer transducer;
263   /* Correlation between the graph's state ids and those in the transducer. */
264   map<int, int> corr;
265   set<int> visited;
266 
267   corr[starting_state] = transducer.getInitial();
268   _extract_transducer(type, starting_state, transducer, corr, visited);
269 
270   /* The final states. */
271   bool noFinals = true;
272   for (auto& f : finals)
273   {
274     if (corr.find(f.first) != corr.end())
275     {
276       transducer.setFinal(corr[f.first], f.second);
277       noFinals = false;
278     }
279   }
280 
281 /*
282   if(noFinals)
283   {
284     wcerr << L"No final states (" << type << ")" << endl;
285     wcerr << L"  were:" << endl;
286     wcerr << L"\t" ;
287     for (auto& f : finals)
288     {
289       wcerr << f.first << L" ";
290     }
291     wcerr << endl;
292   }
293 */
294   return transducer;
295 }
296 
297 /**
298  * Recursively fills @p transducer (and @p corr) -- helper method called by
299  * extract_transducer().
300  */
301 void
_extract_transducer(TransducerType type,int from,Transducer & transducer,map<int,int> & corr,set<int> & visited)302 AttCompiler::_extract_transducer(TransducerType type, int from,
303                                  Transducer& transducer, map<int, int>& corr,
304                                  set<int>& visited)
305 {
306   if (visited.find(from) != visited.end())
307   {
308     return;
309   }
310   else
311   {
312     visited.insert(from);
313   }
314 
315   AttNode* source = get_node(from);
316 
317   /* Is the source state new? */
318   bool new_from = corr.find(from) == corr.end();
319   int from_t, to_t;
320 
321   for (auto& it : source->transductions)
322   {
323     if ((it.type & type) != type)
324     {
325       continue;  // Not the right type
326     }
327     /* Is the target state new? */
328     bool new_to = corr.find(it.to) == corr.end();
329 
330     if (new_from)
331     {
332       corr[from] = transducer.size() + (new_to ? 1 : 0);
333     }
334     from_t = corr[from];
335 
336     /* Now with the target state: */
337     if (!new_to)
338     {
339       /* We already know it, possibly by a different name: link them! */
340       to_t = corr[it.to];
341       transducer.linkStates(from_t, to_t, it.tag, it.weight);
342     }
343     else
344     {
345       /* We haven't seen it yet: add a new state! */
346       to_t = transducer.insertNewSingleTransduction(it.tag, from_t, it.weight);
347       corr[it.to] = to_t;
348     }
349     _extract_transducer(type, it.to, transducer, corr, visited);
350   }  // for
351 }
352 
353 void
classify_single_transition(Transduction & t)354 AttCompiler::classify_single_transition(Transduction& t)
355 {
356   if (t.upper.length() == 1) {
357     if (letters.find(t.upper[0]) != letters.end()) {
358       t.type |= WORD;
359     }
360     if (iswpunct(t.upper[0])) {
361       t.type |= PUNCT;
362     }
363   }
364 }
365 
366 /**
367  * Propagate edge types forwards.
368  */
369 void
classify_forwards()370 AttCompiler::classify_forwards()
371 {
372   stack<int> todo;
373   set<int> done;
374   todo.push(starting_state);
375   while(!todo.empty()) {
376     int next = todo.top();
377     todo.pop();
378     if(done.find(next) != done.end()) continue;
379     AttNode* n1 = get_node(next);
380     for(auto& t1 : n1->transductions) {
381       AttNode* n2 = get_node(t1.to);
382       for(auto& t2 : n2->transductions) {
383 	t2.type |= t1.type;
384       }
385       if(done.find(t1.to) == done.end()) {
386 	todo.push(t1.to);
387       }
388     }
389     done.insert(next);
390   }
391 }
392 
393 /**
394  * Recursively determine edge types of initial epsilon transitions
395  * Also check for epsilon loops or epsilon transitions to final states
396  * @param state the state to examine
397  * @param path the path we took to get here
398  */
399 TransducerType
classify_backwards(int state,set<int> & path)400 AttCompiler::classify_backwards(int state, set<int>& path)
401 {
402   if(finals.find(state) != finals.end()) {
403     wcerr << L"ERROR: Transducer contains epsilon transition to a final state. Aborting." << endl;
404     exit(EXIT_FAILURE);
405   }
406   AttNode* node = get_node(state);
407   TransducerType type = UNDECIDED;
408   for(auto& t1 : node->transductions) {
409     if(t1.type != UNDECIDED) {
410       type |= t1.type;
411     } else if(path.find(t1.to) != path.end()) {
412       wcerr << L"ERROR: Transducer contains initial epsilon loop. Aborting." << endl;
413       exit(EXIT_FAILURE);
414     } else {
415       path.insert(t1.to);
416       t1.type = classify_backwards(t1.to, path);
417       type |= t1.type;
418       path.erase(t1.to);
419     }
420   }
421   // Note: if type is still UNDECIDED at this point, then we have a dead-end
422   // path, which is fine since it will be discarded by _extract_transducer()
423   return type;
424 }
425 
426 
427 /** Writes the transducer to @p file_name in lt binary format. */
428 void
write(FILE * output)429 AttCompiler::write(FILE *output)
430 {
431 //  FILE* output = fopen(file_name, "wb");
432   fwrite(HEADER_LTTOOLBOX, 1, 4, output);
433   uint64_t features = 0;
434   write_le(output, features);
435 
436   Transducer punct_fst = extract_transducer(PUNCT);
437 
438   /* Non-multichar symbols. */
439   Compression::wstring_write(wstring(letters.begin(), letters.end()), output);
440   /* Multichar symbols. */
441   alphabet.write(output);
442   /* And now the FST. */
443   if(punct_fst.numberOfTransitions() == 0)
444   {
445     Compression::multibyte_write(1, output);
446   }
447   else
448   {
449     Compression::multibyte_write(2, output);
450   }
451   Compression::wstring_write(L"main@standard", output);
452   Transducer word_fst = extract_transducer(WORD);
453   word_fst.write(output);
454   wcout << L"main@standard" << " " << word_fst.size();
455   wcout << " " << word_fst.numberOfTransitions() << endl;
456   Compression::wstring_write(L"final@inconditional", output);
457   if(punct_fst.numberOfTransitions() != 0)
458   {
459     punct_fst.write(output);
460     wcout << L"final@inconditional" << " " << punct_fst.size();
461     wcout << " " << punct_fst.numberOfTransitions() << endl;
462   }
463 //  fclose(output);
464 }
465