1 /* -*- mode: C++; c-basic-offset: 2; indent-tabs-mode: nil -*- */
2
3 /*
4 * Main authors:
5 * Guido Tack <guido.tack@monash.edu>
6 */
7
8 /* This Source Code Form is subject to the terms of the Mozilla Public
9 * License, v. 2.0. If a copy of the MPL was not distributed with this
10 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
11
12 #include <minizinc/flat_exp.hh>
13
14 namespace MiniZinc {
15
get_conjuncts(Expression * start)16 std::vector<Expression*> get_conjuncts(Expression* start) {
17 std::vector<Expression*> conj_stack;
18 std::vector<Expression*> conjuncts;
19 conj_stack.push_back(start);
20 while (!conj_stack.empty()) {
21 Expression* e = conj_stack.back();
22 conj_stack.pop_back();
23 if (auto* bo = e->dynamicCast<BinOp>()) {
24 if (bo->op() == BOT_AND) {
25 conj_stack.push_back(bo->rhs());
26 conj_stack.push_back(bo->lhs());
27 } else {
28 conjuncts.push_back(e);
29 }
30 } else {
31 conjuncts.push_back(e);
32 }
33 }
34 return conjuncts;
35 }
36
classify_conjunct(Expression * e,IdMap<int> & eq_occurrences,IdMap<std::pair<Expression *,Expression * >> & eq_branches,std::vector<Expression * > & other_branches)37 void classify_conjunct(Expression* e, IdMap<int>& eq_occurrences,
38 IdMap<std::pair<Expression*, Expression*>>& eq_branches,
39 std::vector<Expression*>& other_branches) {
40 if (auto* bo = e->dynamicCast<BinOp>()) {
41 if (bo->op() == BOT_EQ) {
42 if (Id* ident = bo->lhs()->dynamicCast<Id>()) {
43 if (eq_branches.find(ident) == eq_branches.end()) {
44 auto it = eq_occurrences.find(ident);
45 if (it == eq_occurrences.end()) {
46 eq_occurrences.insert(ident, 1);
47 } else {
48 eq_occurrences.get(ident)++;
49 }
50 eq_branches.insert(ident, std::make_pair(bo->rhs(), bo));
51 return;
52 }
53 } else if (Id* ident = bo->rhs()->dynamicCast<Id>()) {
54 if (eq_branches.find(ident) == eq_branches.end()) {
55 auto it = eq_occurrences.find(ident);
56 if (it == eq_occurrences.end()) {
57 eq_occurrences.insert(ident, 1);
58 } else {
59 eq_occurrences.get(ident)++;
60 }
61 eq_branches.insert(ident, std::make_pair(bo->lhs(), bo));
62 return;
63 }
64 }
65 }
66 }
67 other_branches.push_back(e);
68 }
69
flatten_ite(EnvI & env,const Ctx & ctx,Expression * e,VarDecl * r,VarDecl * b)70 EE flatten_ite(EnvI& env, const Ctx& ctx, Expression* e, VarDecl* r, VarDecl* b) {
71 CallStackItem _csi(env, e);
72 ITE* ite = e->cast<ITE>();
73
74 // The conditions of each branch of the if-then-else
75 std::vector<KeepAlive> conditions;
76 // Whether the right hand side of each branch is defined
77 std::vector<std::vector<KeepAlive>> defined;
78 // The right hand side of each branch
79 std::vector<std::vector<KeepAlive>> branches;
80 // Whether all branches are fixed
81 std::vector<bool> allBranchesPar;
82
83 // Compute bounds of result as union bounds of all branches
84 std::vector<std::vector<IntBounds>> r_bounds_int;
85 std::vector<bool> r_bounds_valid_int;
86 std::vector<std::vector<IntSetVal*>> r_bounds_set;
87 std::vector<bool> r_bounds_valid_set;
88 std::vector<std::vector<FloatBounds>> r_bounds_float;
89 std::vector<bool> r_bounds_valid_float;
90
91 bool allConditionsPar = true;
92 bool allDefined = true;
93
94 // The result variables of each generated conditional
95 std::vector<VarDecl*> results;
96 // The then-expressions of each generated conditional
97 std::vector<std::vector<KeepAlive>> e_then;
98 // The else-expressions of each generated conditional
99 std::vector<KeepAlive> e_else;
100
101 bool noOtherBranches = true;
102 if (ite->type() == Type::varbool() && ctx.b == C_ROOT && r == constants().varTrue) {
103 // Check if all branches are of the form x1=e1 /\ ... /\ xn=en
104 IdMap<int> eq_occurrences;
105 std::vector<IdMap<std::pair<Expression*, Expression*>>> eq_branches(ite->size() + 1);
106 std::vector<std::vector<Expression*>> other_branches(ite->size() + 1);
107 for (int i = 0; i < ite->size(); i++) {
108 auto conjuncts = get_conjuncts(ite->thenExpr(i));
109 for (auto* c : conjuncts) {
110 classify_conjunct(c, eq_occurrences, eq_branches[i], other_branches[i]);
111 }
112 noOtherBranches = noOtherBranches && other_branches[i].empty();
113 }
114 {
115 auto conjuncts = get_conjuncts(ite->elseExpr());
116 for (auto* c : conjuncts) {
117 classify_conjunct(c, eq_occurrences, eq_branches[ite->size()], other_branches[ite->size()]);
118 }
119 noOtherBranches = noOtherBranches && other_branches[ite->size()].empty();
120 }
121 for (auto& eq : eq_occurrences) {
122 if (eq.second >= ite->size()) {
123 // Any identifier that occurs in all or all but one branch gets its own conditional
124 results.push_back(eq.first->decl());
125 e_then.emplace_back();
126 for (int i = 0; i < ite->size(); i++) {
127 auto it = eq_branches[i].find(eq.first);
128 if (it == eq_branches[i].end()) {
129 // not found, simply push x=x
130 e_then.back().push_back(eq.first);
131 } else {
132 e_then.back().push_back(it->second.first);
133 }
134 }
135 {
136 auto it = eq_branches[ite->size()].find(eq.first);
137 if (it == eq_branches[ite->size()].end()) {
138 // not found, simply push x=x
139 e_else.emplace_back(eq.first);
140 } else {
141 e_else.emplace_back(it->second.first);
142 }
143 }
144 } else {
145 // All other identifiers are put in the vector of "other" branches
146 for (int i = 0; i <= ite->size(); i++) {
147 auto it = eq_branches[i].find(eq.first);
148 if (it != eq_branches[i].end()) {
149 other_branches[i].push_back(it->second.second);
150 noOtherBranches = false;
151 eq_branches[i].remove(eq.first);
152 }
153 }
154 }
155 }
156 if (!noOtherBranches) {
157 results.push_back(r);
158 e_then.emplace_back();
159 for (int i = 0; i < ite->size(); i++) {
160 if (eq_branches[i].size() == 0) {
161 e_then.back().push_back(ite->thenExpr(i));
162 } else if (other_branches[i].empty()) {
163 e_then.back().push_back(constants().literalTrue);
164 } else if (other_branches[i].size() == 1) {
165 e_then.back().push_back(other_branches[i][0]);
166 } else {
167 GCLock lock;
168 auto* al = new ArrayLit(Location().introduce(), other_branches[i]);
169 al->type(Type::varbool(1));
170 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al});
171 forall->decl(env.model->matchFn(env, forall, false));
172 forall->type(forall->decl()->rtype(env, {al}, false));
173 e_then.back().push_back(forall);
174 }
175 }
176 {
177 if (eq_branches[ite->size()].size() == 0) {
178 e_else.emplace_back(ite->elseExpr());
179 } else if (other_branches[ite->size()].empty()) {
180 e_else.emplace_back(constants().literalTrue);
181 } else if (other_branches[ite->size()].size() == 1) {
182 e_else.emplace_back(other_branches[ite->size()][0]);
183 } else {
184 GCLock lock;
185 auto* al = new ArrayLit(Location().introduce(), other_branches[ite->size()]);
186 al->type(Type::varbool(1));
187 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al});
188 forall->decl(env.model->matchFn(env, forall, false));
189 forall->type(forall->decl()->rtype(env, {al}, false));
190 e_else.emplace_back(forall);
191 }
192 }
193 }
194 } else {
195 noOtherBranches = false;
196 results.push_back(r);
197 e_then.emplace_back();
198 for (int i = 0; i < ite->size(); i++) {
199 e_then.back().push_back(ite->thenExpr(i));
200 }
201 e_else.emplace_back(ite->elseExpr());
202 }
203 allBranchesPar.resize(results.size());
204 r_bounds_valid_int.resize(results.size());
205 r_bounds_int.resize(results.size());
206 r_bounds_valid_float.resize(results.size());
207 r_bounds_float.resize(results.size());
208 r_bounds_valid_set.resize(results.size());
209 r_bounds_set.resize(results.size());
210 defined.resize(results.size());
211 branches.resize(results.size());
212 for (unsigned int i = 0; i < results.size(); i++) {
213 allBranchesPar[i] = true;
214 r_bounds_valid_int[i] = true;
215 r_bounds_valid_float[i] = true;
216 r_bounds_valid_set[i] = true;
217 }
218
219 Ctx cmix;
220 cmix.b = C_MIX;
221 cmix.i = C_MIX;
222 cmix.neg = ctx.neg;
223
224 bool foundTrueBranch = false;
225 for (int i = 0; i < ite->size() && !foundTrueBranch; i++) {
226 bool cond = true;
227 EE e_if;
228 if (ite->ifExpr(i)->isa<Call>() &&
229 ite->ifExpr(i)->cast<Call>()->id() == "mzn_in_root_context") {
230 e_if = EE(constants().boollit(ctx.b == C_ROOT), constants().literalTrue);
231 } else {
232 Ctx cmix_not_negated;
233 cmix_not_negated.b = C_MIX;
234 cmix_not_negated.i = C_MIX;
235 e_if = flat_exp(env, cmix_not_negated, ite->ifExpr(i), nullptr, constants().varTrue);
236 }
237 if (e_if.r()->type() == Type::parbool()) {
238 {
239 GCLock lock;
240 cond = eval_bool(env, e_if.r());
241 }
242 if (cond) {
243 if (allConditionsPar) {
244 // no var conditions before this one, so we can simply emit
245 // the then branch
246 return flat_exp(env, ctx, ite->thenExpr(i), r, b);
247 }
248 // had var conditions, so we have to take them into account
249 // and emit new conditional clause
250 // add another condition and definedness variable
251 conditions.emplace_back(constants().literalTrue);
252 for (unsigned int j = 0; j < results.size(); j++) {
253 EE ethen = flat_exp(env, cmix, e_then[j][i](), nullptr, nullptr);
254 assert(ethen.b());
255 defined[j].push_back(ethen.b);
256 allDefined = allDefined && (ethen.b() == constants().literalTrue);
257 branches[j].push_back(ethen.r);
258 if (ethen.r()->type().isvar()) {
259 allBranchesPar[j] = false;
260 }
261 }
262 foundTrueBranch = true;
263 } else {
264 GCLock lock;
265 conditions.emplace_back(constants().literalFalse);
266 for (unsigned int j = 0; j < results.size(); j++) {
267 defined[j].push_back(constants().literalTrue);
268 branches[j].push_back(create_dummy_value(env, e_then[j][i]()->type()));
269 }
270 }
271 } else {
272 allConditionsPar = false;
273 // add current condition and definedness variable
274 conditions.push_back(e_if.r);
275
276 for (unsigned int j = 0; j < results.size(); j++) {
277 // flatten the then branch
278 EE ethen = flat_exp(env, cmix, e_then[j][i](), nullptr, nullptr);
279
280 assert(ethen.b());
281 defined[j].push_back(ethen.b);
282 allDefined = allDefined && (ethen.b() == constants().literalTrue);
283 branches[j].push_back(ethen.r);
284 if (ethen.r()->type().isvar()) {
285 allBranchesPar[j] = false;
286 }
287 }
288 }
289 // update bounds
290
291 if (cond) {
292 for (unsigned int j = 0; j < results.size(); j++) {
293 if (r_bounds_valid_int[j] && e_then[j][i]()->type().isint()) {
294 GCLock lock;
295 IntBounds ib_then = compute_int_bounds(env, branches[j][i]());
296 if (ib_then.valid) {
297 r_bounds_int[j].push_back(ib_then);
298 }
299 r_bounds_valid_int[j] = r_bounds_valid_int[j] && ib_then.valid;
300 } else if (r_bounds_valid_set[j] && e_then[j][i]()->type().isIntSet()) {
301 GCLock lock;
302 IntSetVal* isv = compute_intset_bounds(env, branches[j][i]());
303 if (isv != nullptr) {
304 r_bounds_set[j].push_back(isv);
305 }
306 r_bounds_valid_set[j] = r_bounds_valid_set[j] && (isv != nullptr);
307 } else if (r_bounds_valid_float[j] && e_then[j][i]()->type().isfloat()) {
308 GCLock lock;
309 FloatBounds fb_then = compute_float_bounds(env, branches[j][i]());
310 if (fb_then.valid) {
311 r_bounds_float[j].push_back(fb_then);
312 }
313 r_bounds_valid_float[j] = r_bounds_valid_float[j] && fb_then.valid;
314 }
315 }
316 }
317 }
318
319 if (allConditionsPar) {
320 // no var condition, and all par conditions were false,
321 // so simply emit else branch
322 return flat_exp(env, ctx, ite->elseExpr(), r, b);
323 }
324
325 for (auto& result : results) {
326 if (result == nullptr) {
327 // need to introduce new result variable
328 GCLock lock;
329 auto* ti = new TypeInst(Location().introduce(), ite->type(), nullptr);
330 result = new_vardecl(env, Ctx(), ti, nullptr, nullptr, nullptr);
331 }
332 }
333
334 if (conditions.back()() != constants().literalTrue) {
335 // The last condition wasn't fixed to true, we need to look at the else branch
336 conditions.emplace_back(constants().literalTrue);
337
338 for (unsigned int j = 0; j < results.size(); j++) {
339 // flatten else branch
340 EE eelse = flat_exp(env, cmix, e_else[j](), nullptr, nullptr);
341 assert(eelse.b());
342 defined[j].push_back(eelse.b);
343 allDefined = allDefined && (eelse.b() == constants().literalTrue);
344 branches[j].push_back(eelse.r);
345 if (eelse.r()->type().isvar()) {
346 allBranchesPar[j] = false;
347 }
348
349 if (r_bounds_valid_int[j] && e_else[j]()->type().isint()) {
350 GCLock lock;
351 IntBounds ib_else = compute_int_bounds(env, eelse.r());
352 if (ib_else.valid) {
353 r_bounds_int[j].push_back(ib_else);
354 }
355 r_bounds_valid_int[j] = r_bounds_valid_int[j] && ib_else.valid;
356 } else if (r_bounds_valid_set[j] && e_else[j]()->type().isIntSet()) {
357 GCLock lock;
358 IntSetVal* isv = compute_intset_bounds(env, eelse.r());
359 if (isv != nullptr) {
360 r_bounds_set[j].push_back(isv);
361 }
362 r_bounds_valid_set[j] = r_bounds_valid_set[j] && (isv != nullptr);
363 } else if (r_bounds_valid_float[j] && e_else[j]()->type().isfloat()) {
364 GCLock lock;
365 FloatBounds fb_else = compute_float_bounds(env, eelse.r());
366 if (fb_else.valid) {
367 r_bounds_float[j].push_back(fb_else);
368 }
369 r_bounds_valid_float[j] = r_bounds_valid_float[j] && fb_else.valid;
370 }
371 }
372 }
373
374 // update domain of result variable with bounds from all branches
375
376 for (unsigned int j = 0; j < results.size(); j++) {
377 VarDecl* nr = results[j];
378 GCLock lock;
379 if (r_bounds_valid_int[j] && ite->type().isint()) {
380 IntVal lb = IntVal::infinity();
381 IntVal ub = -IntVal::infinity();
382 for (auto& i : r_bounds_int[j]) {
383 lb = std::min(lb, i.l);
384 ub = std::max(ub, i.u);
385 }
386 if (nr->ti()->domain() != nullptr) {
387 IntSetVal* isv = eval_intset(env, nr->ti()->domain());
388 Ranges::Const<IntVal> ite_r(lb, ub);
389 IntSetRanges isv_r(isv);
390 Ranges::Inter<IntVal, Ranges::Const<IntVal>, IntSetRanges> inter(ite_r, isv_r);
391 IntSetVal* isv_new = IntSetVal::ai(inter);
392 if (isv_new->card() != isv->card()) {
393 auto* r_dom = new SetLit(Location().introduce(), isv_new);
394 nr->ti()->domain(r_dom);
395 }
396 } else {
397 auto* r_dom = new SetLit(Location().introduce(), IntSetVal::a(lb, ub));
398 nr->ti()->domain(r_dom);
399 nr->ti()->setComputedDomain(true);
400 }
401 } else if (r_bounds_valid_set[j] && ite->type().isIntSet()) {
402 IntSetVal* isv_branches = IntSetVal::a();
403 for (auto& i : r_bounds_set[j]) {
404 IntSetRanges i0(isv_branches);
405 IntSetRanges i1(i);
406 Ranges::Union<IntVal, IntSetRanges, IntSetRanges> u(i0, i1);
407 isv_branches = IntSetVal::ai(u);
408 }
409 if (nr->ti()->domain() != nullptr) {
410 IntSetVal* isv = eval_intset(env, nr->ti()->domain());
411 IntSetRanges isv_r(isv);
412 IntSetRanges isv_branches_r(isv_branches);
413 Ranges::Inter<IntVal, IntSetRanges, IntSetRanges> inter(isv_branches_r, isv_r);
414 IntSetVal* isv_new = IntSetVal::ai(inter);
415 if (isv_new->card() != isv->card()) {
416 auto* r_dom = new SetLit(Location().introduce(), isv_new);
417 nr->ti()->domain(r_dom);
418 }
419 } else {
420 auto* r_dom = new SetLit(Location().introduce(), isv_branches);
421 nr->ti()->domain(r_dom);
422 nr->ti()->setComputedDomain(true);
423 }
424 } else if (r_bounds_valid_float[j] && ite->type().isfloat()) {
425 FloatVal lb = FloatVal::infinity();
426 FloatVal ub = -FloatVal::infinity();
427 for (auto& i : r_bounds_float[j]) {
428 lb = std::min(lb, i.l);
429 ub = std::max(ub, i.u);
430 }
431 if (nr->ti()->domain() != nullptr) {
432 FloatSetVal* isv = eval_floatset(env, nr->ti()->domain());
433 Ranges::Const<FloatVal> ite_r(lb, ub);
434 FloatSetRanges isv_r(isv);
435 Ranges::Inter<FloatVal, Ranges::Const<FloatVal>, FloatSetRanges> inter(ite_r, isv_r);
436 FloatSetVal* fsv_new = FloatSetVal::ai(inter);
437 auto* r_dom = new SetLit(Location().introduce(), fsv_new);
438 nr->ti()->domain(r_dom);
439 } else {
440 auto* r_dom = new SetLit(Location().introduce(), FloatSetVal::a(lb, ub));
441 nr->ti()->domain(r_dom);
442 nr->ti()->setComputedDomain(true);
443 }
444 }
445 }
446
447 // Create ite predicate calls
448 GCLock lock;
449 auto* al_cond = new ArrayLit(Location().introduce(), conditions);
450 al_cond->type(Type::varbool(1));
451 for (unsigned int j = 0; j < results.size(); j++) {
452 auto* al_branches = new ArrayLit(Location().introduce(), branches[j]);
453 Type branches_t = results[j]->type();
454 branches_t.dim(1);
455 branches_t.ti(allBranchesPar[j] ? Type::TI_PAR : Type::TI_VAR);
456 al_branches->type(branches_t);
457 Call* ite_pred = new Call(ite->loc().introduce(), ASTString("if_then_else"),
458 {al_cond, al_branches, results[j]->id()});
459 ite_pred->decl(env.model->matchFn(env, ite_pred, false));
460 ite_pred->type(Type::varbool());
461 (void)flat_exp(env, Ctx(), ite_pred, constants().varTrue, constants().varTrue);
462 }
463 EE ret;
464 if (noOtherBranches) {
465 ret.r = constants().varTrue->id();
466 } else {
467 ret.r = results.back()->id();
468 }
469 if (allDefined) {
470 bind(env, Ctx(), b, constants().literalTrue);
471 ret.b = constants().literalTrue;
472 } else {
473 // Otherwise, constraint linking conditions, b and the definedness variables
474 if (b == nullptr) {
475 CallStackItem _csi(env, new StringLit(Location().introduce(), "b"));
476 b = new_vardecl(env, Ctx(), new TypeInst(Location().introduce(), Type::varbool()), nullptr,
477 nullptr, nullptr);
478 }
479 ret.b = b->id();
480
481 std::vector<Expression*> defined_conjunctions(ite->size() + 1);
482 for (unsigned int i = 0; i < ite->size() + 1; i++) {
483 std::vector<Expression*> def_i;
484 for (auto& j : defined) {
485 assert(j.size() > i);
486 if (j[i]() != constants().literalTrue) {
487 def_i.push_back(j[i]());
488 }
489 }
490 if (def_i.empty()) {
491 defined_conjunctions[i] = constants().literalTrue;
492 } else if (def_i.size() == 1) {
493 defined_conjunctions[i] = def_i[0];
494 } else {
495 auto* al = new ArrayLit(Location().introduce(), def_i);
496 al->type(Type::varbool(1));
497 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al});
498 forall->decl(env.model->matchFn(env, forall, false));
499 forall->type(forall->decl()->rtype(env, {al}, false));
500 defined_conjunctions[i] = forall;
501 }
502 }
503 auto* al_defined = new ArrayLit(Location().introduce(), defined_conjunctions);
504 al_defined->type(Type::varbool(1));
505 Call* ite_defined_pred = new Call(ite->loc().introduce(), ASTString("if_then_else_partiality"),
506 {al_cond, al_defined, b->id()});
507 ite_defined_pred->decl(env.model->matchFn(env, ite_defined_pred, false));
508 ite_defined_pred->type(Type::varbool());
509 (void)flat_exp(env, Ctx(), ite_defined_pred, constants().varTrue, constants().varTrue);
510 }
511
512 return ret;
513 }
514
515 } // namespace MiniZinc
516