CachingSolver.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "klee/Solver.h"
00012
00013 #include "klee/Constraints.h"
00014 #include "klee/Expr.h"
00015 #include "klee/IncompleteSolver.h"
00016 #include "klee/SolverImpl.h"
00017
00018 #include "SolverStats.h"
00019
00020 #include <tr1/unordered_map>
00021
00022 using namespace klee;
00023
00024 class CachingSolver : public SolverImpl {
00025 private:
00026 ref<Expr> canonicalizeQuery(ref<Expr> originalQuery,
00027 bool &negationUsed);
00028
00029 void cacheInsert(const Query& query,
00030 IncompleteSolver::PartialValidity result);
00031
00032 bool cacheLookup(const Query& query,
00033 IncompleteSolver::PartialValidity &result);
00034
00035 struct CacheEntry {
00036 CacheEntry(const ConstraintManager &c, ref<Expr> q)
00037 : constraints(c), query(q) {}
00038
00039 CacheEntry(const CacheEntry &ce)
00040 : constraints(ce.constraints), query(ce.query) {}
00041
00042 ConstraintManager constraints;
00043 ref<Expr> query;
00044
00045 bool operator==(const CacheEntry &b) const {
00046 return constraints==b.constraints && *query.get()==*b.query.get();
00047 }
00048 };
00049
00050 struct CacheEntryHash {
00051 unsigned operator()(const CacheEntry &ce) const {
00052 unsigned result = ce.query->hash();
00053
00054 for (ConstraintManager::constraint_iterator it = ce.constraints.begin();
00055 it != ce.constraints.end(); ++it)
00056 result ^= (*it)->hash();
00057
00058 return result;
00059 }
00060 };
00061
00062 typedef std::tr1::unordered_map<CacheEntry,
00063 IncompleteSolver::PartialValidity,
00064 CacheEntryHash> cache_map;
00065
00066 Solver *solver;
00067 cache_map cache;
00068
00069 public:
00070 CachingSolver(Solver *s) : solver(s) {}
00071 ~CachingSolver() { cache.clear(); delete solver; }
00072
00073 bool computeValidity(const Query&, Solver::Validity &result);
00074 bool computeTruth(const Query&, bool &isValid);
00075 bool computeValue(const Query& query, ref<Expr> &result) {
00076 return solver->impl->computeValue(query, result);
00077 }
00078 bool computeInitialValues(const Query& query,
00079 const std::vector<const Array*> &objects,
00080 std::vector< std::vector<unsigned char> > &values,
00081 bool &hasSolution) {
00082 return solver->impl->computeInitialValues(query, objects, values,
00083 hasSolution);
00084 }
00085 };
00086
00090 ref<Expr> CachingSolver::canonicalizeQuery(ref<Expr> originalQuery,
00091 bool &negationUsed) {
00092 ref<Expr> negatedQuery = Expr::createNot(originalQuery);
00093
00094
00095 if (originalQuery.compare(negatedQuery) < 0) {
00096 negationUsed = false;
00097 return originalQuery;
00098 } else {
00099 negationUsed = true;
00100 return negatedQuery;
00101 }
00102 }
00103
00106 bool CachingSolver::cacheLookup(const Query& query,
00107 IncompleteSolver::PartialValidity &result) {
00108 bool negationUsed;
00109 ref<Expr> canonicalQuery = canonicalizeQuery(query.expr, negationUsed);
00110
00111 CacheEntry ce(query.constraints, canonicalQuery);
00112 cache_map::iterator it = cache.find(ce);
00113
00114 if (it != cache.end()) {
00115 result = (negationUsed ?
00116 IncompleteSolver::negatePartialValidity(it->second) :
00117 it->second);
00118 return true;
00119 }
00120
00121 return false;
00122 }
00123
00125 void CachingSolver::cacheInsert(const Query& query,
00126 IncompleteSolver::PartialValidity result) {
00127 bool negationUsed;
00128 ref<Expr> canonicalQuery = canonicalizeQuery(query.expr, negationUsed);
00129
00130 CacheEntry ce(query.constraints, canonicalQuery);
00131 IncompleteSolver::PartialValidity cachedResult =
00132 (negationUsed ? IncompleteSolver::negatePartialValidity(result) : result);
00133
00134 cache.insert(std::make_pair(ce, cachedResult));
00135 }
00136
00137 bool CachingSolver::computeValidity(const Query& query,
00138 Solver::Validity &result) {
00139 IncompleteSolver::PartialValidity cachedResult;
00140 bool tmp, cacheHit = cacheLookup(query, cachedResult);
00141
00142 if (cacheHit) {
00143 ++stats::queryCacheHits;
00144
00145 switch(cachedResult) {
00146 case IncompleteSolver::MustBeTrue:
00147 result = Solver::True;
00148 return true;
00149 case IncompleteSolver::MustBeFalse:
00150 result = Solver::False;
00151 return true;
00152 case IncompleteSolver::TrueOrFalse:
00153 result = Solver::Unknown;
00154 return true;
00155 case IncompleteSolver::MayBeTrue: {
00156 if (!solver->impl->computeTruth(query, tmp))
00157 return false;
00158 if (tmp) {
00159 cacheInsert(query, IncompleteSolver::MustBeTrue);
00160 result = Solver::True;
00161 return true;
00162 } else {
00163 cacheInsert(query, IncompleteSolver::TrueOrFalse);
00164 result = Solver::Unknown;
00165 return true;
00166 }
00167 }
00168 case IncompleteSolver::MayBeFalse: {
00169 if (!solver->impl->computeTruth(query.negateExpr(), tmp))
00170 return false;
00171 if (tmp) {
00172 cacheInsert(query, IncompleteSolver::MustBeFalse);
00173 result = Solver::False;
00174 return true;
00175 } else {
00176 cacheInsert(query, IncompleteSolver::TrueOrFalse);
00177 result = Solver::Unknown;
00178 return true;
00179 }
00180 }
00181 default: assert(0 && "unreachable");
00182 }
00183 }
00184
00185 ++stats::queryCacheMisses;
00186
00187 if (!solver->impl->computeValidity(query, result))
00188 return false;
00189
00190 switch (result) {
00191 case Solver::True:
00192 cachedResult = IncompleteSolver::MustBeTrue; break;
00193 case Solver::False:
00194 cachedResult = IncompleteSolver::MustBeFalse; break;
00195 default:
00196 cachedResult = IncompleteSolver::TrueOrFalse; break;
00197 }
00198
00199 cacheInsert(query, cachedResult);
00200 return true;
00201 }
00202
00203 bool CachingSolver::computeTruth(const Query& query,
00204 bool &isValid) {
00205 IncompleteSolver::PartialValidity cachedResult;
00206 bool cacheHit = cacheLookup(query, cachedResult);
00207
00208
00209
00210 if (cacheHit && cachedResult != IncompleteSolver::MayBeTrue) {
00211 ++stats::queryCacheHits;
00212 isValid = (cachedResult == IncompleteSolver::MustBeTrue);
00213 return true;
00214 }
00215
00216 ++stats::queryCacheMisses;
00217
00218
00219 if (!solver->impl->computeTruth(query, isValid))
00220 return false;
00221
00222 if (isValid) {
00223 cachedResult = IncompleteSolver::MustBeTrue;
00224 } else if (cacheHit) {
00225
00226
00227 assert(cachedResult == IncompleteSolver::MayBeTrue);
00228 cachedResult = IncompleteSolver::TrueOrFalse;
00229 } else {
00230 cachedResult = IncompleteSolver::MayBeFalse;
00231 }
00232
00233 cacheInsert(query, cachedResult);
00234 return true;
00235 }
00236
00238
00239 Solver *klee::createCachingSolver(Solver *_solver) {
00240 return new Solver(new CachingSolver(_solver));
00241 }