1 //== SMTConstraintManager.h -------------------------------------*- C++ -*--==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file defines a SMT generic API, which will be the base class for
10 //  every SMT solver specific class.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTCONSTRAINTMANAGER_H
15 #define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTCONSTRAINTMANAGER_H
16 
17 #include "clang/Basic/JsonSupport.h"
18 #include "clang/Basic/TargetInfo.h"
19 #include "clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h"
20 #include "clang/StaticAnalyzer/Core/PathSensitive/SMTConv.h"
21 #include <optional>
22 
23 typedef llvm::ImmutableSet<
24     std::pair<clang::ento::SymbolRef, const llvm::SMTExpr *>>
25     ConstraintSMTType;
26 REGISTER_TRAIT_WITH_PROGRAMSTATE(ConstraintSMT, ConstraintSMTType)
27 
28 namespace clang {
29 namespace ento {
30 
31 class SMTConstraintManager : public clang::ento::SimpleConstraintManager {
32   mutable llvm::SMTSolverRef Solver = llvm::CreateZ3Solver();
33 
34 public:
35   SMTConstraintManager(clang::ento::ExprEngine *EE,
36                        clang::ento::SValBuilder &SB)
37       : SimpleConstraintManager(EE, SB) {}
38   virtual ~SMTConstraintManager() = default;
39 
40   //===------------------------------------------------------------------===//
41   // Implementation for interface from SimpleConstraintManager.
42   //===------------------------------------------------------------------===//
43 
44   ProgramStateRef assumeSym(ProgramStateRef State, SymbolRef Sym,
45                             bool Assumption) override {
46     ASTContext &Ctx = getBasicVals().getContext();
47 
48     QualType RetTy;
49     bool hasComparison;
50 
51     llvm::SMTExprRef Exp =
52         SMTConv::getExpr(Solver, Ctx, Sym, &RetTy, &hasComparison);
53 
54     // Create zero comparison for implicit boolean cast, with reversed
55     // assumption
56     if (!hasComparison && !RetTy->isBooleanType())
57       return assumeExpr(
58           State, Sym,
59           SMTConv::getZeroExpr(Solver, Ctx, Exp, RetTy, !Assumption));
60 
61     return assumeExpr(State, Sym, Assumption ? Exp : Solver->mkNot(Exp));
62   }
63 
64   ProgramStateRef assumeSymInclusiveRange(ProgramStateRef State, SymbolRef Sym,
65                                           const llvm::APSInt &From,
66                                           const llvm::APSInt &To,
67                                           bool InRange) override {
68     ASTContext &Ctx = getBasicVals().getContext();
69     return assumeExpr(
70         State, Sym, SMTConv::getRangeExpr(Solver, Ctx, Sym, From, To, InRange));
71   }
72 
73   ProgramStateRef assumeSymUnsupported(ProgramStateRef State, SymbolRef Sym,
74                                        bool Assumption) override {
75     // Skip anything that is unsupported
76     return State;
77   }
78 
79   //===------------------------------------------------------------------===//
80   // Implementation for interface from ConstraintManager.
81   //===------------------------------------------------------------------===//
82 
83   ConditionTruthVal checkNull(ProgramStateRef State, SymbolRef Sym) override {
84     ASTContext &Ctx = getBasicVals().getContext();
85 
86     QualType RetTy;
87     // The expression may be casted, so we cannot call getZ3DataExpr() directly
88     llvm::SMTExprRef VarExp = SMTConv::getExpr(Solver, Ctx, Sym, &RetTy);
89     llvm::SMTExprRef Exp =
90         SMTConv::getZeroExpr(Solver, Ctx, VarExp, RetTy, /*Assumption=*/true);
91 
92     // Negate the constraint
93     llvm::SMTExprRef NotExp =
94         SMTConv::getZeroExpr(Solver, Ctx, VarExp, RetTy, /*Assumption=*/false);
95 
96     ConditionTruthVal isSat = checkModel(State, Sym, Exp);
97     ConditionTruthVal isNotSat = checkModel(State, Sym, NotExp);
98 
99     // Zero is the only possible solution
100     if (isSat.isConstrainedTrue() && isNotSat.isConstrainedFalse())
101       return true;
102 
103     // Zero is not a solution
104     if (isSat.isConstrainedFalse() && isNotSat.isConstrainedTrue())
105       return false;
106 
107     // Zero may be a solution
108     return ConditionTruthVal();
109   }
110 
111   const llvm::APSInt *getSymVal(ProgramStateRef State,
112                                 SymbolRef Sym) const override {
113     BasicValueFactory &BVF = getBasicVals();
114     ASTContext &Ctx = BVF.getContext();
115 
116     if (const SymbolData *SD = dyn_cast<SymbolData>(Sym)) {
117       QualType Ty = Sym->getType();
118       assert(!Ty->isRealFloatingType());
119       llvm::APSInt Value(Ctx.getTypeSize(Ty),
120                          !Ty->isSignedIntegerOrEnumerationType());
121 
122       // TODO: this should call checkModel so we can use the cache, however,
123       // this method tries to get the interpretation (the actual value) from
124       // the solver, which is currently not cached.
125 
126       llvm::SMTExprRef Exp = SMTConv::fromData(Solver, Ctx, SD);
127 
128       Solver->reset();
129       addStateConstraints(State);
130 
131       // Constraints are unsatisfiable
132       std::optional<bool> isSat = Solver->check();
133       if (!isSat || !*isSat)
134         return nullptr;
135 
136       // Model does not assign interpretation
137       if (!Solver->getInterpretation(Exp, Value))
138         return nullptr;
139 
140       // A value has been obtained, check if it is the only value
141       llvm::SMTExprRef NotExp = SMTConv::fromBinOp(
142           Solver, Exp, BO_NE,
143           Ty->isBooleanType() ? Solver->mkBoolean(Value.getBoolValue())
144                               : Solver->mkBitvector(Value, Value.getBitWidth()),
145           /*isSigned=*/false);
146 
147       Solver->addConstraint(NotExp);
148 
149       std::optional<bool> isNotSat = Solver->check();
150       if (!isNotSat || *isNotSat)
151         return nullptr;
152 
153       // This is the only solution, store it
154       return &BVF.getValue(Value);
155     }
156 
157     if (const SymbolCast *SC = dyn_cast<SymbolCast>(Sym)) {
158       SymbolRef CastSym = SC->getOperand();
159       QualType CastTy = SC->getType();
160       // Skip the void type
161       if (CastTy->isVoidType())
162         return nullptr;
163 
164       const llvm::APSInt *Value;
165       if (!(Value = getSymVal(State, CastSym)))
166         return nullptr;
167       return &BVF.Convert(SC->getType(), *Value);
168     }
169 
170     if (const BinarySymExpr *BSE = dyn_cast<BinarySymExpr>(Sym)) {
171       const llvm::APSInt *LHS, *RHS;
172       if (const SymIntExpr *SIE = dyn_cast<SymIntExpr>(BSE)) {
173         LHS = getSymVal(State, SIE->getLHS());
174         RHS = &SIE->getRHS();
175       } else if (const IntSymExpr *ISE = dyn_cast<IntSymExpr>(BSE)) {
176         LHS = &ISE->getLHS();
177         RHS = getSymVal(State, ISE->getRHS());
178       } else if (const SymSymExpr *SSM = dyn_cast<SymSymExpr>(BSE)) {
179         // Early termination to avoid expensive call
180         LHS = getSymVal(State, SSM->getLHS());
181         RHS = LHS ? getSymVal(State, SSM->getRHS()) : nullptr;
182       } else {
183         llvm_unreachable("Unsupported binary expression to get symbol value!");
184       }
185 
186       if (!LHS || !RHS)
187         return nullptr;
188 
189       llvm::APSInt ConvertedLHS, ConvertedRHS;
190       QualType LTy, RTy;
191       std::tie(ConvertedLHS, LTy) = SMTConv::fixAPSInt(Ctx, *LHS);
192       std::tie(ConvertedRHS, RTy) = SMTConv::fixAPSInt(Ctx, *RHS);
193       SMTConv::doIntTypeConversion<llvm::APSInt, &SMTConv::castAPSInt>(
194           Solver, Ctx, ConvertedLHS, LTy, ConvertedRHS, RTy);
195       return BVF.evalAPSInt(BSE->getOpcode(), ConvertedLHS, ConvertedRHS);
196     }
197 
198     llvm_unreachable("Unsupported expression to get symbol value!");
199   }
200 
201   ProgramStateRef removeDeadBindings(ProgramStateRef State,
202                                      SymbolReaper &SymReaper) override {
203     auto CZ = State->get<ConstraintSMT>();
204     auto &CZFactory = State->get_context<ConstraintSMT>();
205 
206     for (const auto &Entry : CZ) {
207       if (SymReaper.isDead(Entry.first))
208         CZ = CZFactory.remove(CZ, Entry);
209     }
210 
211     return State->set<ConstraintSMT>(CZ);
212   }
213 
214   void printJson(raw_ostream &Out, ProgramStateRef State, const char *NL = "\n",
215                  unsigned int Space = 0, bool IsDot = false) const override {
216     ConstraintSMTType Constraints = State->get<ConstraintSMT>();
217 
218     Indent(Out, Space, IsDot) << "\"constraints\": ";
219     if (Constraints.isEmpty()) {
220       Out << "null," << NL;
221       return;
222     }
223 
224     ++Space;
225     Out << '[' << NL;
226     for (ConstraintSMTType::iterator I = Constraints.begin();
227          I != Constraints.end(); ++I) {
228       Indent(Out, Space, IsDot)
229           << "{ \"symbol\": \"" << I->first << "\", \"range\": \"";
230       I->second->print(Out);
231       Out << "\" }";
232 
233       if (std::next(I) != Constraints.end())
234         Out << ',';
235       Out << NL;
236     }
237 
238     --Space;
239     Indent(Out, Space, IsDot) << "],";
240   }
241 
242   bool haveEqualConstraints(ProgramStateRef S1,
243                             ProgramStateRef S2) const override {
244     return S1->get<ConstraintSMT>() == S2->get<ConstraintSMT>();
245   }
246 
247   bool canReasonAbout(SVal X) const override {
248     const TargetInfo &TI = getBasicVals().getContext().getTargetInfo();
249 
250     std::optional<nonloc::SymbolVal> SymVal = X.getAs<nonloc::SymbolVal>();
251     if (!SymVal)
252       return true;
253 
254     const SymExpr *Sym = SymVal->getSymbol();
255     QualType Ty = Sym->getType();
256 
257     // Complex types are not modeled
258     if (Ty->isComplexType() || Ty->isComplexIntegerType())
259       return false;
260 
261     // Non-IEEE 754 floating-point types are not modeled
262     if ((Ty->isSpecificBuiltinType(BuiltinType::LongDouble) &&
263          (&TI.getLongDoubleFormat() == &llvm::APFloat::x87DoubleExtended() ||
264           &TI.getLongDoubleFormat() == &llvm::APFloat::PPCDoubleDouble())))
265       return false;
266 
267     if (Ty->isRealFloatingType())
268       return Solver->isFPSupported();
269 
270     if (isa<SymbolData>(Sym))
271       return true;
272 
273     SValBuilder &SVB = getSValBuilder();
274 
275     if (const SymbolCast *SC = dyn_cast<SymbolCast>(Sym))
276       return canReasonAbout(SVB.makeSymbolVal(SC->getOperand()));
277 
278     if (const BinarySymExpr *BSE = dyn_cast<BinarySymExpr>(Sym)) {
279       if (const SymIntExpr *SIE = dyn_cast<SymIntExpr>(BSE))
280         return canReasonAbout(SVB.makeSymbolVal(SIE->getLHS()));
281 
282       if (const IntSymExpr *ISE = dyn_cast<IntSymExpr>(BSE))
283         return canReasonAbout(SVB.makeSymbolVal(ISE->getRHS()));
284 
285       if (const SymSymExpr *SSE = dyn_cast<SymSymExpr>(BSE))
286         return canReasonAbout(SVB.makeSymbolVal(SSE->getLHS())) &&
287                canReasonAbout(SVB.makeSymbolVal(SSE->getRHS()));
288     }
289 
290     llvm_unreachable("Unsupported expression to reason about!");
291   }
292 
293   /// Dumps SMT formula
294   LLVM_DUMP_METHOD void dump() const { Solver->dump(); }
295 
296 protected:
297   // Check whether a new model is satisfiable, and update the program state.
298   virtual ProgramStateRef assumeExpr(ProgramStateRef State, SymbolRef Sym,
299                                      const llvm::SMTExprRef &Exp) {
300     // Check the model, avoid simplifying AST to save time
301     if (checkModel(State, Sym, Exp).isConstrainedTrue())
302       return State->add<ConstraintSMT>(std::make_pair(Sym, Exp));
303 
304     return nullptr;
305   }
306 
307   /// Given a program state, construct the logical conjunction and add it to
308   /// the solver
309   virtual void addStateConstraints(ProgramStateRef State) const {
310     // TODO: Don't add all the constraints, only the relevant ones
311     auto CZ = State->get<ConstraintSMT>();
312     auto I = CZ.begin(), IE = CZ.end();
313 
314     // Construct the logical AND of all the constraints
315     if (I != IE) {
316       std::vector<llvm::SMTExprRef> ASTs;
317 
318       llvm::SMTExprRef Constraint = I++->second;
319       while (I != IE) {
320         Constraint = Solver->mkAnd(Constraint, I++->second);
321       }
322 
323       Solver->addConstraint(Constraint);
324     }
325   }
326 
327   // Generate and check a Z3 model, using the given constraint.
328   ConditionTruthVal checkModel(ProgramStateRef State, SymbolRef Sym,
329                                const llvm::SMTExprRef &Exp) const {
330     ProgramStateRef NewState =
331         State->add<ConstraintSMT>(std::make_pair(Sym, Exp));
332 
333     llvm::FoldingSetNodeID ID;
334     NewState->get<ConstraintSMT>().Profile(ID);
335 
336     unsigned hash = ID.ComputeHash();
337     auto I = Cached.find(hash);
338     if (I != Cached.end())
339       return I->second;
340 
341     Solver->reset();
342     addStateConstraints(NewState);
343 
344     std::optional<bool> res = Solver->check();
345     if (!res)
346       Cached[hash] = ConditionTruthVal();
347     else
348       Cached[hash] = ConditionTruthVal(*res);
349 
350     return Cached[hash];
351   }
352 
353   // Cache the result of an SMT query (true, false, unknown). The key is the
354   // hash of the constraints in a state
355   mutable llvm::DenseMap<unsigned, ConditionTruthVal> Cached;
356 }; // end class SMTConstraintManager
357 
358 } // namespace ento
359 } // namespace clang
360 
361 #endif
362