1 /* -*- mode: C++; c-basic-offset: 2; indent-tabs-mode: nil -*- */
2 
3 /*
4  *  Main authors:
5  *     Guido Tack <guido.tack@monash.edu>
6  */
7 
8 /* This Source Code Form is subject to the terms of the Mozilla Public
9  * License, v. 2.0. If a copy of the MPL was not distributed with this
10  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
11 
12 #include <minizinc/flat_exp.hh>
13 
14 namespace MiniZinc {
15 
get_conjuncts(Expression * start)16 std::vector<Expression*> get_conjuncts(Expression* start) {
17   std::vector<Expression*> conj_stack;
18   std::vector<Expression*> conjuncts;
19   conj_stack.push_back(start);
20   while (!conj_stack.empty()) {
21     Expression* e = conj_stack.back();
22     conj_stack.pop_back();
23     if (auto* bo = e->dynamicCast<BinOp>()) {
24       if (bo->op() == BOT_AND) {
25         conj_stack.push_back(bo->rhs());
26         conj_stack.push_back(bo->lhs());
27       } else {
28         conjuncts.push_back(e);
29       }
30     } else {
31       conjuncts.push_back(e);
32     }
33   }
34   return conjuncts;
35 }
36 
classify_conjunct(Expression * e,IdMap<int> & eq_occurrences,IdMap<std::pair<Expression *,Expression * >> & eq_branches,std::vector<Expression * > & other_branches)37 void classify_conjunct(Expression* e, IdMap<int>& eq_occurrences,
38                        IdMap<std::pair<Expression*, Expression*>>& eq_branches,
39                        std::vector<Expression*>& other_branches) {
40   if (auto* bo = e->dynamicCast<BinOp>()) {
41     if (bo->op() == BOT_EQ) {
42       if (Id* ident = bo->lhs()->dynamicCast<Id>()) {
43         if (eq_branches.find(ident) == eq_branches.end()) {
44           auto it = eq_occurrences.find(ident);
45           if (it == eq_occurrences.end()) {
46             eq_occurrences.insert(ident, 1);
47           } else {
48             eq_occurrences.get(ident)++;
49           }
50           eq_branches.insert(ident, std::make_pair(bo->rhs(), bo));
51           return;
52         }
53       } else if (Id* ident = bo->rhs()->dynamicCast<Id>()) {
54         if (eq_branches.find(ident) == eq_branches.end()) {
55           auto it = eq_occurrences.find(ident);
56           if (it == eq_occurrences.end()) {
57             eq_occurrences.insert(ident, 1);
58           } else {
59             eq_occurrences.get(ident)++;
60           }
61           eq_branches.insert(ident, std::make_pair(bo->lhs(), bo));
62           return;
63         }
64       }
65     }
66   }
67   other_branches.push_back(e);
68 }
69 
flatten_ite(EnvI & env,const Ctx & ctx,Expression * e,VarDecl * r,VarDecl * b)70 EE flatten_ite(EnvI& env, const Ctx& ctx, Expression* e, VarDecl* r, VarDecl* b) {
71   CallStackItem _csi(env, e);
72   ITE* ite = e->cast<ITE>();
73 
74   // The conditions of each branch of the if-then-else
75   std::vector<KeepAlive> conditions;
76   // Whether the right hand side of each branch is defined
77   std::vector<std::vector<KeepAlive>> defined;
78   // The right hand side of each branch
79   std::vector<std::vector<KeepAlive>> branches;
80   // Whether all branches are fixed
81   std::vector<bool> allBranchesPar;
82 
83   // Compute bounds of result as union bounds of all branches
84   std::vector<std::vector<IntBounds>> r_bounds_int;
85   std::vector<bool> r_bounds_valid_int;
86   std::vector<std::vector<IntSetVal*>> r_bounds_set;
87   std::vector<bool> r_bounds_valid_set;
88   std::vector<std::vector<FloatBounds>> r_bounds_float;
89   std::vector<bool> r_bounds_valid_float;
90 
91   bool allConditionsPar = true;
92   bool allDefined = true;
93 
94   // The result variables of each generated conditional
95   std::vector<VarDecl*> results;
96   // The then-expressions of each generated conditional
97   std::vector<std::vector<KeepAlive>> e_then;
98   // The else-expressions of each generated conditional
99   std::vector<KeepAlive> e_else;
100 
101   bool noOtherBranches = true;
102   if (ite->type() == Type::varbool() && ctx.b == C_ROOT && r == constants().varTrue) {
103     // Check if all branches are of the form x1=e1 /\ ... /\ xn=en
104     IdMap<int> eq_occurrences;
105     std::vector<IdMap<std::pair<Expression*, Expression*>>> eq_branches(ite->size() + 1);
106     std::vector<std::vector<Expression*>> other_branches(ite->size() + 1);
107     for (int i = 0; i < ite->size(); i++) {
108       auto conjuncts = get_conjuncts(ite->thenExpr(i));
109       for (auto* c : conjuncts) {
110         classify_conjunct(c, eq_occurrences, eq_branches[i], other_branches[i]);
111       }
112       noOtherBranches = noOtherBranches && other_branches[i].empty();
113     }
114     {
115       auto conjuncts = get_conjuncts(ite->elseExpr());
116       for (auto* c : conjuncts) {
117         classify_conjunct(c, eq_occurrences, eq_branches[ite->size()], other_branches[ite->size()]);
118       }
119       noOtherBranches = noOtherBranches && other_branches[ite->size()].empty();
120     }
121     for (auto& eq : eq_occurrences) {
122       if (eq.second >= ite->size()) {
123         // Any identifier that occurs in all or all but one branch gets its own conditional
124         results.push_back(eq.first->decl());
125         e_then.emplace_back();
126         for (int i = 0; i < ite->size(); i++) {
127           auto it = eq_branches[i].find(eq.first);
128           if (it == eq_branches[i].end()) {
129             // not found, simply push x=x
130             e_then.back().push_back(eq.first);
131           } else {
132             e_then.back().push_back(it->second.first);
133           }
134         }
135         {
136           auto it = eq_branches[ite->size()].find(eq.first);
137           if (it == eq_branches[ite->size()].end()) {
138             // not found, simply push x=x
139             e_else.emplace_back(eq.first);
140           } else {
141             e_else.emplace_back(it->second.first);
142           }
143         }
144       } else {
145         // All other identifiers are put in the vector of "other" branches
146         for (int i = 0; i <= ite->size(); i++) {
147           auto it = eq_branches[i].find(eq.first);
148           if (it != eq_branches[i].end()) {
149             other_branches[i].push_back(it->second.second);
150             noOtherBranches = false;
151             eq_branches[i].remove(eq.first);
152           }
153         }
154       }
155     }
156     if (!noOtherBranches) {
157       results.push_back(r);
158       e_then.emplace_back();
159       for (int i = 0; i < ite->size(); i++) {
160         if (eq_branches[i].size() == 0) {
161           e_then.back().push_back(ite->thenExpr(i));
162         } else if (other_branches[i].empty()) {
163           e_then.back().push_back(constants().literalTrue);
164         } else if (other_branches[i].size() == 1) {
165           e_then.back().push_back(other_branches[i][0]);
166         } else {
167           GCLock lock;
168           auto* al = new ArrayLit(Location().introduce(), other_branches[i]);
169           al->type(Type::varbool(1));
170           Call* forall = new Call(Location().introduce(), constants().ids.forall, {al});
171           forall->decl(env.model->matchFn(env, forall, false));
172           forall->type(forall->decl()->rtype(env, {al}, false));
173           e_then.back().push_back(forall);
174         }
175       }
176       {
177         if (eq_branches[ite->size()].size() == 0) {
178           e_else.emplace_back(ite->elseExpr());
179         } else if (other_branches[ite->size()].empty()) {
180           e_else.emplace_back(constants().literalTrue);
181         } else if (other_branches[ite->size()].size() == 1) {
182           e_else.emplace_back(other_branches[ite->size()][0]);
183         } else {
184           GCLock lock;
185           auto* al = new ArrayLit(Location().introduce(), other_branches[ite->size()]);
186           al->type(Type::varbool(1));
187           Call* forall = new Call(Location().introduce(), constants().ids.forall, {al});
188           forall->decl(env.model->matchFn(env, forall, false));
189           forall->type(forall->decl()->rtype(env, {al}, false));
190           e_else.emplace_back(forall);
191         }
192       }
193     }
194   } else {
195     noOtherBranches = false;
196     results.push_back(r);
197     e_then.emplace_back();
198     for (int i = 0; i < ite->size(); i++) {
199       e_then.back().push_back(ite->thenExpr(i));
200     }
201     e_else.emplace_back(ite->elseExpr());
202   }
203   allBranchesPar.resize(results.size());
204   r_bounds_valid_int.resize(results.size());
205   r_bounds_int.resize(results.size());
206   r_bounds_valid_float.resize(results.size());
207   r_bounds_float.resize(results.size());
208   r_bounds_valid_set.resize(results.size());
209   r_bounds_set.resize(results.size());
210   defined.resize(results.size());
211   branches.resize(results.size());
212   for (unsigned int i = 0; i < results.size(); i++) {
213     allBranchesPar[i] = true;
214     r_bounds_valid_int[i] = true;
215     r_bounds_valid_float[i] = true;
216     r_bounds_valid_set[i] = true;
217   }
218 
219   Ctx cmix;
220   cmix.b = C_MIX;
221   cmix.i = C_MIX;
222   cmix.neg = ctx.neg;
223 
224   bool foundTrueBranch = false;
225   for (int i = 0; i < ite->size() && !foundTrueBranch; i++) {
226     bool cond = true;
227     EE e_if;
228     if (ite->ifExpr(i)->isa<Call>() &&
229         ite->ifExpr(i)->cast<Call>()->id() == "mzn_in_root_context") {
230       e_if = EE(constants().boollit(ctx.b == C_ROOT), constants().literalTrue);
231     } else {
232       Ctx cmix_not_negated;
233       cmix_not_negated.b = C_MIX;
234       cmix_not_negated.i = C_MIX;
235       e_if = flat_exp(env, cmix_not_negated, ite->ifExpr(i), nullptr, constants().varTrue);
236     }
237     if (e_if.r()->type() == Type::parbool()) {
238       {
239         GCLock lock;
240         cond = eval_bool(env, e_if.r());
241       }
242       if (cond) {
243         if (allConditionsPar) {
244           // no var conditions before this one, so we can simply emit
245           // the then branch
246           return flat_exp(env, ctx, ite->thenExpr(i), r, b);
247         }
248         // had var conditions, so we have to take them into account
249         // and emit new conditional clause
250         // add another condition and definedness variable
251         conditions.emplace_back(constants().literalTrue);
252         for (unsigned int j = 0; j < results.size(); j++) {
253           EE ethen = flat_exp(env, cmix, e_then[j][i](), nullptr, nullptr);
254           assert(ethen.b());
255           defined[j].push_back(ethen.b);
256           allDefined = allDefined && (ethen.b() == constants().literalTrue);
257           branches[j].push_back(ethen.r);
258           if (ethen.r()->type().isvar()) {
259             allBranchesPar[j] = false;
260           }
261         }
262         foundTrueBranch = true;
263       } else {
264         GCLock lock;
265         conditions.emplace_back(constants().literalFalse);
266         for (unsigned int j = 0; j < results.size(); j++) {
267           defined[j].push_back(constants().literalTrue);
268           branches[j].push_back(create_dummy_value(env, e_then[j][i]()->type()));
269         }
270       }
271     } else {
272       allConditionsPar = false;
273       // add current condition and definedness variable
274       conditions.push_back(e_if.r);
275 
276       for (unsigned int j = 0; j < results.size(); j++) {
277         // flatten the then branch
278         EE ethen = flat_exp(env, cmix, e_then[j][i](), nullptr, nullptr);
279 
280         assert(ethen.b());
281         defined[j].push_back(ethen.b);
282         allDefined = allDefined && (ethen.b() == constants().literalTrue);
283         branches[j].push_back(ethen.r);
284         if (ethen.r()->type().isvar()) {
285           allBranchesPar[j] = false;
286         }
287       }
288     }
289     // update bounds
290 
291     if (cond) {
292       for (unsigned int j = 0; j < results.size(); j++) {
293         if (r_bounds_valid_int[j] && e_then[j][i]()->type().isint()) {
294           GCLock lock;
295           IntBounds ib_then = compute_int_bounds(env, branches[j][i]());
296           if (ib_then.valid) {
297             r_bounds_int[j].push_back(ib_then);
298           }
299           r_bounds_valid_int[j] = r_bounds_valid_int[j] && ib_then.valid;
300         } else if (r_bounds_valid_set[j] && e_then[j][i]()->type().isIntSet()) {
301           GCLock lock;
302           IntSetVal* isv = compute_intset_bounds(env, branches[j][i]());
303           if (isv != nullptr) {
304             r_bounds_set[j].push_back(isv);
305           }
306           r_bounds_valid_set[j] = r_bounds_valid_set[j] && (isv != nullptr);
307         } else if (r_bounds_valid_float[j] && e_then[j][i]()->type().isfloat()) {
308           GCLock lock;
309           FloatBounds fb_then = compute_float_bounds(env, branches[j][i]());
310           if (fb_then.valid) {
311             r_bounds_float[j].push_back(fb_then);
312           }
313           r_bounds_valid_float[j] = r_bounds_valid_float[j] && fb_then.valid;
314         }
315       }
316     }
317   }
318 
319   if (allConditionsPar) {
320     // no var condition, and all par conditions were false,
321     // so simply emit else branch
322     return flat_exp(env, ctx, ite->elseExpr(), r, b);
323   }
324 
325   for (auto& result : results) {
326     if (result == nullptr) {
327       // need to introduce new result variable
328       GCLock lock;
329       auto* ti = new TypeInst(Location().introduce(), ite->type(), nullptr);
330       result = new_vardecl(env, Ctx(), ti, nullptr, nullptr, nullptr);
331     }
332   }
333 
334   if (conditions.back()() != constants().literalTrue) {
335     // The last condition wasn't fixed to true, we need to look at the else branch
336     conditions.emplace_back(constants().literalTrue);
337 
338     for (unsigned int j = 0; j < results.size(); j++) {
339       // flatten else branch
340       EE eelse = flat_exp(env, cmix, e_else[j](), nullptr, nullptr);
341       assert(eelse.b());
342       defined[j].push_back(eelse.b);
343       allDefined = allDefined && (eelse.b() == constants().literalTrue);
344       branches[j].push_back(eelse.r);
345       if (eelse.r()->type().isvar()) {
346         allBranchesPar[j] = false;
347       }
348 
349       if (r_bounds_valid_int[j] && e_else[j]()->type().isint()) {
350         GCLock lock;
351         IntBounds ib_else = compute_int_bounds(env, eelse.r());
352         if (ib_else.valid) {
353           r_bounds_int[j].push_back(ib_else);
354         }
355         r_bounds_valid_int[j] = r_bounds_valid_int[j] && ib_else.valid;
356       } else if (r_bounds_valid_set[j] && e_else[j]()->type().isIntSet()) {
357         GCLock lock;
358         IntSetVal* isv = compute_intset_bounds(env, eelse.r());
359         if (isv != nullptr) {
360           r_bounds_set[j].push_back(isv);
361         }
362         r_bounds_valid_set[j] = r_bounds_valid_set[j] && (isv != nullptr);
363       } else if (r_bounds_valid_float[j] && e_else[j]()->type().isfloat()) {
364         GCLock lock;
365         FloatBounds fb_else = compute_float_bounds(env, eelse.r());
366         if (fb_else.valid) {
367           r_bounds_float[j].push_back(fb_else);
368         }
369         r_bounds_valid_float[j] = r_bounds_valid_float[j] && fb_else.valid;
370       }
371     }
372   }
373 
374   // update domain of result variable with bounds from all branches
375 
376   for (unsigned int j = 0; j < results.size(); j++) {
377     VarDecl* nr = results[j];
378     GCLock lock;
379     if (r_bounds_valid_int[j] && ite->type().isint()) {
380       IntVal lb = IntVal::infinity();
381       IntVal ub = -IntVal::infinity();
382       for (auto& i : r_bounds_int[j]) {
383         lb = std::min(lb, i.l);
384         ub = std::max(ub, i.u);
385       }
386       if (nr->ti()->domain() != nullptr) {
387         IntSetVal* isv = eval_intset(env, nr->ti()->domain());
388         Ranges::Const<IntVal> ite_r(lb, ub);
389         IntSetRanges isv_r(isv);
390         Ranges::Inter<IntVal, Ranges::Const<IntVal>, IntSetRanges> inter(ite_r, isv_r);
391         IntSetVal* isv_new = IntSetVal::ai(inter);
392         if (isv_new->card() != isv->card()) {
393           auto* r_dom = new SetLit(Location().introduce(), isv_new);
394           nr->ti()->domain(r_dom);
395         }
396       } else {
397         auto* r_dom = new SetLit(Location().introduce(), IntSetVal::a(lb, ub));
398         nr->ti()->domain(r_dom);
399         nr->ti()->setComputedDomain(true);
400       }
401     } else if (r_bounds_valid_set[j] && ite->type().isIntSet()) {
402       IntSetVal* isv_branches = IntSetVal::a();
403       for (auto& i : r_bounds_set[j]) {
404         IntSetRanges i0(isv_branches);
405         IntSetRanges i1(i);
406         Ranges::Union<IntVal, IntSetRanges, IntSetRanges> u(i0, i1);
407         isv_branches = IntSetVal::ai(u);
408       }
409       if (nr->ti()->domain() != nullptr) {
410         IntSetVal* isv = eval_intset(env, nr->ti()->domain());
411         IntSetRanges isv_r(isv);
412         IntSetRanges isv_branches_r(isv_branches);
413         Ranges::Inter<IntVal, IntSetRanges, IntSetRanges> inter(isv_branches_r, isv_r);
414         IntSetVal* isv_new = IntSetVal::ai(inter);
415         if (isv_new->card() != isv->card()) {
416           auto* r_dom = new SetLit(Location().introduce(), isv_new);
417           nr->ti()->domain(r_dom);
418         }
419       } else {
420         auto* r_dom = new SetLit(Location().introduce(), isv_branches);
421         nr->ti()->domain(r_dom);
422         nr->ti()->setComputedDomain(true);
423       }
424     } else if (r_bounds_valid_float[j] && ite->type().isfloat()) {
425       FloatVal lb = FloatVal::infinity();
426       FloatVal ub = -FloatVal::infinity();
427       for (auto& i : r_bounds_float[j]) {
428         lb = std::min(lb, i.l);
429         ub = std::max(ub, i.u);
430       }
431       if (nr->ti()->domain() != nullptr) {
432         FloatSetVal* isv = eval_floatset(env, nr->ti()->domain());
433         Ranges::Const<FloatVal> ite_r(lb, ub);
434         FloatSetRanges isv_r(isv);
435         Ranges::Inter<FloatVal, Ranges::Const<FloatVal>, FloatSetRanges> inter(ite_r, isv_r);
436         FloatSetVal* fsv_new = FloatSetVal::ai(inter);
437         auto* r_dom = new SetLit(Location().introduce(), fsv_new);
438         nr->ti()->domain(r_dom);
439       } else {
440         auto* r_dom = new SetLit(Location().introduce(), FloatSetVal::a(lb, ub));
441         nr->ti()->domain(r_dom);
442         nr->ti()->setComputedDomain(true);
443       }
444     }
445   }
446 
447   // Create ite predicate calls
448   GCLock lock;
449   auto* al_cond = new ArrayLit(Location().introduce(), conditions);
450   al_cond->type(Type::varbool(1));
451   for (unsigned int j = 0; j < results.size(); j++) {
452     auto* al_branches = new ArrayLit(Location().introduce(), branches[j]);
453     Type branches_t = results[j]->type();
454     branches_t.dim(1);
455     branches_t.ti(allBranchesPar[j] ? Type::TI_PAR : Type::TI_VAR);
456     al_branches->type(branches_t);
457     Call* ite_pred = new Call(ite->loc().introduce(), ASTString("if_then_else"),
458                               {al_cond, al_branches, results[j]->id()});
459     ite_pred->decl(env.model->matchFn(env, ite_pred, false));
460     ite_pred->type(Type::varbool());
461     (void)flat_exp(env, Ctx(), ite_pred, constants().varTrue, constants().varTrue);
462   }
463   EE ret;
464   if (noOtherBranches) {
465     ret.r = constants().varTrue->id();
466   } else {
467     ret.r = results.back()->id();
468   }
469   if (allDefined) {
470     bind(env, Ctx(), b, constants().literalTrue);
471     ret.b = constants().literalTrue;
472   } else {
473     // Otherwise, constraint linking conditions, b and the definedness variables
474     if (b == nullptr) {
475       CallStackItem _csi(env, new StringLit(Location().introduce(), "b"));
476       b = new_vardecl(env, Ctx(), new TypeInst(Location().introduce(), Type::varbool()), nullptr,
477                       nullptr, nullptr);
478     }
479     ret.b = b->id();
480 
481     std::vector<Expression*> defined_conjunctions(ite->size() + 1);
482     for (unsigned int i = 0; i < ite->size() + 1; i++) {
483       std::vector<Expression*> def_i;
484       for (auto& j : defined) {
485         assert(j.size() > i);
486         if (j[i]() != constants().literalTrue) {
487           def_i.push_back(j[i]());
488         }
489       }
490       if (def_i.empty()) {
491         defined_conjunctions[i] = constants().literalTrue;
492       } else if (def_i.size() == 1) {
493         defined_conjunctions[i] = def_i[0];
494       } else {
495         auto* al = new ArrayLit(Location().introduce(), def_i);
496         al->type(Type::varbool(1));
497         Call* forall = new Call(Location().introduce(), constants().ids.forall, {al});
498         forall->decl(env.model->matchFn(env, forall, false));
499         forall->type(forall->decl()->rtype(env, {al}, false));
500         defined_conjunctions[i] = forall;
501       }
502     }
503     auto* al_defined = new ArrayLit(Location().introduce(), defined_conjunctions);
504     al_defined->type(Type::varbool(1));
505     Call* ite_defined_pred = new Call(ite->loc().introduce(), ASTString("if_then_else_partiality"),
506                                       {al_cond, al_defined, b->id()});
507     ite_defined_pred->decl(env.model->matchFn(env, ite_defined_pred, false));
508     ite_defined_pred->type(Type::varbool());
509     (void)flat_exp(env, Ctx(), ite_defined_pred, constants().varTrue, constants().varTrue);
510   }
511 
512   return ret;
513 }
514 
515 }  // namespace MiniZinc
516