1 /*
2  * BinaryOPNode.cpp
3  *
4  *  Created on: 6 Nov 2013
5  *      Author: s0965328
6  */
7 
8 #include "auto_diff_types.h"
9 #include "BinaryOPNode.h"
10 #include "PNode.h"
11 #include "Stack.h"
12 #include "Tape.h"
13 #include "EdgeSet.h"
14 #include "Node.h"
15 #include "VNode.h"
16 #include "OPNode.h"
17 #include "ActNode.h"
18 #include "EdgeSet.h"
19 
20 namespace AutoDiff {
21 
BinaryOPNode(OPCODE op_,Node * left_,Node * right_)22 BinaryOPNode::BinaryOPNode(OPCODE op_, Node* left_, Node* right_):OPNode(op_,left_),right(right_)
23 {
24 }
25 
createBinaryOpNode(OPCODE op,Node * left,Node * right)26 OPNode* BinaryOPNode::createBinaryOpNode(OPCODE op, Node* left, Node* right)
27 {
28 	assert(left!=NULL && right!=NULL);
29 	OPNode* node = NULL;
30 	node = new BinaryOPNode(op,left,right);
31 	return node;
32 }
33 
~BinaryOPNode()34 BinaryOPNode::~BinaryOPNode() {
35 	if(right->getType()!=VNode_Type)
36 	{
37 		delete right;
38 		right = NULL;
39 	}
40 }
41 
inorder_visit(int level,ostream & oss)42 void BinaryOPNode::inorder_visit(int level,ostream& oss){
43 	if(left!=NULL){
44 		left->inorder_visit(level+1,oss);
45 	}
46 	oss<<this->toString(level)<<endl;
47 	if(right!=NULL){
48 		right->inorder_visit(level+1,oss);
49 	}
50 }
51 
collect_vnodes(boost::unordered_set<Node * > & nodes,unsigned int & total)52 void BinaryOPNode::collect_vnodes(boost::unordered_set<Node*>& nodes,unsigned int& total){
53 	total++;
54 	if (left != NULL) {
55 		left->collect_vnodes(nodes,total);
56 	}
57 	if (right != NULL) {
58 		right->collect_vnodes(nodes,total);
59 	}
60 
61 }
62 
eval_function()63 void BinaryOPNode::eval_function()
64 {
65 	assert(left!=NULL && right!=NULL);
66 	left->eval_function();
67 	right->eval_function();
68 	this->calc_eval_function();
69 }
70 
calc_eval_function()71 void BinaryOPNode::calc_eval_function()
72 {
73 	double x = NaN_Double;
74 	double rx = SV->pop_back();
75 	double lx = SV->pop_back();
76 	switch (op)
77 	{
78 	case OP_PLUS:
79 		x = lx + rx;
80 		break;
81 	case OP_MINUS:
82 		x = lx - rx;
83 		break;
84 	case OP_TIMES:
85 		x = lx * rx;
86 		break;
87 	case OP_DIVID:
88 		x = lx / rx;
89 		break;
90 	case OP_POW:
91 		x = pow(lx,rx);
92 		break;
93 	default:
94 		cerr<<"op["<<op<<"] not yet implemented!!"<<endl;
95 		assert(false);
96 		break;
97 	}
98 	SV->push_back(x);
99 }
100 
101 
102 //1. visiting left if not NULL
103 //2. then, visiting right if not NULL
104 //3. calculating the immediate derivative hu and hv
grad_reverse_0()105 void BinaryOPNode::grad_reverse_0()
106 {
107 	assert(left!=NULL && right != NULL);
108 	this->adj = 0;
109 	left->grad_reverse_0();
110 	right->grad_reverse_0();
111 	this->calc_grad_reverse_0();
112 }
113 
114 //right left - right most traversal
grad_reverse_1()115 void BinaryOPNode::grad_reverse_1()
116 {
117 	assert(right!=NULL && left!=NULL);
118 	double r_adj = SD->pop_back()*this->adj;
119 	right->update_adj(r_adj);
120 	double l_adj = SD->pop_back()*this->adj;
121 	left->update_adj(l_adj);
122 
123 	right->grad_reverse_1();
124 	left->grad_reverse_1();
125 }
126 
calc_grad_reverse_0()127 void BinaryOPNode::calc_grad_reverse_0()
128 {
129 	assert(left!=NULL && right != NULL);
130 	double l_dh = NaN_Double;
131 	double r_dh = NaN_Double;
132 	double rx = SV->pop_back();
133 	double lx = SV->pop_back();
134 	double x = NaN_Double;
135 	switch (op)
136 	{
137 	case OP_PLUS:
138 		x = lx + rx;
139 		l_dh = 1;
140 		r_dh = 1;
141 		break;
142 	case OP_MINUS:
143 		x = lx - rx;
144 		l_dh = 1;
145 		r_dh = -1;
146 		break;
147 	case OP_TIMES:
148 		x = lx * rx;
149 		l_dh = rx;
150 		r_dh = lx;
151 		break;
152 	case OP_DIVID:
153 		x = lx / rx;
154 		l_dh = 1 / rx;
155 		r_dh = -(lx) / pow(rx, 2);
156 		break;
157 	case OP_POW:
158 		if(right->getType()==PNode_Type){
159 			x = pow(lx,rx);
160 			l_dh = rx*pow(lx,(rx-1));
161 			r_dh = 0;
162 		}
163 		else{
164 			assert(lx>0.0); //otherwise log(lx) is not defined in read number
165 			x = pow(lx,rx);
166 			l_dh = rx*pow(lx,(rx-1));
167 			r_dh = pow(lx,rx)*log(lx);  //this is for x1^x2 when x1=0 cause r_dh become +inf, however d(0^x2)/d(x2) = 0
168 		}
169 		break;
170 	default:
171 		cerr<<"error op not impl"<<endl;
172 		break;
173 	}
174 	SV->push_back(x);
175 	SD->push_back(l_dh);
176 	SD->push_back(r_dh);
177 }
178 
hess_reverse_0_init_n_in_arcs()179 void BinaryOPNode::hess_reverse_0_init_n_in_arcs()
180 {
181 	this->left->hess_reverse_0_init_n_in_arcs();
182 	this->right->hess_reverse_0_init_n_in_arcs();
183 	this->Node::hess_reverse_0_init_n_in_arcs();
184 }
185 
hess_reverse_1_clear_index()186 void BinaryOPNode::hess_reverse_1_clear_index()
187 {
188 	this->left->hess_reverse_1_clear_index();
189 	this->right->hess_reverse_1_clear_index();
190 	this->Node::hess_reverse_1_clear_index();
191 }
192 
hess_reverse_0()193 unsigned int BinaryOPNode::hess_reverse_0()
194 {
195 	assert(this->left!=NULL && right!=NULL);
196 	if(index==0)
197 	{
198 		unsigned int lindex=0, rindex=0;
199 		lindex = left->hess_reverse_0();
200 		rindex = right->hess_reverse_0();
201 		assert(lindex!=0 && rindex !=0);
202 		II->set(lindex);
203 		II->set(rindex);
204 		double rx,rx_bar,rw,rw_bar;
205 		double lx,lx_bar,lw,lw_bar;
206 		double x,x_bar,w,w_bar;
207 		double r_dh, l_dh;
208 		right->hess_reverse_0_get_values(rindex,rx,rx_bar,rw,rw_bar);
209 		left->hess_reverse_0_get_values(lindex,lx,lx_bar,lw,lw_bar);
210 		switch(op)
211 		{
212 		case OP_PLUS:
213 	//		cout<<"lindex="<<lindex<<"\trindex="<<rindex<<"\tI="<<I<<endl;
214 			x = lx + rx;
215 	//		cout<<lx<<"\t+"<<rx<<"\t="<<x<<"\t\t"<<toString(0)<<endl;
216 			x_bar = 0;
217 			l_dh = 1;
218 			r_dh = 1;
219 			w = lw * l_dh  + rw * r_dh;
220 	//		cout<<lw<<"\t+"<<rw<<"\t="<<w<<"\t\t"<<toString(0)<<endl;
221 			w_bar = 0;
222 			break;
223 		case OP_MINUS:
224 			x = lx - rx;
225 			x_bar = 0;
226 			l_dh = 1;
227 			r_dh = -1;
228 			w = lw * l_dh  + rw * r_dh;
229 			w_bar = 0;
230 			break;
231 		case OP_TIMES:
232 			x = lx * rx;
233 			x_bar = 0;
234 			l_dh = rx;
235 			r_dh = lx;
236 			w = lw * l_dh  + rw * r_dh;
237 			w_bar = 0;
238 			break;
239 		case OP_DIVID:
240 			x = lx / rx;
241 			x_bar = 0;
242 			l_dh = 1/rx;
243 			r_dh = -lx/pow(rx,2);
244 			w = lw * l_dh  + rw * r_dh;
245 			w_bar = 0;
246 			break;
247 		case OP_POW:
248 			if(right->getType()==PNode_Type)
249 			{
250 				x = pow(lx,rx);
251 				x_bar = 0;
252 				l_dh = rx*pow(lx,(rx-1));
253 				r_dh = 0;
254 				w = lw * l_dh + rw * r_dh;
255 				w_bar = 0;
256 			}
257 			else
258 			{
259 				assert(lx>0.0); //otherwise log(lx) undefined in real number
260 				x = pow(lx,rx);
261 				x_bar = 0;
262 				l_dh = rx*pow(lx,(rx-1));
263 				r_dh = pow(lx,rx)*log(lx);   //log(lx) cause -inf when lx=0;
264 				w = lw * l_dh  + rw * r_dh;
265 				w_bar = 0;
266 			}
267 			break;
268 		default:
269 			cerr<<"op["<<op<<"] not yet implemented!"<<endl;
270 			assert(false);
271 			break;
272 		}
273 		TT->set(x);
274 		TT->set(x_bar);
275 		TT->set(w);
276 		TT->set(w_bar);
277 		TT->set(l_dh);
278 		TT->set(r_dh);
279 		assert(TT->index == TT->index);
280 		index = TT->index;
281 	}
282 	return index;
283 }
284 
hess_reverse_0_get_values(unsigned int i,double & x,double & x_bar,double & w,double & w_bar)285 void BinaryOPNode::hess_reverse_0_get_values(unsigned int i,double& x, double& x_bar, double& w, double& w_bar)
286 {
287 	--i; // skip the r_dh (ie, dh/du)
288 	--i; // skip the l_dh (ie. dh/dv)
289 	w_bar = TT->get(--i);
290 	w = TT->get(--i);
291 	x_bar = TT->get(--i);
292 	x = TT->get(--i);
293 }
294 
hess_reverse_1(unsigned int i)295 void BinaryOPNode::hess_reverse_1(unsigned int i)
296 {
297 	n_in_arcs--;
298 	if(n_in_arcs==0)
299 	{
300 		assert(right!=NULL && left!=NULL);
301 		unsigned int rindex = II->get(--(II->index));
302 		unsigned int lindex = II->get(--(II->index));
303 	//	cout<<"ri["<<rindex<<"]\tli["<<lindex<<"]\t"<<this->toString(0)<<endl;
304 		double r_dh = TT->get(--i);
305 		double l_dh = TT->get(--i);
306 		double w_bar = TT->get(--i);
307 		--i; //skip w
308 		double x_bar = TT->get(--i);
309 		--i; //skip x
310 
311 		double lw_bar=0,rw_bar=0;
312 		double lw=0,lx=0; left->hess_reverse_1_get_xw(lindex,lw,lx);
313 		double rw=0,rx=0; right->hess_reverse_1_get_xw(rindex,rw,rx);
314 		switch(op)
315 		{
316 		case OP_PLUS:
317 			assert(l_dh==1);
318 			assert(r_dh==1);
319 			lw_bar += w_bar*l_dh;
320 			rw_bar += w_bar*r_dh;
321 			break;
322 		case OP_MINUS:
323 			assert(l_dh==1);
324 			assert(r_dh==-1);
325 			lw_bar += w_bar*l_dh;
326 			rw_bar += w_bar*r_dh;
327 			break;
328 		case OP_TIMES:
329 			assert(rx == l_dh);
330 			assert(lx == r_dh);
331 			lw_bar += w_bar*rx;
332 			lw_bar += x_bar*lw*0 + x_bar*rw*1;
333 			rw_bar += w_bar*lx;
334 			rw_bar += x_bar*lw*1 + x_bar*rw*0;
335 			break;
336 		case OP_DIVID:
337 			lw_bar += w_bar*l_dh;
338 			lw_bar += x_bar*lw*0 + x_bar*rw*-1/(pow(rx,2));
339 			rw_bar += w_bar*r_dh;
340 			rw_bar += x_bar*lw*-1/pow(rx,2) + x_bar*rw*2*lx/pow(rx,3);
341 			break;
342 		case OP_POW:
343 			if(right->getType()==PNode_Type){
344 				lw_bar += w_bar*l_dh;
345 				lw_bar += x_bar*lw*pow(lx,rx-2)*rx*(rx-1) + 0;
346 				rw_bar += w_bar*r_dh; assert(r_dh==0.0);
347 				rw_bar += 0;
348 			}
349 			else{
350 				assert(lx>0.0); //otherwise log(lx) is not define in Real
351 				lw_bar += w_bar*l_dh;
352 				lw_bar += x_bar*lw*pow(lx,rx-2)*rx*(rx-1) + x_bar*rw*pow(lx,rx-1)*(rx*log(lx)+1);  //cause log(lx)=-inf when
353 				rw_bar += w_bar*r_dh;
354 				rw_bar += x_bar*lw*pow(lx,rx-1)*(rx*log(lx)+1) + x_bar*rw*pow(lx,rx)*pow(log(lx),2);
355 			}
356 			break;
357 		default:
358 			cerr<<"op["<<op<<"] not yet implemented !"<<endl;
359 			assert(false);
360 			break;
361 		}
362 		double rx_bar = x_bar*r_dh;
363 		double lx_bar = x_bar*l_dh;
364 		right->update_x_bar(rindex,rx_bar);
365 		left->update_x_bar(lindex,lx_bar);
366 		right->update_w_bar(rindex,rw_bar);
367 		left->update_w_bar(lindex,lw_bar);
368 
369 
370 		this->right->hess_reverse_1(rindex);
371 		this->left->hess_reverse_1(lindex);
372 	}
373 }
hess_reverse_1_init_x_bar(unsigned int i)374 void BinaryOPNode::hess_reverse_1_init_x_bar(unsigned int i)
375 {
376 	TT->at(i-5) = 1;
377 }
update_x_bar(unsigned int i,double v)378 void BinaryOPNode::update_x_bar(unsigned int i ,double v)
379 {
380 	TT->at(i-5) += v;
381 }
update_w_bar(unsigned int i,double v)382 void BinaryOPNode::update_w_bar(unsigned int i ,double v)
383 {
384 	TT->at(i-3) += v;
385 }
hess_reverse_1_get_xw(unsigned int i,double & w,double & x)386 void BinaryOPNode::hess_reverse_1_get_xw(unsigned int i,double& w,double& x)
387 {
388 	w = TT->get(i-4);
389 	x = TT->get(i-6);
390 }
hess_reverse_get_x(unsigned int i,double & x)391 void BinaryOPNode::hess_reverse_get_x(unsigned int i,double& x)
392 {
393 	x = TT->get(i-6);
394 }
395 
396 
nonlinearEdges(EdgeSet & edges)397 void BinaryOPNode::nonlinearEdges(EdgeSet& edges)
398 {
399 	for(list<Edge>::iterator it=edges.edges.begin();it!=edges.edges.end();)
400 	{
401 		Edge e = *it;
402 		if(e.a==this || e.b == this){
403 			if(e.a == this && e.b == this)
404 			{
405 				Edge e1(left,left);
406 				Edge e2(right,right);
407 				Edge e3(left,right);
408 				edges.insertEdge(e1);
409 				edges.insertEdge(e2);
410 				edges.insertEdge(e3);
411 			}
412 			else
413 			{
414 				Node* o = e.a==this? e.b: e.a;
415 				Edge e1(left,o);
416 				Edge e2(right,o);
417 				edges.insertEdge(e1);
418 				edges.insertEdge(e2);
419 			}
420 			it = edges.edges.erase(it);
421 		}
422 		else
423 		{
424 			it++;
425 		}
426 	}
427 
428 	Edge e1(left,right);
429 	Edge e2(left,left);
430 	Edge e3(right,right);
431 	switch(op)
432 	{
433 	case OP_PLUS:
434 	case OP_MINUS:
435 		//do nothing for linear operator
436 		break;
437 	case OP_TIMES:
438 		edges.insertEdge(e1);
439 		break;
440 	case OP_DIVID:
441 		edges.insertEdge(e1);
442 		edges.insertEdge(e3);
443 		break;
444 	case OP_POW:
445 		edges.insertEdge(e1);
446 		edges.insertEdge(e2);
447 		edges.insertEdge(e3);
448 		break;
449 	default:
450 		cerr<<"op["<<op<<"] not yet implmented !"<<endl;
451 		assert(false);
452 		break;
453 	}
454 	left->nonlinearEdges(edges);
455 	right->nonlinearEdges(edges);
456 }
457 
458 #if FORWARD_ENABLED
459 
hess_forward(unsigned int len,double ** ret_vec)460 void BinaryOPNode::hess_forward(unsigned int len, double** ret_vec)
461 {
462 	double* lvec = NULL;
463 	double* rvec = NULL;
464 	if(left!=NULL){
465 		left->hess_forward(len,&lvec);
466 	}
467 	if(right!=NULL){
468 		right->hess_forward(len,&rvec);
469 	}
470 
471 	*ret_vec = new double[len];
472 	hess_forward_calc0(len,lvec,rvec,*ret_vec);
473 	//delete lvec, rvec
474 	delete[] lvec;
475 	delete[] rvec;
476 }
477 
hess_forward_calc0(unsigned int & len,double * lvec,double * rvec,double * ret_vec)478 void BinaryOPNode::hess_forward_calc0(unsigned int& len, double* lvec, double* rvec, double* ret_vec)
479 {
480 	double hu = NaN_Double, hv= NaN_Double;
481 	double lval = NaN_Double, rval = NaN_Double;
482 	double val = NaN_Double;
483 	unsigned int index = 0;
484 	switch (op)
485 	{
486 	case OP_PLUS:
487 		rval = SV->pop_back();
488 		lval = SV->pop_back();
489 		val = lval + rval;
490 		SV->push_back(val);
491 		//calculate the first order derivatives
492 		for(unsigned int i=0;i<AutoDiff::num_var;++i)
493 		{
494 			ret_vec[i] = lvec[i]+rvec[i];
495 		}
496 		//calculate the second order
497 		index = AutoDiff::num_var;
498 		for(unsigned int i=0;i<AutoDiff::num_var;++i)
499 		{
500 			for(unsigned int j=i;j<AutoDiff::num_var;++j){
501 				ret_vec[index] = lvec[index] + 0 + rvec[index] + 0;
502 				++index;
503 			}
504 		}
505 		assert(index==len);
506 		break;
507 	case OP_MINUS:
508 		rval = SV->pop_back();
509 		lval = SV->pop_back();
510 		val = lval + rval;
511 		SV->push_back(val);
512 		//calculate the first order derivatives
513 		for(unsigned int i=0;i<AutoDiff::num_var;++i)
514 		{
515 			ret_vec[i] = lvec[i] - rvec[i];
516 		}
517 		//calculate the second order
518 		index = AutoDiff::num_var;
519 		for(unsigned int i=0;i<AutoDiff::num_var;++i)
520 		{
521 			for(unsigned int j=i;j<AutoDiff::num_var;++j){
522 				ret_vec[index] = lvec[index] + 0 - rvec[index] + 0;
523 				++index;
524 			}
525 		}
526 		assert(index==len);
527 		break;
528 	case OP_TIMES:
529 		rval = SV->pop_back();
530 		lval = SV->pop_back();
531 		val = lval * rval;
532 		SV->push_back(val);
533 		hu = rval;
534 		hv = lval;
535 		//calculate the first order derivatives
536 		for(unsigned int i =0;i<AutoDiff::num_var;++i)
537 		{
538 			ret_vec[i] = hu*lvec[i] + hv*rvec[i];
539 		}
540 		//calculate the second order
541 		index = AutoDiff::num_var;
542 		for(unsigned int i=0;i<AutoDiff::num_var;++i)
543 		{
544 			for(unsigned int j=i;j<AutoDiff::num_var;++j)
545 			{
546 				ret_vec[index] = hu * lvec[index] + lvec[i] * rvec[j]+hv * rvec[index] + rvec[i] * lvec[j];
547 				++index;
548 			}
549 		}
550 		assert(index==len);
551 		break;
552 	case OP_POW:
553 		rval = SV->pop_back();
554 		lval = SV->pop_back();
555 		val = pow(lval,rval);
556 		SV->push_back(val);
557 		if(left->getType()==PNode_Type && right->getType()==PNode_Type)
558 		{
559 			std::fill_n(ret_vec,len,0);
560 		}
561 		else
562 		{
563 			hu = rval*pow(lval,(rval-1));
564 			hv = pow(lval,rval)*log(lval);
565 			if(left->getType()==PNode_Type)
566 			{
567 				double coeff = pow(log(lval),2)*pow(lval,rval);
568 				//calculate the first order derivatives
569 				for(unsigned int i =0;i<AutoDiff::num_var;++i)
570 				{
571 					ret_vec[i] = hu*lvec[i] + hv*rvec[i];
572 				}
573 				//calculate the second order
574 				index = AutoDiff::num_var;
575 				for(unsigned int i=0;i<AutoDiff::num_var;++i)
576 				{
577 					for(unsigned int j=i;j<AutoDiff::num_var;++j)
578 					{
579 						ret_vec[index] = 0 + 0 + hv * rvec[index] + rvec[i] * coeff * rvec[j];
580 						++index;
581 					}
582 				}
583 			}
584 			else if(right->getType()==PNode_Type)
585 			{
586 				double coeff = rval*(rval-1)*pow(lval,rval-2);
587 				//calculate the first order derivatives
588 				for(unsigned int i =0;i<AutoDiff::num_var;++i)
589 				{
590 					ret_vec[i] = hu*lvec[i] + hv*rvec[i];
591 				}
592 				//calculate the second order
593 				index = AutoDiff::num_var;
594 				for(unsigned int i=0;i<AutoDiff::num_var;++i)
595 				{
596 					for(unsigned int j=i;j<AutoDiff::num_var;++j)
597 					{
598 						ret_vec[index] = hu*lvec[index] + lvec[i] * coeff * lvec[j] + 0 + 0;
599 						++index;
600 					}
601 				}
602 			}
603 			else
604 			{
605 				assert(false);
606 			}
607 		}
608 		assert(index==len);
609 		break;
610 	case OP_SIN: //TODO should move to UnaryOPNode.cpp?
611 		assert(left!=NULL&&right==NULL);
612 		lval = SV->pop_back();
613 		val = sin(lval);
614 		SV->push_back(val);
615 		hu = cos(lval);
616 
617 		double coeff;
618 		coeff = -val; //=sin(left->val); -- and avoid cross initialisation
619 		//calculate the first order derivatives
620 		for(unsigned int i =0;i<AutoDiff::num_var;++i)
621 		{
622 			ret_vec[i] = hu*lvec[i] + 0;
623 		}
624 		//calculate the second order
625 		index = AutoDiff::num_var;
626 		for(unsigned int i=0;i<AutoDiff::num_var;++i)
627 		{
628 			for(unsigned int j=i;j<AutoDiff::num_var;++j)
629 			{
630 				ret_vec[index] = hu*lvec[index] + lvec[i] * coeff * lvec[j] + 0 + 0;
631 				++index;
632 			}
633 		}
634 		assert(index==len);
635 		break;
636 	default:
637 		cerr<<"op["<<op<<"] not yet implemented!";
638 		break;
639 	}
640 }
641 #endif
642 
643 
toString(int level)644 string BinaryOPNode::toString(int level){
645 	ostringstream oss;
646 	string s(level,'\t');
647 	oss<<s<<"[BinaryOPNode]("<<op<<")";
648 	return oss.str();
649 }
650 
651 } /* namespace AutoDiff */
652