1 /* Copyright (C) 2012  Olga Yakovleva <yakovleva.o.v@gmail.com> */
2 
3 /* This program is free software: you can redistribute it and/or modify */
4 /* it under the terms of the GNU Lesser General Public License as published by */
5 /* the Free Software Foundation, either version 2.1 of the License, or */
6 /* (at your option) any later version. */
7 
8 /* This program is distributed in the hope that it will be useful, */
9 /* but WITHOUT ANY WARRANTY; without even the implied warranty of */
10 /* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the */
11 /* GNU Lesser General Public License for more details. */
12 
13 /* You should have received a copy of the GNU Lesser General Public License */
14 /* along with this program.  If not, see <http://www.gnu.org/licenses/>. */
15 
16 #include <sstream>
17 #include "core/exception.hpp"
18 #include "core/dtree.hpp"
19 
20 namespace RHVoice
21 {
22   namespace
23   {
24     const unsigned int value_string=0;
25     const unsigned int value_number=1;
26     const unsigned int value_list=2;
27     const unsigned int condition_equal=1;
28     const unsigned int condition_less=2;
29     const unsigned int condition_grater=3;
30     const unsigned int condition_in=4;
31     const std::string err_msg("Incorrect format of the decision tree file");
32 
read_number(std::istream & in)33     inline unsigned int read_number(std::istream& in)
34     {
35       uint8_t n;
36       if(!io::read_integer(in,n))
37         throw file_format_error(err_msg);
38       return n;
39     }
40 
read_string(std::istream & in)41     inline std::string read_string(std::istream& in)
42     {
43       std::string s;
44       if(!io::read_string(in,s))
45         throw file_format_error(err_msg);
46       return s;
47     }
48   }
49 
num_equal(unsigned int num)50   dtree::num_equal::num_equal(unsigned int num):
51     as_number(num)
52   {
53     std::ostringstream os;
54     os << as_number;
55     as_string=os.str();
56   }
57 
test(const value & val) const58   bool dtree::num_equal::test(const value& val) const
59   {
60     if(val.empty())
61       return (as_number==0);
62     else
63       {
64         if(val.is<std::string>())
65           return (val.as<std::string>()==as_string);
66         else
67           return (val.as<unsigned int>()==as_number);
68       }
69   }
70 
in_list(std::istream & in)71   dtree::in_list::in_list(std::istream& in)
72   {
73     unsigned int size=read_number(in);
74     if(size==0)
75       throw file_format_error(err_msg);
76     tests.reserve(size);
77     for(unsigned int i=0;i<size;++i)
78       {
79         switch(read_number(in))
80           {
81           case value_string:
82             tests.push_back(std::shared_ptr<condition>(new str_equal(read_string(in))));
83             break;
84           case value_number:
85             tests.push_back(std::shared_ptr<condition>(new num_equal(read_number(in))));
86             break;
87           default:
88             throw file_format_error(err_msg);
89           }
90       }
91   }
92 
test(const value & val) const93   bool dtree::in_list::test(const value& val) const
94   {
95     for(std::vector<std::shared_ptr<condition> >::const_iterator it(tests.begin());it!=tests.end();++it)
96       {
97         if((*it)->test(val))
98           return true;
99       }
100     return false;
101   }
102 
leaf_node(std::istream & in)103   dtree::leaf_node::leaf_node(std::istream& in)
104   {
105     unsigned int type=read_number(in);
106     switch(type)
107       {
108       case value_string:
109         answer=read_string(in);
110         break;
111       case value_number:
112         answer=read_number(in);
113         break;
114       default:
115         throw file_format_error(err_msg);
116       }
117   }
118 
internal_node(std::istream & in,unsigned int qtype)119   dtree::internal_node::internal_node(std::istream& in,unsigned int qtype):
120     feature_name(read_string(in))
121   {
122     unsigned int vtype=read_number(in);
123     switch(qtype)
124       {
125       case condition_equal:
126         switch(vtype)
127           {
128           case value_string:
129             question.reset(new str_equal(read_string(in)));
130             break;
131           case value_number:
132             question.reset(new num_equal(read_number(in)));
133             break;
134           default:
135             throw file_format_error(err_msg);
136           }
137         break;
138       case condition_less:
139         if(vtype!=value_number)
140           throw file_format_error(err_msg);
141         question.reset(new less(read_number(in)));
142         break;
143       case condition_grater:
144         if(vtype!=value_number)
145           throw file_format_error(err_msg);
146         question.reset(new grater(read_number(in)));
147         break;
148       case condition_in:
149         if(vtype!=value_list)
150           throw file_format_error(err_msg);
151         question.reset(new in_list(in));
152         break;
153       default:
154         throw file_format_error(err_msg);
155       }
156     unsigned int next_type=read_number(in);
157     if(next_type==0)
158       yes_node.reset(new leaf_node(in));
159     else
160       yes_node.reset(new internal_node(in,next_type));
161     next_type=read_number(in);
162     if(next_type==0)
163       no_node.reset(new leaf_node(in));
164     else
165       no_node.reset(new internal_node(in,next_type));
166   }
167 
get_next_node(const dtree::features & f) const168   const dtree::node* dtree::internal_node::get_next_node(const dtree::features& f) const
169   {
170     return (((question->test(f.eval(feature_name)))?yes_node:no_node).get());
171   }
172 
load(std::istream & in)173   void dtree::load(std::istream& in)
174   {
175     unsigned int type=read_number(in);
176     if(type==0)
177       root.reset(new leaf_node(in));
178     else
179       root.reset(new internal_node(in,type));
180   }
181 
predict(const dtree::features & f) const182   const value& dtree::predict(const dtree::features& f) const
183   {
184     const node* cur_node=root.get();
185     while(!(cur_node->is_leaf()))
186       {
187         cur_node=cur_node->get_next_node(f);
188       }
189     return *(cur_node->get_answer());
190   }
191 }
192