1 /* Copyright (C) 2012 Olga Yakovleva <yakovleva.o.v@gmail.com> */ 2 3 /* This program is free software: you can redistribute it and/or modify */ 4 /* it under the terms of the GNU Lesser General Public License as published by */ 5 /* the Free Software Foundation, either version 2.1 of the License, or */ 6 /* (at your option) any later version. */ 7 8 /* This program is distributed in the hope that it will be useful, */ 9 /* but WITHOUT ANY WARRANTY; without even the implied warranty of */ 10 /* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the */ 11 /* GNU Lesser General Public License for more details. */ 12 13 /* You should have received a copy of the GNU Lesser General Public License */ 14 /* along with this program. If not, see <http://www.gnu.org/licenses/>. */ 15 16 #include <sstream> 17 #include "core/exception.hpp" 18 #include "core/dtree.hpp" 19 20 namespace RHVoice 21 { 22 namespace 23 { 24 const unsigned int value_string=0; 25 const unsigned int value_number=1; 26 const unsigned int value_list=2; 27 const unsigned int condition_equal=1; 28 const unsigned int condition_less=2; 29 const unsigned int condition_grater=3; 30 const unsigned int condition_in=4; 31 const std::string err_msg("Incorrect format of the decision tree file"); 32 read_number(std::istream & in)33 inline unsigned int read_number(std::istream& in) 34 { 35 uint8_t n; 36 if(!io::read_integer(in,n)) 37 throw file_format_error(err_msg); 38 return n; 39 } 40 read_string(std::istream & in)41 inline std::string read_string(std::istream& in) 42 { 43 std::string s; 44 if(!io::read_string(in,s)) 45 throw file_format_error(err_msg); 46 return s; 47 } 48 } 49 num_equal(unsigned int num)50 dtree::num_equal::num_equal(unsigned int num): 51 as_number(num) 52 { 53 std::ostringstream os; 54 os << as_number; 55 as_string=os.str(); 56 } 57 test(const value & val) const58 bool dtree::num_equal::test(const value& val) const 59 { 60 if(val.empty()) 61 return (as_number==0); 62 else 63 { 64 if(val.is<std::string>()) 65 return (val.as<std::string>()==as_string); 66 else 67 return (val.as<unsigned int>()==as_number); 68 } 69 } 70 in_list(std::istream & in)71 dtree::in_list::in_list(std::istream& in) 72 { 73 unsigned int size=read_number(in); 74 if(size==0) 75 throw file_format_error(err_msg); 76 tests.reserve(size); 77 for(unsigned int i=0;i<size;++i) 78 { 79 switch(read_number(in)) 80 { 81 case value_string: 82 tests.push_back(std::shared_ptr<condition>(new str_equal(read_string(in)))); 83 break; 84 case value_number: 85 tests.push_back(std::shared_ptr<condition>(new num_equal(read_number(in)))); 86 break; 87 default: 88 throw file_format_error(err_msg); 89 } 90 } 91 } 92 test(const value & val) const93 bool dtree::in_list::test(const value& val) const 94 { 95 for(std::vector<std::shared_ptr<condition> >::const_iterator it(tests.begin());it!=tests.end();++it) 96 { 97 if((*it)->test(val)) 98 return true; 99 } 100 return false; 101 } 102 leaf_node(std::istream & in)103 dtree::leaf_node::leaf_node(std::istream& in) 104 { 105 unsigned int type=read_number(in); 106 switch(type) 107 { 108 case value_string: 109 answer=read_string(in); 110 break; 111 case value_number: 112 answer=read_number(in); 113 break; 114 default: 115 throw file_format_error(err_msg); 116 } 117 } 118 internal_node(std::istream & in,unsigned int qtype)119 dtree::internal_node::internal_node(std::istream& in,unsigned int qtype): 120 feature_name(read_string(in)) 121 { 122 unsigned int vtype=read_number(in); 123 switch(qtype) 124 { 125 case condition_equal: 126 switch(vtype) 127 { 128 case value_string: 129 question.reset(new str_equal(read_string(in))); 130 break; 131 case value_number: 132 question.reset(new num_equal(read_number(in))); 133 break; 134 default: 135 throw file_format_error(err_msg); 136 } 137 break; 138 case condition_less: 139 if(vtype!=value_number) 140 throw file_format_error(err_msg); 141 question.reset(new less(read_number(in))); 142 break; 143 case condition_grater: 144 if(vtype!=value_number) 145 throw file_format_error(err_msg); 146 question.reset(new grater(read_number(in))); 147 break; 148 case condition_in: 149 if(vtype!=value_list) 150 throw file_format_error(err_msg); 151 question.reset(new in_list(in)); 152 break; 153 default: 154 throw file_format_error(err_msg); 155 } 156 unsigned int next_type=read_number(in); 157 if(next_type==0) 158 yes_node.reset(new leaf_node(in)); 159 else 160 yes_node.reset(new internal_node(in,next_type)); 161 next_type=read_number(in); 162 if(next_type==0) 163 no_node.reset(new leaf_node(in)); 164 else 165 no_node.reset(new internal_node(in,next_type)); 166 } 167 get_next_node(const dtree::features & f) const168 const dtree::node* dtree::internal_node::get_next_node(const dtree::features& f) const 169 { 170 return (((question->test(f.eval(feature_name)))?yes_node:no_node).get()); 171 } 172 load(std::istream & in)173 void dtree::load(std::istream& in) 174 { 175 unsigned int type=read_number(in); 176 if(type==0) 177 root.reset(new leaf_node(in)); 178 else 179 root.reset(new internal_node(in,type)); 180 } 181 predict(const dtree::features & f) const182 const value& dtree::predict(const dtree::features& f) const 183 { 184 const node* cur_node=root.get(); 185 while(!(cur_node->is_leaf())) 186 { 187 cur_node=cur_node->get_next_node(f); 188 } 189 return *(cur_node->get_answer()); 190 } 191 } 192