1 /*
2 * BinaryOPNode.cpp
3 *
4 * Created on: 6 Nov 2013
5 * Author: s0965328
6 */
7
8 #include "auto_diff_types.h"
9 #include "BinaryOPNode.h"
10 #include "PNode.h"
11 #include "Stack.h"
12 #include "Tape.h"
13 #include "EdgeSet.h"
14 #include "Node.h"
15 #include "VNode.h"
16 #include "OPNode.h"
17 #include "ActNode.h"
18 #include "EdgeSet.h"
19
20 namespace AutoDiff {
21
BinaryOPNode(OPCODE op_,Node * left_,Node * right_)22 BinaryOPNode::BinaryOPNode(OPCODE op_, Node* left_, Node* right_):OPNode(op_,left_),right(right_)
23 {
24 }
25
createBinaryOpNode(OPCODE op,Node * left,Node * right)26 OPNode* BinaryOPNode::createBinaryOpNode(OPCODE op, Node* left, Node* right)
27 {
28 assert(left!=NULL && right!=NULL);
29 OPNode* node = NULL;
30 node = new BinaryOPNode(op,left,right);
31 return node;
32 }
33
~BinaryOPNode()34 BinaryOPNode::~BinaryOPNode() {
35 if(right->getType()!=VNode_Type)
36 {
37 delete right;
38 right = NULL;
39 }
40 }
41
inorder_visit(int level,ostream & oss)42 void BinaryOPNode::inorder_visit(int level,ostream& oss){
43 if(left!=NULL){
44 left->inorder_visit(level+1,oss);
45 }
46 oss<<this->toString(level)<<endl;
47 if(right!=NULL){
48 right->inorder_visit(level+1,oss);
49 }
50 }
51
collect_vnodes(boost::unordered_set<Node * > & nodes,unsigned int & total)52 void BinaryOPNode::collect_vnodes(boost::unordered_set<Node*>& nodes,unsigned int& total){
53 total++;
54 if (left != NULL) {
55 left->collect_vnodes(nodes,total);
56 }
57 if (right != NULL) {
58 right->collect_vnodes(nodes,total);
59 }
60
61 }
62
eval_function()63 void BinaryOPNode::eval_function()
64 {
65 assert(left!=NULL && right!=NULL);
66 left->eval_function();
67 right->eval_function();
68 this->calc_eval_function();
69 }
70
calc_eval_function()71 void BinaryOPNode::calc_eval_function()
72 {
73 double x = NaN_Double;
74 double rx = SV->pop_back();
75 double lx = SV->pop_back();
76 switch (op)
77 {
78 case OP_PLUS:
79 x = lx + rx;
80 break;
81 case OP_MINUS:
82 x = lx - rx;
83 break;
84 case OP_TIMES:
85 x = lx * rx;
86 break;
87 case OP_DIVID:
88 x = lx / rx;
89 break;
90 case OP_POW:
91 x = pow(lx,rx);
92 break;
93 default:
94 cerr<<"op["<<op<<"] not yet implemented!!"<<endl;
95 assert(false);
96 break;
97 }
98 SV->push_back(x);
99 }
100
101
102 //1. visiting left if not NULL
103 //2. then, visiting right if not NULL
104 //3. calculating the immediate derivative hu and hv
grad_reverse_0()105 void BinaryOPNode::grad_reverse_0()
106 {
107 assert(left!=NULL && right != NULL);
108 this->adj = 0;
109 left->grad_reverse_0();
110 right->grad_reverse_0();
111 this->calc_grad_reverse_0();
112 }
113
114 //right left - right most traversal
grad_reverse_1()115 void BinaryOPNode::grad_reverse_1()
116 {
117 assert(right!=NULL && left!=NULL);
118 double r_adj = SD->pop_back()*this->adj;
119 right->update_adj(r_adj);
120 double l_adj = SD->pop_back()*this->adj;
121 left->update_adj(l_adj);
122
123 right->grad_reverse_1();
124 left->grad_reverse_1();
125 }
126
calc_grad_reverse_0()127 void BinaryOPNode::calc_grad_reverse_0()
128 {
129 assert(left!=NULL && right != NULL);
130 double l_dh = NaN_Double;
131 double r_dh = NaN_Double;
132 double rx = SV->pop_back();
133 double lx = SV->pop_back();
134 double x = NaN_Double;
135 switch (op)
136 {
137 case OP_PLUS:
138 x = lx + rx;
139 l_dh = 1;
140 r_dh = 1;
141 break;
142 case OP_MINUS:
143 x = lx - rx;
144 l_dh = 1;
145 r_dh = -1;
146 break;
147 case OP_TIMES:
148 x = lx * rx;
149 l_dh = rx;
150 r_dh = lx;
151 break;
152 case OP_DIVID:
153 x = lx / rx;
154 l_dh = 1 / rx;
155 r_dh = -(lx) / pow(rx, 2);
156 break;
157 case OP_POW:
158 if(right->getType()==PNode_Type){
159 x = pow(lx,rx);
160 l_dh = rx*pow(lx,(rx-1));
161 r_dh = 0;
162 }
163 else{
164 assert(lx>0.0); //otherwise log(lx) is not defined in read number
165 x = pow(lx,rx);
166 l_dh = rx*pow(lx,(rx-1));
167 r_dh = pow(lx,rx)*log(lx); //this is for x1^x2 when x1=0 cause r_dh become +inf, however d(0^x2)/d(x2) = 0
168 }
169 break;
170 default:
171 cerr<<"error op not impl"<<endl;
172 break;
173 }
174 SV->push_back(x);
175 SD->push_back(l_dh);
176 SD->push_back(r_dh);
177 }
178
hess_reverse_0_init_n_in_arcs()179 void BinaryOPNode::hess_reverse_0_init_n_in_arcs()
180 {
181 this->left->hess_reverse_0_init_n_in_arcs();
182 this->right->hess_reverse_0_init_n_in_arcs();
183 this->Node::hess_reverse_0_init_n_in_arcs();
184 }
185
hess_reverse_1_clear_index()186 void BinaryOPNode::hess_reverse_1_clear_index()
187 {
188 this->left->hess_reverse_1_clear_index();
189 this->right->hess_reverse_1_clear_index();
190 this->Node::hess_reverse_1_clear_index();
191 }
192
hess_reverse_0()193 unsigned int BinaryOPNode::hess_reverse_0()
194 {
195 assert(this->left!=NULL && right!=NULL);
196 if(index==0)
197 {
198 unsigned int lindex=0, rindex=0;
199 lindex = left->hess_reverse_0();
200 rindex = right->hess_reverse_0();
201 assert(lindex!=0 && rindex !=0);
202 II->set(lindex);
203 II->set(rindex);
204 double rx,rx_bar,rw,rw_bar;
205 double lx,lx_bar,lw,lw_bar;
206 double x,x_bar,w,w_bar;
207 double r_dh, l_dh;
208 right->hess_reverse_0_get_values(rindex,rx,rx_bar,rw,rw_bar);
209 left->hess_reverse_0_get_values(lindex,lx,lx_bar,lw,lw_bar);
210 switch(op)
211 {
212 case OP_PLUS:
213 // cout<<"lindex="<<lindex<<"\trindex="<<rindex<<"\tI="<<I<<endl;
214 x = lx + rx;
215 // cout<<lx<<"\t+"<<rx<<"\t="<<x<<"\t\t"<<toString(0)<<endl;
216 x_bar = 0;
217 l_dh = 1;
218 r_dh = 1;
219 w = lw * l_dh + rw * r_dh;
220 // cout<<lw<<"\t+"<<rw<<"\t="<<w<<"\t\t"<<toString(0)<<endl;
221 w_bar = 0;
222 break;
223 case OP_MINUS:
224 x = lx - rx;
225 x_bar = 0;
226 l_dh = 1;
227 r_dh = -1;
228 w = lw * l_dh + rw * r_dh;
229 w_bar = 0;
230 break;
231 case OP_TIMES:
232 x = lx * rx;
233 x_bar = 0;
234 l_dh = rx;
235 r_dh = lx;
236 w = lw * l_dh + rw * r_dh;
237 w_bar = 0;
238 break;
239 case OP_DIVID:
240 x = lx / rx;
241 x_bar = 0;
242 l_dh = 1/rx;
243 r_dh = -lx/pow(rx,2);
244 w = lw * l_dh + rw * r_dh;
245 w_bar = 0;
246 break;
247 case OP_POW:
248 if(right->getType()==PNode_Type)
249 {
250 x = pow(lx,rx);
251 x_bar = 0;
252 l_dh = rx*pow(lx,(rx-1));
253 r_dh = 0;
254 w = lw * l_dh + rw * r_dh;
255 w_bar = 0;
256 }
257 else
258 {
259 assert(lx>0.0); //otherwise log(lx) undefined in real number
260 x = pow(lx,rx);
261 x_bar = 0;
262 l_dh = rx*pow(lx,(rx-1));
263 r_dh = pow(lx,rx)*log(lx); //log(lx) cause -inf when lx=0;
264 w = lw * l_dh + rw * r_dh;
265 w_bar = 0;
266 }
267 break;
268 default:
269 cerr<<"op["<<op<<"] not yet implemented!"<<endl;
270 assert(false);
271 break;
272 }
273 TT->set(x);
274 TT->set(x_bar);
275 TT->set(w);
276 TT->set(w_bar);
277 TT->set(l_dh);
278 TT->set(r_dh);
279 assert(TT->index == TT->index);
280 index = TT->index;
281 }
282 return index;
283 }
284
hess_reverse_0_get_values(unsigned int i,double & x,double & x_bar,double & w,double & w_bar)285 void BinaryOPNode::hess_reverse_0_get_values(unsigned int i,double& x, double& x_bar, double& w, double& w_bar)
286 {
287 --i; // skip the r_dh (ie, dh/du)
288 --i; // skip the l_dh (ie. dh/dv)
289 w_bar = TT->get(--i);
290 w = TT->get(--i);
291 x_bar = TT->get(--i);
292 x = TT->get(--i);
293 }
294
hess_reverse_1(unsigned int i)295 void BinaryOPNode::hess_reverse_1(unsigned int i)
296 {
297 n_in_arcs--;
298 if(n_in_arcs==0)
299 {
300 assert(right!=NULL && left!=NULL);
301 unsigned int rindex = II->get(--(II->index));
302 unsigned int lindex = II->get(--(II->index));
303 // cout<<"ri["<<rindex<<"]\tli["<<lindex<<"]\t"<<this->toString(0)<<endl;
304 double r_dh = TT->get(--i);
305 double l_dh = TT->get(--i);
306 double w_bar = TT->get(--i);
307 --i; //skip w
308 double x_bar = TT->get(--i);
309 --i; //skip x
310
311 double lw_bar=0,rw_bar=0;
312 double lw=0,lx=0; left->hess_reverse_1_get_xw(lindex,lw,lx);
313 double rw=0,rx=0; right->hess_reverse_1_get_xw(rindex,rw,rx);
314 switch(op)
315 {
316 case OP_PLUS:
317 assert(l_dh==1);
318 assert(r_dh==1);
319 lw_bar += w_bar*l_dh;
320 rw_bar += w_bar*r_dh;
321 break;
322 case OP_MINUS:
323 assert(l_dh==1);
324 assert(r_dh==-1);
325 lw_bar += w_bar*l_dh;
326 rw_bar += w_bar*r_dh;
327 break;
328 case OP_TIMES:
329 assert(rx == l_dh);
330 assert(lx == r_dh);
331 lw_bar += w_bar*rx;
332 lw_bar += x_bar*lw*0 + x_bar*rw*1;
333 rw_bar += w_bar*lx;
334 rw_bar += x_bar*lw*1 + x_bar*rw*0;
335 break;
336 case OP_DIVID:
337 lw_bar += w_bar*l_dh;
338 lw_bar += x_bar*lw*0 + x_bar*rw*-1/(pow(rx,2));
339 rw_bar += w_bar*r_dh;
340 rw_bar += x_bar*lw*-1/pow(rx,2) + x_bar*rw*2*lx/pow(rx,3);
341 break;
342 case OP_POW:
343 if(right->getType()==PNode_Type){
344 lw_bar += w_bar*l_dh;
345 lw_bar += x_bar*lw*pow(lx,rx-2)*rx*(rx-1) + 0;
346 rw_bar += w_bar*r_dh; assert(r_dh==0.0);
347 rw_bar += 0;
348 }
349 else{
350 assert(lx>0.0); //otherwise log(lx) is not define in Real
351 lw_bar += w_bar*l_dh;
352 lw_bar += x_bar*lw*pow(lx,rx-2)*rx*(rx-1) + x_bar*rw*pow(lx,rx-1)*(rx*log(lx)+1); //cause log(lx)=-inf when
353 rw_bar += w_bar*r_dh;
354 rw_bar += x_bar*lw*pow(lx,rx-1)*(rx*log(lx)+1) + x_bar*rw*pow(lx,rx)*pow(log(lx),2);
355 }
356 break;
357 default:
358 cerr<<"op["<<op<<"] not yet implemented !"<<endl;
359 assert(false);
360 break;
361 }
362 double rx_bar = x_bar*r_dh;
363 double lx_bar = x_bar*l_dh;
364 right->update_x_bar(rindex,rx_bar);
365 left->update_x_bar(lindex,lx_bar);
366 right->update_w_bar(rindex,rw_bar);
367 left->update_w_bar(lindex,lw_bar);
368
369
370 this->right->hess_reverse_1(rindex);
371 this->left->hess_reverse_1(lindex);
372 }
373 }
hess_reverse_1_init_x_bar(unsigned int i)374 void BinaryOPNode::hess_reverse_1_init_x_bar(unsigned int i)
375 {
376 TT->at(i-5) = 1;
377 }
update_x_bar(unsigned int i,double v)378 void BinaryOPNode::update_x_bar(unsigned int i ,double v)
379 {
380 TT->at(i-5) += v;
381 }
update_w_bar(unsigned int i,double v)382 void BinaryOPNode::update_w_bar(unsigned int i ,double v)
383 {
384 TT->at(i-3) += v;
385 }
hess_reverse_1_get_xw(unsigned int i,double & w,double & x)386 void BinaryOPNode::hess_reverse_1_get_xw(unsigned int i,double& w,double& x)
387 {
388 w = TT->get(i-4);
389 x = TT->get(i-6);
390 }
hess_reverse_get_x(unsigned int i,double & x)391 void BinaryOPNode::hess_reverse_get_x(unsigned int i,double& x)
392 {
393 x = TT->get(i-6);
394 }
395
396
nonlinearEdges(EdgeSet & edges)397 void BinaryOPNode::nonlinearEdges(EdgeSet& edges)
398 {
399 for(list<Edge>::iterator it=edges.edges.begin();it!=edges.edges.end();)
400 {
401 Edge e = *it;
402 if(e.a==this || e.b == this){
403 if(e.a == this && e.b == this)
404 {
405 Edge e1(left,left);
406 Edge e2(right,right);
407 Edge e3(left,right);
408 edges.insertEdge(e1);
409 edges.insertEdge(e2);
410 edges.insertEdge(e3);
411 }
412 else
413 {
414 Node* o = e.a==this? e.b: e.a;
415 Edge e1(left,o);
416 Edge e2(right,o);
417 edges.insertEdge(e1);
418 edges.insertEdge(e2);
419 }
420 it = edges.edges.erase(it);
421 }
422 else
423 {
424 it++;
425 }
426 }
427
428 Edge e1(left,right);
429 Edge e2(left,left);
430 Edge e3(right,right);
431 switch(op)
432 {
433 case OP_PLUS:
434 case OP_MINUS:
435 //do nothing for linear operator
436 break;
437 case OP_TIMES:
438 edges.insertEdge(e1);
439 break;
440 case OP_DIVID:
441 edges.insertEdge(e1);
442 edges.insertEdge(e3);
443 break;
444 case OP_POW:
445 edges.insertEdge(e1);
446 edges.insertEdge(e2);
447 edges.insertEdge(e3);
448 break;
449 default:
450 cerr<<"op["<<op<<"] not yet implmented !"<<endl;
451 assert(false);
452 break;
453 }
454 left->nonlinearEdges(edges);
455 right->nonlinearEdges(edges);
456 }
457
458 #if FORWARD_ENABLED
459
hess_forward(unsigned int len,double ** ret_vec)460 void BinaryOPNode::hess_forward(unsigned int len, double** ret_vec)
461 {
462 double* lvec = NULL;
463 double* rvec = NULL;
464 if(left!=NULL){
465 left->hess_forward(len,&lvec);
466 }
467 if(right!=NULL){
468 right->hess_forward(len,&rvec);
469 }
470
471 *ret_vec = new double[len];
472 hess_forward_calc0(len,lvec,rvec,*ret_vec);
473 //delete lvec, rvec
474 delete[] lvec;
475 delete[] rvec;
476 }
477
hess_forward_calc0(unsigned int & len,double * lvec,double * rvec,double * ret_vec)478 void BinaryOPNode::hess_forward_calc0(unsigned int& len, double* lvec, double* rvec, double* ret_vec)
479 {
480 double hu = NaN_Double, hv= NaN_Double;
481 double lval = NaN_Double, rval = NaN_Double;
482 double val = NaN_Double;
483 unsigned int index = 0;
484 switch (op)
485 {
486 case OP_PLUS:
487 rval = SV->pop_back();
488 lval = SV->pop_back();
489 val = lval + rval;
490 SV->push_back(val);
491 //calculate the first order derivatives
492 for(unsigned int i=0;i<AutoDiff::num_var;++i)
493 {
494 ret_vec[i] = lvec[i]+rvec[i];
495 }
496 //calculate the second order
497 index = AutoDiff::num_var;
498 for(unsigned int i=0;i<AutoDiff::num_var;++i)
499 {
500 for(unsigned int j=i;j<AutoDiff::num_var;++j){
501 ret_vec[index] = lvec[index] + 0 + rvec[index] + 0;
502 ++index;
503 }
504 }
505 assert(index==len);
506 break;
507 case OP_MINUS:
508 rval = SV->pop_back();
509 lval = SV->pop_back();
510 val = lval + rval;
511 SV->push_back(val);
512 //calculate the first order derivatives
513 for(unsigned int i=0;i<AutoDiff::num_var;++i)
514 {
515 ret_vec[i] = lvec[i] - rvec[i];
516 }
517 //calculate the second order
518 index = AutoDiff::num_var;
519 for(unsigned int i=0;i<AutoDiff::num_var;++i)
520 {
521 for(unsigned int j=i;j<AutoDiff::num_var;++j){
522 ret_vec[index] = lvec[index] + 0 - rvec[index] + 0;
523 ++index;
524 }
525 }
526 assert(index==len);
527 break;
528 case OP_TIMES:
529 rval = SV->pop_back();
530 lval = SV->pop_back();
531 val = lval * rval;
532 SV->push_back(val);
533 hu = rval;
534 hv = lval;
535 //calculate the first order derivatives
536 for(unsigned int i =0;i<AutoDiff::num_var;++i)
537 {
538 ret_vec[i] = hu*lvec[i] + hv*rvec[i];
539 }
540 //calculate the second order
541 index = AutoDiff::num_var;
542 for(unsigned int i=0;i<AutoDiff::num_var;++i)
543 {
544 for(unsigned int j=i;j<AutoDiff::num_var;++j)
545 {
546 ret_vec[index] = hu * lvec[index] + lvec[i] * rvec[j]+hv * rvec[index] + rvec[i] * lvec[j];
547 ++index;
548 }
549 }
550 assert(index==len);
551 break;
552 case OP_POW:
553 rval = SV->pop_back();
554 lval = SV->pop_back();
555 val = pow(lval,rval);
556 SV->push_back(val);
557 if(left->getType()==PNode_Type && right->getType()==PNode_Type)
558 {
559 std::fill_n(ret_vec,len,0);
560 }
561 else
562 {
563 hu = rval*pow(lval,(rval-1));
564 hv = pow(lval,rval)*log(lval);
565 if(left->getType()==PNode_Type)
566 {
567 double coeff = pow(log(lval),2)*pow(lval,rval);
568 //calculate the first order derivatives
569 for(unsigned int i =0;i<AutoDiff::num_var;++i)
570 {
571 ret_vec[i] = hu*lvec[i] + hv*rvec[i];
572 }
573 //calculate the second order
574 index = AutoDiff::num_var;
575 for(unsigned int i=0;i<AutoDiff::num_var;++i)
576 {
577 for(unsigned int j=i;j<AutoDiff::num_var;++j)
578 {
579 ret_vec[index] = 0 + 0 + hv * rvec[index] + rvec[i] * coeff * rvec[j];
580 ++index;
581 }
582 }
583 }
584 else if(right->getType()==PNode_Type)
585 {
586 double coeff = rval*(rval-1)*pow(lval,rval-2);
587 //calculate the first order derivatives
588 for(unsigned int i =0;i<AutoDiff::num_var;++i)
589 {
590 ret_vec[i] = hu*lvec[i] + hv*rvec[i];
591 }
592 //calculate the second order
593 index = AutoDiff::num_var;
594 for(unsigned int i=0;i<AutoDiff::num_var;++i)
595 {
596 for(unsigned int j=i;j<AutoDiff::num_var;++j)
597 {
598 ret_vec[index] = hu*lvec[index] + lvec[i] * coeff * lvec[j] + 0 + 0;
599 ++index;
600 }
601 }
602 }
603 else
604 {
605 assert(false);
606 }
607 }
608 assert(index==len);
609 break;
610 case OP_SIN: //TODO should move to UnaryOPNode.cpp?
611 assert(left!=NULL&&right==NULL);
612 lval = SV->pop_back();
613 val = sin(lval);
614 SV->push_back(val);
615 hu = cos(lval);
616
617 double coeff;
618 coeff = -val; //=sin(left->val); -- and avoid cross initialisation
619 //calculate the first order derivatives
620 for(unsigned int i =0;i<AutoDiff::num_var;++i)
621 {
622 ret_vec[i] = hu*lvec[i] + 0;
623 }
624 //calculate the second order
625 index = AutoDiff::num_var;
626 for(unsigned int i=0;i<AutoDiff::num_var;++i)
627 {
628 for(unsigned int j=i;j<AutoDiff::num_var;++j)
629 {
630 ret_vec[index] = hu*lvec[index] + lvec[i] * coeff * lvec[j] + 0 + 0;
631 ++index;
632 }
633 }
634 assert(index==len);
635 break;
636 default:
637 cerr<<"op["<<op<<"] not yet implemented!";
638 break;
639 }
640 }
641 #endif
642
643
toString(int level)644 string BinaryOPNode::toString(int level){
645 ostringstream oss;
646 string s(level,'\t');
647 oss<<s<<"[BinaryOPNode]("<<op<<")";
648 return oss.str();
649 }
650
651 } /* namespace AutoDiff */
652