1 /* Copyright (C) 2012  Olga Yakovleva <yakovleva.o.v@gmail.com> */
2 
3 /* This program is free software: you can redistribute it and/or modify */
4 /* it under the terms of the GNU Lesser General Public License as published by */
5 /* the Free Software Foundation, either version 2.1 of the License, or */
6 /* (at your option) any later version. */
7 
8 /* This program is distributed in the hope that it will be useful, */
9 /* but WITHOUT ANY WARRANTY; without even the implied warranty of */
10 /* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the */
11 /* GNU Lesser General Public License for more details. */
12 
13 /* You should have received a copy of the GNU Lesser General Public License */
14 /* along with this program.  If not, see <http://www.gnu.org/licenses/>. */
15 
16 #ifndef RHVOICE_FST_HPP
17 #define RHVOICE_FST_HPP
18 
19 #include <stdint.h>
20 #include <iostream>
21 #include <utility>
22 #include <string>
23 #include <vector>
24 #include <map>
25 #include <iterator>
26 #include <algorithm>
27 #include <functional>
28 #include "utf8.h"
29 #include "exception.hpp"
30 #include "item.hpp"
31 
32 namespace RHVoice
33 {
34   class fst
35   {
36   public:
37     explicit fst(const std::string& path);
38 
39   private:
40     typedef uint32_t state_id;
41     typedef uint16_t symbol_id;
42 
43     class symbol_not_found: public lookup_error
44     {
45     public:
symbol_not_found()46       symbol_not_found():
47         lookup_error("Symbol not found")
48       {
49       }
50     };
51 
52     class alphabet
53     {
54     public:
55       void load(std::istream& in);
56       std::string name(symbol_id id) const;
57       std::string name(symbol_id id,const std::string& default_name) const;
58       symbol_id id(const std::string& name) const;
59       symbol_id id(const std::string& name,symbol_id default_id) const;
60     private:
61       typedef std::map<std::string,symbol_id> symbol_map;
62       std::vector<std::string> ids_to_names;
63       symbol_map names_to_ids;
64     };
65 
66     struct arc
67     {
68       state_id target;
69       symbol_id isymbol,osymbol;
arcRHVoice::fst::arc70       arc():
71         target(0),
72         isymbol(0),
73         osymbol(0)
74       {
75       }
76 
arcRHVoice::fst::arc77       arc(state_id t,symbol_id i,symbol_id o):
78         target(t),
79         isymbol(i),
80         osymbol(o)
81       {
82       }
83 
84       explicit arc(std::istream& in);
85     };
86 
87     class state
88     {
89     private:
90       struct compare_arcs: public std::binary_function<bool,const arc&,const arc&>
91       {
operator ()RHVoice::fst::state::compare_arcs92         bool operator()(const arc& arc1,const arc& arc2) const
93         {
94           return (arc1.isymbol<arc2.isymbol);
95         }
96       };
97 
98     public:
99       explicit state(std::istream& in);
100 
is_final() const101       bool is_final() const
102       {
103         return final;
104       }
105 
state()106       state():
107         final(false)
108       {
109       }
110 
111       typedef std::vector<arc>::const_iterator arc_iterator;
112 
begin() const113       arc_iterator begin() const
114       {
115         return arcs.begin();
116       }
117 
find_arc(symbol_id id) const118       arc_iterator find_arc(symbol_id id) const
119       {
120         arc_iterator it=std::lower_bound(arcs.begin(),arcs.end(),arc(0,id,0),compare_arcs());
121         return (((it!=arcs.end())&&(it->isymbol==id))?it:arcs.end());
122       }
123 
end() const124       arc_iterator end() const
125       {
126         return arcs.end();
127       }
128     private:
129       bool final;
130       std::vector<arc> arcs;
131     };
132 
133     typedef state::arc_iterator arc_iterator;
134 
135     std::vector<state> states;
136     alphabet symbols;
137 
138     typedef std::vector<state>::const_iterator state_iterator;
139     typedef std::vector<std::pair<std::string,symbol_id> > input_symbols;
140 
141     class arc_filter
142     {
143     public:
arc_filter(state_iterator sstate,symbol_id isymbol)144       arc_filter(state_iterator sstate,symbol_id isymbol):
145         source_state(sstate),
146         current_arc(source_state->find_arc(isymbol))
147       {
148         if(current_arc==source_state->end())
149           current_arc=source_state->find_arc(0);
150       }
151 
get() const152       const arc& get() const
153       {
154         return *current_arc;
155       }
156 
157       void next();
158 
done() const159       bool done() const
160       {
161         return (current_arc==source_state->end());
162       }
163     private:
164       state_iterator source_state;
165       arc_iterator current_arc;
166     };
167 
168     template<class output_iterator> bool do_translate(const input_symbols& input,output_iterator output) const;
append_input_symbol(const std::string & name,input_symbols & dest) const169     void append_input_symbol(const std::string& name,input_symbols& dest) const
170     {
171       dest.push_back(input_symbols::value_type(name,symbols.id(name,1)));
172     }
173 
append_input_symbol(utf8::uint32_t chr,input_symbols & dest) const174     void append_input_symbol(utf8::uint32_t chr,input_symbols& dest) const
175     {
176       std::string name;
177       utf8::append(chr,std::back_inserter(name));
178       append_input_symbol(name,dest);
179     }
180 
append_input_symbol(const item & i,input_symbols & dest) const181     void append_input_symbol(const item& i,input_symbols& dest) const
182     {
183       append_input_symbol(i.get("name").as<std::string>(),dest);
184     }
185 
186   public:
187     template<class input_iterator,class output_iterator> bool translate(input_iterator first,input_iterator last,output_iterator output) const;
188   };
189 
190   template<class output_iterator>
do_translate(const input_symbols & input,output_iterator output) const191   bool fst::do_translate(const input_symbols& input,output_iterator output) const
192   {
193     if(states.empty())
194       return false;
195     input_symbols::const_iterator pos=input.begin();
196     if(pos==input.end())
197       return false;
198     arc_filter f(states.begin(),pos->second);
199     if(f.done())
200       return false;
201     std::vector<arc_filter> path;
202     path.push_back(f);
203     if(f.get().isymbol!=0)
204       ++pos;
205     while(!path.empty())
206       {
207         if(pos==input.end())
208           {
209             if(states[path.back().get().target].is_final())
210               break;
211             else
212               f=arc_filter(states.begin()+path.back().get().target,0);
213           }
214         else
215           f=arc_filter(states.begin()+path.back().get().target,pos->second);
216         if(f.done())
217           {
218             while(!path.empty())
219               {
220                 if(path.back().get().isymbol!=0)
221                   --pos;
222                 path.back().next();
223                 if(path.back().done())
224                   path.pop_back();
225                 else
226                   {
227                     if(path.back().get().isymbol!=0)
228                       ++pos;
229                     break;
230                   }
231               }
232           }
233         else
234           {
235             path.push_back(f);
236             if(f.get().isymbol!=0)
237               ++pos;
238           }
239       }
240     if((pos!=input.end())||path.empty()||(!states[path.back().get().target].is_final()))
241       return false;
242     pos=input.begin();
243     for(std::vector<arc_filter>::const_iterator it=path.begin();it!=path.end();++it)
244       {
245         if(it->get().osymbol!=0)
246           {
247             if(it->get().osymbol==1)
248               {
249               *output=pos->first;
250               ++output;
251               }
252             else
253               {
254                 *output=symbols.name(it->get().osymbol);
255                 ++output;
256               }
257           }
258         if(it->get().isymbol!=0)
259           ++pos;
260       }
261     return true;
262   }
263 
264   template<class input_iterator,class output_iterator>
translate(input_iterator first,input_iterator last,output_iterator output) const265   bool fst::translate(input_iterator first,input_iterator last,output_iterator output) const
266   {
267     input_symbols input;
268     for(input_iterator it=first;it!=last;++it)
269       {
270         append_input_symbol(*it,input);
271       }
272     return do_translate(input,output);
273   }
274 }
275 #endif
276