1 /*
2 * UaryOPNode.cpp
3 *
4 * Created on: 6 Nov 2013
5 * Author: s0965328
6 */
7
8 #include "UaryOPNode.h"
9 #include "BinaryOPNode.h"
10 #include "PNode.h"
11 #include "Stack.h"
12 #include "Tape.h"
13 #include "Edge.h"
14 #include "EdgeSet.h"
15 #include "auto_diff_types.h"
16
17 #include <list>
18
19 using namespace std;
20
21 namespace AutoDiff {
22
UaryOPNode(OPCODE op_,Node * left)23 UaryOPNode::UaryOPNode(OPCODE op_, Node* left): OPNode(op_,left) {
24 }
25
createUnaryOpNode(OPCODE op,Node * left)26 OPNode* UaryOPNode::createUnaryOpNode(OPCODE op, Node* left)
27 {
28 assert(left!=NULL);
29 OPNode* node = NULL;
30 if(op == OP_SQRT)
31 {
32 double param = 0.5;
33 node = BinaryOPNode::createBinaryOpNode(OP_POW,left,new PNode(param));
34 }
35 else if(op == OP_NEG)
36 {
37 double param = -1;
38 node = BinaryOPNode::createBinaryOpNode(OP_TIMES,left,new PNode(param));
39 }
40 else
41 {
42 node = new UaryOPNode(op,left);
43 }
44 return node;
45 }
46
~UaryOPNode()47 UaryOPNode::~UaryOPNode() {
48
49 }
50
51
inorder_visit(int level,ostream & oss)52 void UaryOPNode::inorder_visit(int level,ostream& oss){
53 if(left!=NULL){
54 left->inorder_visit(level+1,oss);
55 }
56 oss<<this->toString(level)<<endl;
57 }
58
collect_vnodes(boost::unordered_set<Node * > & nodes,unsigned int & total)59 void UaryOPNode::collect_vnodes(boost::unordered_set<Node*>& nodes,unsigned int& total)
60 {
61 total++;
62 if(left!=NULL){
63 left->collect_vnodes(nodes,total);
64 }
65 }
66
eval_function()67 void UaryOPNode::eval_function()
68 {
69 if(left!=NULL){
70 left->eval_function();
71 }
72 this->calc_eval_function();
73 }
74
75 //1. visiting left if not NULL
76 //2. then, visiting right if not NULL
77 //3. calculating the immediate derivative hu and hv
grad_reverse_0()78 void UaryOPNode::grad_reverse_0(){
79 assert(left!=NULL);
80 this->adj = 0;
81 left->grad_reverse_0();
82 this->calc_grad_reverse_0();
83 }
84
85 //right left - right most traversal
grad_reverse_1()86 void UaryOPNode::grad_reverse_1()
87 {
88 assert(left!=NULL);
89 double l_adj = SD->pop_back()*this->adj;
90 left->update_adj(l_adj);
91 left->grad_reverse_1();
92 }
93
calc_grad_reverse_0()94 void UaryOPNode::calc_grad_reverse_0()
95 {
96 assert(left!=NULL);
97 double hu = NaN_Double;
98 double lval = SV->pop_back();
99 double val = NaN_Double;
100 switch (op)
101 {
102 case OP_SIN:
103 val = sin(lval);
104 hu = cos(lval);
105 break;
106 case OP_COS:
107 val = cos(lval);
108 hu = -sin(lval);
109 break;
110 default:
111 cerr<<"error op not impl"<<endl;
112 break;
113 }
114 SV->push_back(val);
115 SD->push_back(hu);
116 }
117
calc_eval_function()118 void UaryOPNode::calc_eval_function()
119 {
120 double lval = SV->pop_back();
121 double val = NaN_Double;
122 switch (op)
123 {
124 case OP_SIN:
125 assert(left!=NULL);
126 val = sin(lval);
127 break;
128 case OP_COS:
129 assert(left!=NULL);
130 val = cos(lval);
131 break;
132 default:
133 cerr<<"op["<<op<<"] not yet implemented!!"<<endl;
134 assert(false);
135 break;
136 }
137 SV->push_back(val);
138 }
139
hess_reverse_0_init_n_in_arcs()140 void UaryOPNode::hess_reverse_0_init_n_in_arcs()
141 {
142 this->left->hess_reverse_0_init_n_in_arcs();
143 this->Node::hess_reverse_0_init_n_in_arcs();
144 }
145
hess_reverse_1_clear_index()146 void UaryOPNode::hess_reverse_1_clear_index()
147 {
148 this->left->hess_reverse_1_clear_index();
149 this->Node::hess_reverse_1_clear_index();
150 }
151
hess_reverse_0()152 unsigned int UaryOPNode::hess_reverse_0()
153 {
154 assert(left!=NULL);
155 if(index==0)
156 {
157 unsigned int lindex=0;
158 lindex = left->hess_reverse_0();
159 assert(lindex!=0);
160 II->set(lindex);
161 double lx,lx_bar,lw,lw_bar;
162 double x,x_bar,w,w_bar;
163 double l_dh;
164 switch(op)
165 {
166 case OP_SIN:
167 assert(left != NULL);
168 left->hess_reverse_0_get_values(lindex,lx,lx_bar,lw,lw_bar);
169 x = sin(lx);
170 x_bar = 0;
171 l_dh = cos(lx);
172 w = lw*l_dh;
173 w_bar = 0;
174 break;
175 case OP_COS:
176 assert(left!=NULL);
177 left->hess_reverse_0_get_values(lindex,lx,lx_bar,lw,lw_bar);
178 x = cos(lx);
179 x_bar = 0;
180 l_dh = -sin(lx);
181 w = lw*l_dh;
182 w_bar = 0;
183 break;
184 default:
185 cerr<<"op["<<op<<"] not yet implemented!"<<endl;
186 assert(false);
187 break;
188 }
189 TT->set(x);
190 TT->set(x_bar);
191 TT->set(w);
192 TT->set(w_bar);
193 TT->set(l_dh);
194 assert(TT->index == TT->index);
195 index = TT->index;
196 }
197 return index;
198 }
199
hess_reverse_0_get_values(unsigned int i,double & x,double & x_bar,double & w,double & w_bar)200 void UaryOPNode::hess_reverse_0_get_values(unsigned int i,double& x, double& x_bar, double& w, double& w_bar)
201 {
202 --i; // skip the l_dh (ie, dh/du)
203 w_bar = TT->get(--i);
204 w = TT->get(--i);
205 x_bar = TT->get(--i);
206 x = TT->get(--i);
207 }
208
hess_reverse_1(unsigned int i)209 void UaryOPNode::hess_reverse_1(unsigned int i)
210 {
211 n_in_arcs--;
212 if(n_in_arcs==0)
213 {
214 double lindex = II->get(--(II->index));
215 // cout<<"li["<<lindex<<"]\t"<<this->toString(0)<<endl;
216 double l_dh = TT->get(--i);
217 double w_bar = TT->get(--i);
218 --i; //skip w
219 double x_bar = TT->get(--i);
220 --i; //skip x
221 // cout<<"i["<<i<<"]"<<endl;
222
223 assert(left!=NULL);
224 left->update_x_bar(lindex,x_bar*l_dh);
225 double lw_bar = 0;
226 double lw = 0,lx = 0;
227 left->hess_reverse_1_get_xw(lindex,lw,lx);
228 switch(op)
229 {
230 case OP_SIN:
231 assert(l_dh == cos(lx));
232 lw_bar += w_bar*l_dh;
233 lw_bar += x_bar*lw*(-sin(lx));
234 break;
235 case OP_COS:
236 assert(l_dh == -sin(lx));
237 lw_bar += w_bar*l_dh;
238 lw_bar += x_bar*lw*(-cos(lx));
239 break;
240 default:
241 cerr<<"op["<<op<<"] not yet implemented!"<<endl;
242 break;
243 }
244 left->update_w_bar(lindex,lw_bar);
245 left->hess_reverse_1(lindex);
246 }
247 }
248
hess_reverse_1_init_x_bar(unsigned int i)249 void UaryOPNode::hess_reverse_1_init_x_bar(unsigned int i)
250 {
251 TT->at(i-4) = 1;
252 }
253
update_x_bar(unsigned int i,double v)254 void UaryOPNode::update_x_bar(unsigned int i ,double v)
255 {
256 TT->at(i-4) += v;
257 }
update_w_bar(unsigned int i,double v)258 void UaryOPNode::update_w_bar(unsigned int i ,double v)
259 {
260 TT->at(i-2) += v;
261 }
hess_reverse_1_get_xw(unsigned int i,double & w,double & x)262 void UaryOPNode::hess_reverse_1_get_xw(unsigned int i,double& w,double& x)
263 {
264 w = TT->get(i-3);
265 x = TT->get(i-5);
266 }
hess_reverse_get_x(unsigned int i,double & x)267 void UaryOPNode::hess_reverse_get_x(unsigned int i, double& x)
268 {
269 x = TT->get(i-5);
270 }
271
nonlinearEdges(EdgeSet & edges)272 void UaryOPNode::nonlinearEdges(EdgeSet& edges)
273 {
274 for(list<Edge>::iterator it=edges.edges.begin();it!=edges.edges.end();)
275 {
276 Edge& e = *it;
277 if(e.a == this || e.b == this){
278 if(e.a == this && e.b == this)
279 {
280 Edge e1(left,left);
281 edges.insertEdge(e1);
282 }
283 else{
284 Node* o = e.a==this?e.b:e.a;
285 Edge e1(left,o);
286 edges.insertEdge(e1);
287 }
288 it = edges.edges.erase(it);
289 }
290 else
291 {
292 it++;
293 }
294 }
295
296 Edge e1(left,left);
297 switch(op)
298 {
299 case OP_SIN:
300 edges.insertEdge(e1);
301 break;
302 case OP_COS:
303 edges.insertEdge(e1);
304 break;
305 default:
306 cerr<<"op["<<op<<"] is not yet implemented !"<<endl;
307 assert(false);
308 break;
309 }
310 left->nonlinearEdges(edges);
311 }
312
313 #if FORWARD_ENABLED
hess_forward(unsigned int len,double ** ret_vec)314 void UaryOPNode::hess_forward(unsigned int len, double** ret_vec)
315 {
316 double* lvec = NULL;
317 if(left!=NULL){
318 left->hess_forward(len,&lvec);
319 }
320
321 *ret_vec = new double[len];
322 this->hess_forward_calc0(len,lvec,*ret_vec);
323 delete[] lvec;
324 }
325
hess_forward_calc0(unsigned int & len,double * lvec,double * ret_vec)326 void UaryOPNode::hess_forward_calc0(unsigned int& len, double* lvec, double* ret_vec)
327 {
328 double hu = NaN_Double;
329 double lval = NaN_Double;
330 double val = NaN_Double;
331 unsigned int index = 0;
332 switch (op)
333 {
334 case OP_SIN:
335 assert(left!=NULL);
336 lval = SV->pop_back();
337 val = sin(lval);
338 SV->push_back(val);
339 hu = cos(lval);
340
341 double coeff;
342 coeff = -val; //=sin(left->val); -- and avoid cross initialisation
343 //calculate the first order derivatives
344 for(unsigned int i =0;i<AutoDiff::num_var;++i)
345 {
346 ret_vec[i] = hu*lvec[i] + 0;
347 }
348 //calculate the second order
349 index = AutoDiff::num_var;
350 for(unsigned int i=0;i<AutoDiff::num_var;++i)
351 {
352 for(unsigned int j=i;j<AutoDiff::num_var;++j)
353 {
354 ret_vec[index] = hu*lvec[index] + lvec[i] * coeff * lvec[j] + 0 + 0;
355 ++index;
356 }
357 }
358 assert(index==len);
359 break;
360 default:
361 cerr<<"op["<<op<<"] not yet implemented!";
362 break;
363 }
364 }
365 #endif
366
toString(int level)367 string UaryOPNode::toString(int level)
368 {
369 ostringstream oss;
370 string s(level,'\t');
371 oss<<s<<"[UaryOPNode]("<<op<<")";
372 return oss.str();
373 }
374
375 } /* namespace AutoDiff */
376