1 /* $Id$
2  *
3  * Name:    flattenMul.cpp
4  * Author:  Pietro Belotti
5  * Purpose: flatten multiplication expression tree into monomial
6  *          c*\Prod_{k\in K} x_{i_k}^{p_k}
7  *
8  * (C) Carnegie-Mellon University, 2007-10.
9  * This file is licensed under the Eclipse Public License (EPL)
10  */
11 
12 #include <stdio.h>
13 
14 #include "CouenneProblem.hpp"
15 #include "CouenneExprAux.hpp"
16 
17 using namespace Couenne;
18 
19 /// re-organizes multiplication and stores indices (and exponents) of
20 /// its variables
flattenMul(expression * mul,CouNumber & coe,std::map<int,CouNumber> & indices)21 void CouenneProblem::flattenMul (expression *mul, CouNumber &coe,
22 				 std::map <int, CouNumber> &indices) {
23 
24   if (jnlst_ -> ProduceOutput (Ipopt::J_ALL, J_REFORMULATE)) {
25     printf ("flatten %d ---> ", mul -> code ()); mul -> print ();
26     printf ("\n");
27   }
28 
29   if (mul -> code () != COU_EXPRMUL) {
30 
31     exprAux *aux = mul -> standardize (this);
32 
33     int ind = (aux) ? aux -> Index () : mul -> Index ();
34 
35     std::map <int, CouNumber>::iterator
36       where = indices.find (ind);
37 
38     if (where == indices.end ())
39       indices.insert (std::pair <int, CouNumber> (ind, 1));
40     else ++ (where -> second);
41 
42     return;
43   }
44 
45   int nargs = mul -> nArgs ();
46   expression **al = mul -> ArgList ();
47 
48   // for each factor (variable, function, or constant) of the product
49   for (int i=0; i < nargs; i++) {
50 
51     expression
52       *arg = al [i],
53       *simpl = arg -> simplify ();
54 
55     if (simpl)
56       al [i] = arg = simpl;
57 
58     if (jnlst_ -> ProduceOutput (Ipopt::J_ALL, J_REFORMULATE)) {
59       printf ("  flatten arg %d ---> ", arg -> code ());
60       arg -> print ();
61       printf ("\n");
62     }
63 
64     switch (arg -> code ()) {
65 
66     case COU_EXPRCONST: // change scalar multiplier
67 
68       coe *= arg -> Value ();
69       break;
70 
71     case COU_EXPRMUL:  // apply recursively
72 
73       flattenMul (arg, coe, indices);
74       break;
75 
76     case COU_EXPRVAR: { // insert index or increment
77 
78       std::map <int, CouNumber>::iterator
79 	where = indices.find (arg -> Index ());
80 
81       if (where == indices.end ())
82 	indices.insert (std::pair <int, CouNumber> (arg -> Index (), 1));
83       else ++ (where -> second);
84     } break;
85 
86     case COU_EXPROPP: // equivalent to multiplying by -1
87 
88       coe = -coe;
89 
90       if (arg -> Argument () -> Type () == N_ARY) {
91 	flattenMul (arg -> Argument (), coe, indices);
92 	break;
93       } else arg = arg -> Argument ();
94 
95     case COU_EXPRPOW:
96 
97       if (arg -> code () == COU_EXPRPOW) { // re-check as it could come from above
98 
99 	expression
100 	  *base     = arg -> ArgList () [0],
101 	  *exponent = arg -> ArgList () [1];
102 
103 	if (exponent -> Type () == CONST) { // could be of the form k x^2
104 
105 	  double expnum = exponent -> Value ();
106 
107 	  expression *aux = base -> standardize (this);
108 
109 	  if (!aux)
110 	    aux = base;
111 
112 	  std::map <int, CouNumber>::iterator
113 	    where = indices.find (aux -> Index ());
114 
115 	  if (where == indices.end ())
116 	    indices.insert (std::pair <int, CouNumber> (aux -> Index (), expnum));
117 	  else (where -> second += expnum);
118 
119 	  break;
120 	}  // otherwise, revert to default
121       }
122 
123     case COU_EXPRSUM: // well, only if there is one element
124 
125       if ((arg -> code  () == COU_EXPRSUM) && // re-check as it could come from above
126 	  (arg -> nArgs () == 1)) {
127 
128 	flattenMul (arg, coe, indices);
129 	break;
130 
131       } // otherwise, continue into default case
132 
133     default: { // for all other expression, add associated new auxiliary
134 
135       exprAux *aux = arg -> standardize (this);
136 
137       int ind = (aux) ? aux -> Index () : arg -> Index ();
138 
139       std::map <int, CouNumber>::iterator
140 	where = indices.find (ind);
141 
142       if (where == indices.end ())
143 	indices.insert (std::pair <int, CouNumber> (ind, 1));
144       else ++ (where -> second);
145     }
146     }
147   }
148 }
149