1 /* Copyright (C) 2012 Olga Yakovleva <yakovleva.o.v@gmail.com> */ 2 3 /* This program is free software: you can redistribute it and/or modify */ 4 /* it under the terms of the GNU Lesser General Public License as published by */ 5 /* the Free Software Foundation, either version 2.1 of the License, or */ 6 /* (at your option) any later version. */ 7 8 /* This program is distributed in the hope that it will be useful, */ 9 /* but WITHOUT ANY WARRANTY; without even the implied warranty of */ 10 /* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the */ 11 /* GNU Lesser General Public License for more details. */ 12 13 /* You should have received a copy of the GNU Lesser General Public License */ 14 /* along with this program. If not, see <http://www.gnu.org/licenses/>. */ 15 16 #ifndef RHVOICE_FST_HPP 17 #define RHVOICE_FST_HPP 18 19 #include <stdint.h> 20 #include <iostream> 21 #include <utility> 22 #include <string> 23 #include <vector> 24 #include <map> 25 #include <iterator> 26 #include <algorithm> 27 #include <functional> 28 #include "utf8.h" 29 #include "exception.hpp" 30 #include "item.hpp" 31 32 namespace RHVoice 33 { 34 class fst 35 { 36 public: 37 explicit fst(const std::string& path); 38 39 private: 40 typedef uint32_t state_id; 41 typedef uint16_t symbol_id; 42 43 class symbol_not_found: public lookup_error 44 { 45 public: symbol_not_found()46 symbol_not_found(): 47 lookup_error("Symbol not found") 48 { 49 } 50 }; 51 52 class alphabet 53 { 54 public: 55 void load(std::istream& in); 56 std::string name(symbol_id id) const; 57 std::string name(symbol_id id,const std::string& default_name) const; 58 symbol_id id(const std::string& name) const; 59 symbol_id id(const std::string& name,symbol_id default_id) const; 60 private: 61 typedef std::map<std::string,symbol_id> symbol_map; 62 std::vector<std::string> ids_to_names; 63 symbol_map names_to_ids; 64 }; 65 66 struct arc 67 { 68 state_id target; 69 symbol_id isymbol,osymbol; arcRHVoice::fst::arc70 arc(): 71 target(0), 72 isymbol(0), 73 osymbol(0) 74 { 75 } 76 arcRHVoice::fst::arc77 arc(state_id t,symbol_id i,symbol_id o): 78 target(t), 79 isymbol(i), 80 osymbol(o) 81 { 82 } 83 84 explicit arc(std::istream& in); 85 }; 86 87 class state 88 { 89 private: 90 struct compare_arcs: public std::binary_function<bool,const arc&,const arc&> 91 { operator ()RHVoice::fst::state::compare_arcs92 bool operator()(const arc& arc1,const arc& arc2) const 93 { 94 return (arc1.isymbol<arc2.isymbol); 95 } 96 }; 97 98 public: 99 explicit state(std::istream& in); 100 is_final() const101 bool is_final() const 102 { 103 return final; 104 } 105 state()106 state(): 107 final(false) 108 { 109 } 110 111 typedef std::vector<arc>::const_iterator arc_iterator; 112 begin() const113 arc_iterator begin() const 114 { 115 return arcs.begin(); 116 } 117 find_arc(symbol_id id) const118 arc_iterator find_arc(symbol_id id) const 119 { 120 arc_iterator it=std::lower_bound(arcs.begin(),arcs.end(),arc(0,id,0),compare_arcs()); 121 return (((it!=arcs.end())&&(it->isymbol==id))?it:arcs.end()); 122 } 123 end() const124 arc_iterator end() const 125 { 126 return arcs.end(); 127 } 128 private: 129 bool final; 130 std::vector<arc> arcs; 131 }; 132 133 typedef state::arc_iterator arc_iterator; 134 135 std::vector<state> states; 136 alphabet symbols; 137 138 typedef std::vector<state>::const_iterator state_iterator; 139 typedef std::vector<std::pair<std::string,symbol_id> > input_symbols; 140 141 class arc_filter 142 { 143 public: arc_filter(state_iterator sstate,symbol_id isymbol)144 arc_filter(state_iterator sstate,symbol_id isymbol): 145 source_state(sstate), 146 current_arc(source_state->find_arc(isymbol)) 147 { 148 if(current_arc==source_state->end()) 149 current_arc=source_state->find_arc(0); 150 } 151 get() const152 const arc& get() const 153 { 154 return *current_arc; 155 } 156 157 void next(); 158 done() const159 bool done() const 160 { 161 return (current_arc==source_state->end()); 162 } 163 private: 164 state_iterator source_state; 165 arc_iterator current_arc; 166 }; 167 168 template<class output_iterator> bool do_translate(const input_symbols& input,output_iterator output) const; append_input_symbol(const std::string & name,input_symbols & dest) const169 void append_input_symbol(const std::string& name,input_symbols& dest) const 170 { 171 dest.push_back(input_symbols::value_type(name,symbols.id(name,1))); 172 } 173 append_input_symbol(utf8::uint32_t chr,input_symbols & dest) const174 void append_input_symbol(utf8::uint32_t chr,input_symbols& dest) const 175 { 176 std::string name; 177 utf8::append(chr,std::back_inserter(name)); 178 append_input_symbol(name,dest); 179 } 180 append_input_symbol(const item & i,input_symbols & dest) const181 void append_input_symbol(const item& i,input_symbols& dest) const 182 { 183 append_input_symbol(i.get("name").as<std::string>(),dest); 184 } 185 186 public: 187 template<class input_iterator,class output_iterator> bool translate(input_iterator first,input_iterator last,output_iterator output) const; 188 }; 189 190 template<class output_iterator> do_translate(const input_symbols & input,output_iterator output) const191 bool fst::do_translate(const input_symbols& input,output_iterator output) const 192 { 193 if(states.empty()) 194 return false; 195 input_symbols::const_iterator pos=input.begin(); 196 if(pos==input.end()) 197 return false; 198 arc_filter f(states.begin(),pos->second); 199 if(f.done()) 200 return false; 201 std::vector<arc_filter> path; 202 path.push_back(f); 203 if(f.get().isymbol!=0) 204 ++pos; 205 while(!path.empty()) 206 { 207 if(pos==input.end()) 208 { 209 if(states[path.back().get().target].is_final()) 210 break; 211 else 212 f=arc_filter(states.begin()+path.back().get().target,0); 213 } 214 else 215 f=arc_filter(states.begin()+path.back().get().target,pos->second); 216 if(f.done()) 217 { 218 while(!path.empty()) 219 { 220 if(path.back().get().isymbol!=0) 221 --pos; 222 path.back().next(); 223 if(path.back().done()) 224 path.pop_back(); 225 else 226 { 227 if(path.back().get().isymbol!=0) 228 ++pos; 229 break; 230 } 231 } 232 } 233 else 234 { 235 path.push_back(f); 236 if(f.get().isymbol!=0) 237 ++pos; 238 } 239 } 240 if((pos!=input.end())||path.empty()||(!states[path.back().get().target].is_final())) 241 return false; 242 pos=input.begin(); 243 for(std::vector<arc_filter>::const_iterator it=path.begin();it!=path.end();++it) 244 { 245 if(it->get().osymbol!=0) 246 { 247 if(it->get().osymbol==1) 248 { 249 *output=pos->first; 250 ++output; 251 } 252 else 253 { 254 *output=symbols.name(it->get().osymbol); 255 ++output; 256 } 257 } 258 if(it->get().isymbol!=0) 259 ++pos; 260 } 261 return true; 262 } 263 264 template<class input_iterator,class output_iterator> translate(input_iterator first,input_iterator last,output_iterator output) const265 bool fst::translate(input_iterator first,input_iterator last,output_iterator output) const 266 { 267 input_symbols input; 268 for(input_iterator it=first;it!=last;++it) 269 { 270 append_input_symbol(*it,input); 271 } 272 return do_translate(input,output); 273 } 274 } 275 #endif 276