1 //===- WatchedLiteralsSolver.cpp --------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file defines a SAT solver implementation that can be used by dataflow
10 //  analyses.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <cassert>
15 #include <cstddef>
16 #include <cstdint>
17 #include <queue>
18 #include <vector>
19 
20 #include "clang/Analysis/FlowSensitive/Formula.h"
21 #include "clang/Analysis/FlowSensitive/Solver.h"
22 #include "clang/Analysis/FlowSensitive/WatchedLiteralsSolver.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/DenseSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/STLExtras.h"
28 
29 
30 namespace clang {
31 namespace dataflow {
32 
33 // `WatchedLiteralsSolver` is an implementation of Algorithm D from Knuth's
34 // The Art of Computer Programming Volume 4: Satisfiability, Fascicle 6. It is
35 // based on the backtracking DPLL algorithm [1], keeps references to a single
36 // "watched" literal per clause, and uses a set of "active" variables to perform
37 // unit propagation.
38 //
39 // The solver expects that its input is a boolean formula in conjunctive normal
40 // form that consists of clauses of at least one literal. A literal is either a
41 // boolean variable or its negation. Below we define types, data structures, and
42 // utilities that are used to represent boolean formulas in conjunctive normal
43 // form.
44 //
45 // [1] https://en.wikipedia.org/wiki/DPLL_algorithm
46 
47 /// Boolean variables are represented as positive integers.
48 using Variable = uint32_t;
49 
50 /// A null boolean variable is used as a placeholder in various data structures
51 /// and algorithms.
52 static constexpr Variable NullVar = 0;
53 
54 /// Literals are represented as positive integers. Specifically, for a boolean
55 /// variable `V` that is represented as the positive integer `I`, the positive
56 /// literal `V` is represented as the integer `2*I` and the negative literal
57 /// `!V` is represented as the integer `2*I+1`.
58 using Literal = uint32_t;
59 
60 /// A null literal is used as a placeholder in various data structures and
61 /// algorithms.
62 [[maybe_unused]] static constexpr Literal NullLit = 0;
63 
64 /// Returns the positive literal `V`.
posLit(Variable V)65 static constexpr Literal posLit(Variable V) { return 2 * V; }
66 
isPosLit(Literal L)67 static constexpr bool isPosLit(Literal L) { return 0 == (L & 1); }
68 
isNegLit(Literal L)69 static constexpr bool isNegLit(Literal L) { return 1 == (L & 1); }
70 
71 /// Returns the negative literal `!V`.
negLit(Variable V)72 static constexpr Literal negLit(Variable V) { return 2 * V + 1; }
73 
74 /// Returns the negated literal `!L`.
notLit(Literal L)75 static constexpr Literal notLit(Literal L) { return L ^ 1; }
76 
77 /// Returns the variable of `L`.
var(Literal L)78 static constexpr Variable var(Literal L) { return L >> 1; }
79 
80 /// Clause identifiers are represented as positive integers.
81 using ClauseID = uint32_t;
82 
83 /// A null clause identifier is used as a placeholder in various data structures
84 /// and algorithms.
85 static constexpr ClauseID NullClause = 0;
86 
87 /// A boolean formula in conjunctive normal form.
88 struct CNFFormula {
89   /// `LargestVar` is equal to the largest positive integer that represents a
90   /// variable in the formula.
91   const Variable LargestVar;
92 
93   /// Literals of all clauses in the formula.
94   ///
95   /// The element at index 0 stands for the literal in the null clause. It is
96   /// set to 0 and isn't used. Literals of clauses in the formula start from the
97   /// element at index 1.
98   ///
99   /// For example, for the formula `(L1 v L2) ^ (L2 v L3 v L4)` the elements of
100   /// `Clauses` will be `[0, L1, L2, L2, L3, L4]`.
101   std::vector<Literal> Clauses;
102 
103   /// Start indices of clauses of the formula in `Clauses`.
104   ///
105   /// The element at index 0 stands for the start index of the null clause. It
106   /// is set to 0 and isn't used. Start indices of clauses in the formula start
107   /// from the element at index 1.
108   ///
109   /// For example, for the formula `(L1 v L2) ^ (L2 v L3 v L4)` the elements of
110   /// `ClauseStarts` will be `[0, 1, 3]`. Note that the literals of the first
111   /// clause always start at index 1. The start index for the literals of the
112   /// second clause depends on the size of the first clause and so on.
113   std::vector<size_t> ClauseStarts;
114 
115   /// Maps literals (indices of the vector) to clause identifiers (elements of
116   /// the vector) that watch the respective literals.
117   ///
118   /// For a given clause, its watched literal is always its first literal in
119   /// `Clauses`. This invariant is maintained when watched literals change.
120   std::vector<ClauseID> WatchedHead;
121 
122   /// Maps clause identifiers (elements of the vector) to identifiers of other
123   /// clauses that watch the same literals, forming a set of linked lists.
124   ///
125   /// The element at index 0 stands for the identifier of the clause that
126   /// follows the null clause. It is set to 0 and isn't used. Identifiers of
127   /// clauses in the formula start from the element at index 1.
128   std::vector<ClauseID> NextWatched;
129 
130   /// Stores the variable identifier and Atom for atomic booleans in the
131   /// formula.
132   llvm::DenseMap<Variable, Atom> Atomics;
133 
134   /// Indicates that we already know the formula is unsatisfiable.
135   /// During construction, we catch simple cases of conflicting unit-clauses.
136   bool KnownContradictory;
137 
CNFFormulaclang::dataflow::CNFFormula138   explicit CNFFormula(Variable LargestVar,
139                       llvm::DenseMap<Variable, Atom> Atomics)
140       : LargestVar(LargestVar), Atomics(std::move(Atomics)),
141         KnownContradictory(false) {
142     Clauses.push_back(0);
143     ClauseStarts.push_back(0);
144     NextWatched.push_back(0);
145     const size_t NumLiterals = 2 * LargestVar + 1;
146     WatchedHead.resize(NumLiterals + 1, 0);
147   }
148 
149   /// Adds the `L1 v ... v Ln` clause to the formula.
150   /// Requirements:
151   ///
152   ///  `Li` must not be `NullLit`.
153   ///
154   ///  All literals in the input that are not `NullLit` must be distinct.
addClauseclang::dataflow::CNFFormula155   void addClause(ArrayRef<Literal> lits) {
156     assert(!lits.empty());
157     assert(llvm::all_of(lits, [](Literal L) { return L != NullLit; }));
158 
159     const ClauseID C = ClauseStarts.size();
160     const size_t S = Clauses.size();
161     ClauseStarts.push_back(S);
162     Clauses.insert(Clauses.end(), lits.begin(), lits.end());
163 
164     // Designate the first literal as the "watched" literal of the clause.
165     NextWatched.push_back(WatchedHead[lits.front()]);
166     WatchedHead[lits.front()] = C;
167   }
168 
169   /// Returns the number of literals in clause `C`.
clauseSizeclang::dataflow::CNFFormula170   size_t clauseSize(ClauseID C) const {
171     return C == ClauseStarts.size() - 1 ? Clauses.size() - ClauseStarts[C]
172                                         : ClauseStarts[C + 1] - ClauseStarts[C];
173   }
174 
175   /// Returns the literals of clause `C`.
clauseLiteralsclang::dataflow::CNFFormula176   llvm::ArrayRef<Literal> clauseLiterals(ClauseID C) const {
177     return llvm::ArrayRef<Literal>(&Clauses[ClauseStarts[C]], clauseSize(C));
178   }
179 };
180 
181 /// Applies simplifications while building up a BooleanFormula.
182 /// We keep track of unit clauses, which tell us variables that must be
183 /// true/false in any model that satisfies the overall formula.
184 /// Such variables can be dropped from subsequently-added clauses, which
185 /// may in turn yield more unit clauses or even a contradiction.
186 /// The total added complexity of this preprocessing is O(N) where we
187 /// for every clause, we do a lookup for each unit clauses.
188 /// The lookup is O(1) on average. This method won't catch all
189 /// contradictory formulas, more passes can in principle catch
190 /// more cases but we leave all these and the general case to the
191 /// proper SAT solver.
192 struct CNFFormulaBuilder {
193   // Formula should outlive CNFFormulaBuilder.
CNFFormulaBuilderclang::dataflow::CNFFormulaBuilder194   explicit CNFFormulaBuilder(CNFFormula &CNF)
195       : Formula(CNF) {}
196 
197   /// Adds the `L1 v ... v Ln` clause to the formula. Applies
198   /// simplifications, based on single-literal clauses.
199   ///
200   /// Requirements:
201   ///
202   ///  `Li` must not be `NullLit`.
203   ///
204   ///  All literals must be distinct.
addClauseclang::dataflow::CNFFormulaBuilder205   void addClause(ArrayRef<Literal> Literals) {
206     // We generate clauses with up to 3 literals in this file.
207     assert(!Literals.empty() && Literals.size() <= 3);
208     // Contains literals of the simplified clause.
209     llvm::SmallVector<Literal> Simplified;
210     for (auto L : Literals) {
211       assert(L != NullLit &&
212              llvm::all_of(Simplified,
213                           [L](Literal S) { return  S != L; }));
214       auto X = var(L);
215       if (trueVars.contains(X)) { // X must be true
216         if (isPosLit(L))
217           return; // Omit clause `(... v X v ...)`, it is `true`.
218         else
219           continue; // Omit `!X` from `(... v !X v ...)`.
220       }
221       if (falseVars.contains(X)) { // X must be false
222         if (isNegLit(L))
223           return; // Omit clause `(... v !X v ...)`, it is `true`.
224         else
225           continue; // Omit `X` from `(... v X v ...)`.
226       }
227       Simplified.push_back(L);
228     }
229     if (Simplified.empty()) {
230       // Simplification made the clause empty, which is equivalent to `false`.
231       // We already know that this formula is unsatisfiable.
232       Formula.KnownContradictory = true;
233       // We can add any of the input literals to get an unsatisfiable formula.
234       Formula.addClause(Literals[0]);
235       return;
236     }
237     if (Simplified.size() == 1) {
238       // We have new unit clause.
239       const Literal lit = Simplified.front();
240       const Variable v = var(lit);
241       if (isPosLit(lit))
242         trueVars.insert(v);
243       else
244         falseVars.insert(v);
245     }
246     Formula.addClause(Simplified);
247   }
248 
249   /// Returns true if we observed a contradiction while adding clauses.
250   /// In this case then the formula is already known to be unsatisfiable.
isKnownContradictoryclang::dataflow::CNFFormulaBuilder251   bool isKnownContradictory() { return Formula.KnownContradictory; }
252 
253 private:
254   CNFFormula &Formula;
255   llvm::DenseSet<Variable> trueVars;
256   llvm::DenseSet<Variable> falseVars;
257 };
258 
259 /// Converts the conjunction of `Vals` into a formula in conjunctive normal
260 /// form where each clause has at least one and at most three literals.
buildCNF(const llvm::ArrayRef<const Formula * > & Vals)261 CNFFormula buildCNF(const llvm::ArrayRef<const Formula *> &Vals) {
262   // The general strategy of the algorithm implemented below is to map each
263   // of the sub-values in `Vals` to a unique variable and use these variables in
264   // the resulting CNF expression to avoid exponential blow up. The number of
265   // literals in the resulting formula is guaranteed to be linear in the number
266   // of sub-formulas in `Vals`.
267 
268   // Map each sub-formula in `Vals` to a unique variable.
269   llvm::DenseMap<const Formula *, Variable> SubValsToVar;
270   // Store variable identifiers and Atom of atomic booleans.
271   llvm::DenseMap<Variable, Atom> Atomics;
272   Variable NextVar = 1;
273   {
274     std::queue<const Formula *> UnprocessedSubVals;
275     for (const Formula *Val : Vals)
276       UnprocessedSubVals.push(Val);
277     while (!UnprocessedSubVals.empty()) {
278       Variable Var = NextVar;
279       const Formula *Val = UnprocessedSubVals.front();
280       UnprocessedSubVals.pop();
281 
282       if (!SubValsToVar.try_emplace(Val, Var).second)
283         continue;
284       ++NextVar;
285 
286       for (const Formula *F : Val->operands())
287         UnprocessedSubVals.push(F);
288       if (Val->kind() == Formula::AtomRef)
289         Atomics[Var] = Val->getAtom();
290     }
291   }
292 
293   auto GetVar = [&SubValsToVar](const Formula *Val) {
294     auto ValIt = SubValsToVar.find(Val);
295     assert(ValIt != SubValsToVar.end());
296     return ValIt->second;
297   };
298 
299   CNFFormula CNF(NextVar - 1, std::move(Atomics));
300   std::vector<bool> ProcessedSubVals(NextVar, false);
301   CNFFormulaBuilder builder(CNF);
302 
303   // Add a conjunct for each variable that represents a top-level conjunction
304   // value in `Vals`.
305   for (const Formula *Val : Vals)
306     builder.addClause(posLit(GetVar(Val)));
307 
308   // Add conjuncts that represent the mapping between newly-created variables
309   // and their corresponding sub-formulas.
310   std::queue<const Formula *> UnprocessedSubVals;
311   for (const Formula *Val : Vals)
312     UnprocessedSubVals.push(Val);
313   while (!UnprocessedSubVals.empty()) {
314     const Formula *Val = UnprocessedSubVals.front();
315     UnprocessedSubVals.pop();
316     const Variable Var = GetVar(Val);
317 
318     if (ProcessedSubVals[Var])
319       continue;
320     ProcessedSubVals[Var] = true;
321 
322     switch (Val->kind()) {
323     case Formula::AtomRef:
324       break;
325     case Formula::Literal:
326       CNF.addClause(Val->literal() ? posLit(Var) : negLit(Var));
327       break;
328     case Formula::And: {
329       const Variable LHS = GetVar(Val->operands()[0]);
330       const Variable RHS = GetVar(Val->operands()[1]);
331 
332       if (LHS == RHS) {
333         // `X <=> (A ^ A)` is equivalent to `(!X v A) ^ (X v !A)` which is
334         // already in conjunctive normal form. Below we add each of the
335         // conjuncts of the latter expression to the result.
336         builder.addClause({negLit(Var), posLit(LHS)});
337         builder.addClause({posLit(Var), negLit(LHS)});
338       } else {
339         // `X <=> (A ^ B)` is equivalent to `(!X v A) ^ (!X v B) ^ (X v !A v
340         // !B)` which is already in conjunctive normal form. Below we add each
341         // of the conjuncts of the latter expression to the result.
342         builder.addClause({negLit(Var), posLit(LHS)});
343         builder.addClause({negLit(Var), posLit(RHS)});
344         builder.addClause({posLit(Var), negLit(LHS), negLit(RHS)});
345       }
346       break;
347     }
348     case Formula::Or: {
349       const Variable LHS = GetVar(Val->operands()[0]);
350       const Variable RHS = GetVar(Val->operands()[1]);
351 
352       if (LHS == RHS) {
353         // `X <=> (A v A)` is equivalent to `(!X v A) ^ (X v !A)` which is
354         // already in conjunctive normal form. Below we add each of the
355         // conjuncts of the latter expression to the result.
356         builder.addClause({negLit(Var), posLit(LHS)});
357         builder.addClause({posLit(Var), negLit(LHS)});
358       } else {
359         // `X <=> (A v B)` is equivalent to `(!X v A v B) ^ (X v !A) ^ (X v
360         // !B)` which is already in conjunctive normal form. Below we add each
361         // of the conjuncts of the latter expression to the result.
362         builder.addClause({negLit(Var), posLit(LHS), posLit(RHS)});
363         builder.addClause({posLit(Var), negLit(LHS)});
364         builder.addClause({posLit(Var), negLit(RHS)});
365       }
366       break;
367     }
368     case Formula::Not: {
369       const Variable Operand = GetVar(Val->operands()[0]);
370 
371       // `X <=> !Y` is equivalent to `(!X v !Y) ^ (X v Y)` which is
372       // already in conjunctive normal form. Below we add each of the
373       // conjuncts of the latter expression to the result.
374       builder.addClause({negLit(Var), negLit(Operand)});
375       builder.addClause({posLit(Var), posLit(Operand)});
376       break;
377     }
378     case Formula::Implies: {
379       const Variable LHS = GetVar(Val->operands()[0]);
380       const Variable RHS = GetVar(Val->operands()[1]);
381 
382       // `X <=> (A => B)` is equivalent to
383       // `(X v A) ^ (X v !B) ^ (!X v !A v B)` which is already in
384       // conjunctive normal form. Below we add each of the conjuncts of
385       // the latter expression to the result.
386       builder.addClause({posLit(Var), posLit(LHS)});
387       builder.addClause({posLit(Var), negLit(RHS)});
388       builder.addClause({negLit(Var), negLit(LHS), posLit(RHS)});
389       break;
390     }
391     case Formula::Equal: {
392       const Variable LHS = GetVar(Val->operands()[0]);
393       const Variable RHS = GetVar(Val->operands()[1]);
394 
395       if (LHS == RHS) {
396         // `X <=> (A <=> A)` is equivalent to `X` which is already in
397         // conjunctive normal form. Below we add each of the conjuncts of the
398         // latter expression to the result.
399         builder.addClause(posLit(Var));
400 
401         // No need to visit the sub-values of `Val`.
402         continue;
403       }
404       // `X <=> (A <=> B)` is equivalent to
405       // `(X v A v B) ^ (X v !A v !B) ^ (!X v A v !B) ^ (!X v !A v B)` which
406       // is already in conjunctive normal form. Below we add each of the
407       // conjuncts of the latter expression to the result.
408       builder.addClause({posLit(Var), posLit(LHS), posLit(RHS)});
409       builder.addClause({posLit(Var), negLit(LHS), negLit(RHS)});
410       builder.addClause({negLit(Var), posLit(LHS), negLit(RHS)});
411       builder.addClause({negLit(Var), negLit(LHS), posLit(RHS)});
412       break;
413     }
414     }
415     if (builder.isKnownContradictory()) {
416       return CNF;
417     }
418     for (const Formula *Child : Val->operands())
419       UnprocessedSubVals.push(Child);
420   }
421 
422   // Unit clauses that were added later were not
423   // considered for the simplification of earlier clauses. Do a final
424   // pass to find more opportunities for simplification.
425   CNFFormula FinalCNF(NextVar - 1, std::move(CNF.Atomics));
426   CNFFormulaBuilder FinalBuilder(FinalCNF);
427 
428   // Collect unit clauses.
429   for (ClauseID C = 1; C < CNF.ClauseStarts.size(); ++C) {
430     if (CNF.clauseSize(C) == 1) {
431       FinalBuilder.addClause(CNF.clauseLiterals(C)[0]);
432     }
433   }
434 
435   // Add all clauses that were added previously, preserving the order.
436   for (ClauseID C = 1; C < CNF.ClauseStarts.size(); ++C) {
437     FinalBuilder.addClause(CNF.clauseLiterals(C));
438     if (FinalBuilder.isKnownContradictory()) {
439       break;
440     }
441   }
442   // It is possible there were new unit clauses again, but
443   // we stop here and leave the rest to the solver algorithm.
444   return FinalCNF;
445 }
446 
447 class WatchedLiteralsSolverImpl {
448   /// A boolean formula in conjunctive normal form that the solver will attempt
449   /// to prove satisfiable. The formula will be modified in the process.
450   CNFFormula CNF;
451 
452   /// The search for a satisfying assignment of the variables in `Formula` will
453   /// proceed in levels, starting from 1 and going up to `Formula.LargestVar`
454   /// (inclusive). The current level is stored in `Level`. At each level the
455   /// solver will assign a value to an unassigned variable. If this leads to a
456   /// consistent partial assignment, `Level` will be incremented. Otherwise, if
457   /// it results in a conflict, the solver will backtrack by decrementing
458   /// `Level` until it reaches the most recent level where a decision was made.
459   size_t Level = 0;
460 
461   /// Maps levels (indices of the vector) to variables (elements of the vector)
462   /// that are assigned values at the respective levels.
463   ///
464   /// The element at index 0 isn't used. Variables start from the element at
465   /// index 1.
466   std::vector<Variable> LevelVars;
467 
468   /// State of the solver at a particular level.
469   enum class State : uint8_t {
470     /// Indicates that the solver made a decision.
471     Decision = 0,
472 
473     /// Indicates that the solver made a forced move.
474     Forced = 1,
475   };
476 
477   /// State of the solver at a particular level. It keeps track of previous
478   /// decisions that the solver can refer to when backtracking.
479   ///
480   /// The element at index 0 isn't used. States start from the element at index
481   /// 1.
482   std::vector<State> LevelStates;
483 
484   enum class Assignment : int8_t {
485     Unassigned = -1,
486     AssignedFalse = 0,
487     AssignedTrue = 1
488   };
489 
490   /// Maps variables (indices of the vector) to their assignments (elements of
491   /// the vector).
492   ///
493   /// The element at index 0 isn't used. Variable assignments start from the
494   /// element at index 1.
495   std::vector<Assignment> VarAssignments;
496 
497   /// A set of unassigned variables that appear in watched literals in
498   /// `Formula`. The vector is guaranteed to contain unique elements.
499   std::vector<Variable> ActiveVars;
500 
501 public:
WatchedLiteralsSolverImpl(const llvm::ArrayRef<const Formula * > & Vals)502   explicit WatchedLiteralsSolverImpl(
503       const llvm::ArrayRef<const Formula *> &Vals)
504       : CNF(buildCNF(Vals)), LevelVars(CNF.LargestVar + 1),
505         LevelStates(CNF.LargestVar + 1) {
506     assert(!Vals.empty());
507 
508     // Initialize the state at the root level to a decision so that in
509     // `reverseForcedMoves` we don't have to check that `Level >= 0` on each
510     // iteration.
511     LevelStates[0] = State::Decision;
512 
513     // Initialize all variables as unassigned.
514     VarAssignments.resize(CNF.LargestVar + 1, Assignment::Unassigned);
515 
516     // Initialize the active variables.
517     for (Variable Var = CNF.LargestVar; Var != NullVar; --Var) {
518       if (isWatched(posLit(Var)) || isWatched(negLit(Var)))
519         ActiveVars.push_back(Var);
520     }
521   }
522 
523   // Returns the `Result` and the number of iterations "remaining" from
524   // `MaxIterations` (that is, `MaxIterations` - iterations in this call).
solve(std::int64_t MaxIterations)525   std::pair<Solver::Result, std::int64_t> solve(std::int64_t MaxIterations) && {
526     if (CNF.KnownContradictory) {
527       // Short-cut the solving process. We already found out at CNF
528       // construction time that the formula is unsatisfiable.
529       return std::make_pair(Solver::Result::Unsatisfiable(), MaxIterations);
530     }
531     size_t I = 0;
532     while (I < ActiveVars.size()) {
533       if (MaxIterations == 0)
534         return std::make_pair(Solver::Result::TimedOut(), 0);
535       --MaxIterations;
536 
537       // Assert that the following invariants hold:
538       // 1. All active variables are unassigned.
539       // 2. All active variables form watched literals.
540       // 3. Unassigned variables that form watched literals are active.
541       // FIXME: Consider replacing these with test cases that fail if the any
542       // of the invariants is broken. That might not be easy due to the
543       // transformations performed by `buildCNF`.
544       assert(activeVarsAreUnassigned());
545       assert(activeVarsFormWatchedLiterals());
546       assert(unassignedVarsFormingWatchedLiteralsAreActive());
547 
548       const Variable ActiveVar = ActiveVars[I];
549 
550       // Look for unit clauses that contain the active variable.
551       const bool unitPosLit = watchedByUnitClause(posLit(ActiveVar));
552       const bool unitNegLit = watchedByUnitClause(negLit(ActiveVar));
553       if (unitPosLit && unitNegLit) {
554         // We found a conflict!
555 
556         // Backtrack and rewind the `Level` until the most recent non-forced
557         // assignment.
558         reverseForcedMoves();
559 
560         // If the root level is reached, then all possible assignments lead to
561         // a conflict.
562         if (Level == 0)
563           return std::make_pair(Solver::Result::Unsatisfiable(), MaxIterations);
564 
565         // Otherwise, take the other branch at the most recent level where a
566         // decision was made.
567         LevelStates[Level] = State::Forced;
568         const Variable Var = LevelVars[Level];
569         VarAssignments[Var] = VarAssignments[Var] == Assignment::AssignedTrue
570                                   ? Assignment::AssignedFalse
571                                   : Assignment::AssignedTrue;
572 
573         updateWatchedLiterals();
574       } else if (unitPosLit || unitNegLit) {
575         // We found a unit clause! The value of its unassigned variable is
576         // forced.
577         ++Level;
578 
579         LevelVars[Level] = ActiveVar;
580         LevelStates[Level] = State::Forced;
581         VarAssignments[ActiveVar] =
582             unitPosLit ? Assignment::AssignedTrue : Assignment::AssignedFalse;
583 
584         // Remove the variable that was just assigned from the set of active
585         // variables.
586         if (I + 1 < ActiveVars.size()) {
587           // Replace the variable that was just assigned with the last active
588           // variable for efficient removal.
589           ActiveVars[I] = ActiveVars.back();
590         } else {
591           // This was the last active variable. Repeat the process from the
592           // beginning.
593           I = 0;
594         }
595         ActiveVars.pop_back();
596 
597         updateWatchedLiterals();
598       } else if (I + 1 == ActiveVars.size()) {
599         // There are no remaining unit clauses in the formula! Make a decision
600         // for one of the active variables at the current level.
601         ++Level;
602 
603         LevelVars[Level] = ActiveVar;
604         LevelStates[Level] = State::Decision;
605         VarAssignments[ActiveVar] = decideAssignment(ActiveVar);
606 
607         // Remove the variable that was just assigned from the set of active
608         // variables.
609         ActiveVars.pop_back();
610 
611         updateWatchedLiterals();
612 
613         // This was the last active variable. Repeat the process from the
614         // beginning.
615         I = 0;
616       } else {
617         ++I;
618       }
619     }
620     return std::make_pair(Solver::Result::Satisfiable(buildSolution()),
621                           MaxIterations);
622   }
623 
624 private:
625   /// Returns a satisfying truth assignment to the atoms in the boolean formula.
buildSolution()626   llvm::DenseMap<Atom, Solver::Result::Assignment> buildSolution() {
627     llvm::DenseMap<Atom, Solver::Result::Assignment> Solution;
628     for (auto &Atomic : CNF.Atomics) {
629       // A variable may have a definite true/false assignment, or it may be
630       // unassigned indicating its truth value does not affect the result of
631       // the formula. Unassigned variables are assigned to true as a default.
632       Solution[Atomic.second] =
633           VarAssignments[Atomic.first] == Assignment::AssignedFalse
634               ? Solver::Result::Assignment::AssignedFalse
635               : Solver::Result::Assignment::AssignedTrue;
636     }
637     return Solution;
638   }
639 
640   /// Reverses forced moves until the most recent level where a decision was
641   /// made on the assignment of a variable.
reverseForcedMoves()642   void reverseForcedMoves() {
643     for (; LevelStates[Level] == State::Forced; --Level) {
644       const Variable Var = LevelVars[Level];
645 
646       VarAssignments[Var] = Assignment::Unassigned;
647 
648       // If the variable that we pass through is watched then we add it to the
649       // active variables.
650       if (isWatched(posLit(Var)) || isWatched(negLit(Var)))
651         ActiveVars.push_back(Var);
652     }
653   }
654 
655   /// Updates watched literals that are affected by a variable assignment.
updateWatchedLiterals()656   void updateWatchedLiterals() {
657     const Variable Var = LevelVars[Level];
658 
659     // Update the watched literals of clauses that currently watch the literal
660     // that falsifies `Var`.
661     const Literal FalseLit = VarAssignments[Var] == Assignment::AssignedTrue
662                                  ? negLit(Var)
663                                  : posLit(Var);
664     ClauseID FalseLitWatcher = CNF.WatchedHead[FalseLit];
665     CNF.WatchedHead[FalseLit] = NullClause;
666     while (FalseLitWatcher != NullClause) {
667       const ClauseID NextFalseLitWatcher = CNF.NextWatched[FalseLitWatcher];
668 
669       // Pick the first non-false literal as the new watched literal.
670       const size_t FalseLitWatcherStart = CNF.ClauseStarts[FalseLitWatcher];
671       size_t NewWatchedLitIdx = FalseLitWatcherStart + 1;
672       while (isCurrentlyFalse(CNF.Clauses[NewWatchedLitIdx]))
673         ++NewWatchedLitIdx;
674       const Literal NewWatchedLit = CNF.Clauses[NewWatchedLitIdx];
675       const Variable NewWatchedLitVar = var(NewWatchedLit);
676 
677       // Swap the old watched literal for the new one in `FalseLitWatcher` to
678       // maintain the invariant that the watched literal is at the beginning of
679       // the clause.
680       CNF.Clauses[NewWatchedLitIdx] = FalseLit;
681       CNF.Clauses[FalseLitWatcherStart] = NewWatchedLit;
682 
683       // If the new watched literal isn't watched by any other clause and its
684       // variable isn't assigned we need to add it to the active variables.
685       if (!isWatched(NewWatchedLit) && !isWatched(notLit(NewWatchedLit)) &&
686           VarAssignments[NewWatchedLitVar] == Assignment::Unassigned)
687         ActiveVars.push_back(NewWatchedLitVar);
688 
689       CNF.NextWatched[FalseLitWatcher] = CNF.WatchedHead[NewWatchedLit];
690       CNF.WatchedHead[NewWatchedLit] = FalseLitWatcher;
691 
692       // Go to the next clause that watches `FalseLit`.
693       FalseLitWatcher = NextFalseLitWatcher;
694     }
695   }
696 
697   /// Returns true if and only if one of the clauses that watch `Lit` is a unit
698   /// clause.
watchedByUnitClause(Literal Lit) const699   bool watchedByUnitClause(Literal Lit) const {
700     for (ClauseID LitWatcher = CNF.WatchedHead[Lit]; LitWatcher != NullClause;
701          LitWatcher = CNF.NextWatched[LitWatcher]) {
702       llvm::ArrayRef<Literal> Clause = CNF.clauseLiterals(LitWatcher);
703 
704       // Assert the invariant that the watched literal is always the first one
705       // in the clause.
706       // FIXME: Consider replacing this with a test case that fails if the
707       // invariant is broken by `updateWatchedLiterals`. That might not be easy
708       // due to the transformations performed by `buildCNF`.
709       assert(Clause.front() == Lit);
710 
711       if (isUnit(Clause))
712         return true;
713     }
714     return false;
715   }
716 
717   /// Returns true if and only if `Clause` is a unit clause.
isUnit(llvm::ArrayRef<Literal> Clause) const718   bool isUnit(llvm::ArrayRef<Literal> Clause) const {
719     return llvm::all_of(Clause.drop_front(),
720                         [this](Literal L) { return isCurrentlyFalse(L); });
721   }
722 
723   /// Returns true if and only if `Lit` evaluates to `false` in the current
724   /// partial assignment.
isCurrentlyFalse(Literal Lit) const725   bool isCurrentlyFalse(Literal Lit) const {
726     return static_cast<int8_t>(VarAssignments[var(Lit)]) ==
727            static_cast<int8_t>(Lit & 1);
728   }
729 
730   /// Returns true if and only if `Lit` is watched by a clause in `Formula`.
isWatched(Literal Lit) const731   bool isWatched(Literal Lit) const {
732     return CNF.WatchedHead[Lit] != NullClause;
733   }
734 
735   /// Returns an assignment for an unassigned variable.
decideAssignment(Variable Var) const736   Assignment decideAssignment(Variable Var) const {
737     return !isWatched(posLit(Var)) || isWatched(negLit(Var))
738                ? Assignment::AssignedFalse
739                : Assignment::AssignedTrue;
740   }
741 
742   /// Returns a set of all watched literals.
watchedLiterals() const743   llvm::DenseSet<Literal> watchedLiterals() const {
744     llvm::DenseSet<Literal> WatchedLiterals;
745     for (Literal Lit = 2; Lit < CNF.WatchedHead.size(); Lit++) {
746       if (CNF.WatchedHead[Lit] == NullClause)
747         continue;
748       WatchedLiterals.insert(Lit);
749     }
750     return WatchedLiterals;
751   }
752 
753   /// Returns true if and only if all active variables are unassigned.
activeVarsAreUnassigned() const754   bool activeVarsAreUnassigned() const {
755     return llvm::all_of(ActiveVars, [this](Variable Var) {
756       return VarAssignments[Var] == Assignment::Unassigned;
757     });
758   }
759 
760   /// Returns true if and only if all active variables form watched literals.
activeVarsFormWatchedLiterals() const761   bool activeVarsFormWatchedLiterals() const {
762     const llvm::DenseSet<Literal> WatchedLiterals = watchedLiterals();
763     return llvm::all_of(ActiveVars, [&WatchedLiterals](Variable Var) {
764       return WatchedLiterals.contains(posLit(Var)) ||
765              WatchedLiterals.contains(negLit(Var));
766     });
767   }
768 
769   /// Returns true if and only if all unassigned variables that are forming
770   /// watched literals are active.
unassignedVarsFormingWatchedLiteralsAreActive() const771   bool unassignedVarsFormingWatchedLiteralsAreActive() const {
772     const llvm::DenseSet<Variable> ActiveVarsSet(ActiveVars.begin(),
773                                                  ActiveVars.end());
774     for (Literal Lit : watchedLiterals()) {
775       const Variable Var = var(Lit);
776       if (VarAssignments[Var] != Assignment::Unassigned)
777         continue;
778       if (ActiveVarsSet.contains(Var))
779         continue;
780       return false;
781     }
782     return true;
783   }
784 };
785 
786 Solver::Result
solve(llvm::ArrayRef<const Formula * > Vals)787 WatchedLiteralsSolver::solve(llvm::ArrayRef<const Formula *> Vals) {
788   if (Vals.empty())
789     return Solver::Result::Satisfiable({{}});
790   auto [Res, Iterations] = WatchedLiteralsSolverImpl(Vals).solve(MaxIterations);
791   MaxIterations = Iterations;
792   return Res;
793 }
794 
795 } // namespace dataflow
796 } // namespace clang
797