Skip to content

Commit

Permalink
Cleanup of RefinementSolver
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHaas committed Aug 22, 2023
1 parent 77ce57f commit 7158193
Showing 1 changed file with 129 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,19 @@
Refinement is a custom solving procedure that starts from a weak memory model (possibly the empty model)
and iteratively refines it to perform a verification task.
It can be understood as a lazy offline-SMT solver.
More concretely, it iteratively:
- Finds some assertion-violating execution w.r.t. to some (very weak) baseline memory model
- Checks the consistency of this execution using a custom theory solver (CAAT-Solver)
- Refines the used memory model if the found execution was inconsistent, using the explanations
More concretely, it iteratively
- finds some assertion-violating execution w.r.t. to some (very weak) baseline memory model
- checks the consistency of this execution using a custom theory solver (CAAT-Solver)
- refines the used memory model if the found execution was inconsistent, using the explanations
provided by the theory solver.
*/
@Options
public class RefinementSolver extends ModelChecker {

private static final Logger logger = LogManager.getLogger(RefinementSolver.class);

private final SolverContext ctx;
private final ProverEnvironment prover;
private final VerificationTask task;

// =========================== Configurables ===========================
// ================================================================================================================
// Configuration

@Option(name=BASELINE,
description="Refinement starts from this baseline WMM.",
Expand All @@ -93,34 +90,83 @@ public class RefinementSolver extends ModelChecker {
toUppercase=true)
private boolean printCovReport = false;

// ======================================================================
// ================================================================================================================
// Data classes

private RefinementSolver(SolverContext c, ProverEnvironment p, VerificationTask t) {
ctx = c;
prover = p;
task = t;
private enum SMTStatus {
SAT, UNSAT, UNKNOWN
}

private record RefinementIteration(
SMTStatus smtStatus,
long nativeSmtTime,
long caatTime,
long refineTime,
// The following are only meaningful if <smtStatus>==SAT
CAATSolver.Status caatStatus,
BooleanFormula refinementFormula,
// The following are only for statistics keeping
WMMSolver.Statistics caatStats,
DNF<CoreLiteral> inconsistencyReasons,
List<Event> observedEvents
) {
public boolean isInconclusive() { return smtStatus == SMTStatus.SAT && caatStatus == INCONSISTENT; }
public boolean isConclusive() { return !isInconclusive(); }
}

private record RefinementTrace(List<RefinementIteration> iterations) {
public RefinementIteration getFinalIteration() { return iterations.get(iterations.size() - 1); }

public SMTStatus getFinalResult() {
final RefinementIteration finalIteration = getFinalIteration();
if (finalIteration.smtStatus != SMTStatus.SAT) {
return finalIteration.smtStatus;
} else if (finalIteration.caatStatus == CONSISTENT) {
return SMTStatus.SAT;
} else {
return SMTStatus.UNKNOWN;
}
}

public long getNativeSmtTime() { return iterations.stream().mapToLong(RefinementIteration::nativeSmtTime).sum(); }
public long getCaatTime() { return iterations.stream().mapToLong(RefinementIteration::caatTime).sum(); }
public long getRefiningTime() { return iterations.stream().mapToLong(RefinementIteration::refineTime).sum(); }

public Set<Event> getObservedEvents() {
return iterations.stream().filter(iter -> iter.observedEvents != null)
.flatMap(iter -> iter.observedEvents.stream()).collect(Collectors.toSet());
}

public List<BooleanFormula> getRefinementFormulas() {
return iterations.stream().filter(iter -> iter.refinementFormula != null)
.map(RefinementIteration::refinementFormula).toList();
}

public RefinementTrace concat(RefinementTrace other) {
return new RefinementTrace(Lists.newArrayList(Iterables.concat(this.iterations, other.iterations)));
}
}

// ================================================================================================================
// Refinement solver

private RefinementSolver() {
}

//TODO: We do not yet use Witness information. The problem is that WitnessGraph.encode() generates
// constraints on hb, which is not encoded in Refinement.
//TODO (2): Add possibility for Refinement to handle CAT-properties (it ignores them for now).
public static RefinementSolver run(SolverContext ctx, ProverEnvironment prover, VerificationTask task)
throws InterruptedException, SolverException, InvalidConfigurationException {
RefinementSolver solver = new RefinementSolver(ctx, prover, task);
RefinementSolver solver = new RefinementSolver();
task.getConfig().inject(solver);
logger.info("{}: {}", BASELINE, solver.baselines);
solver.run();
solver.runInternal(ctx, prover, task);
return solver;
}

private record RefinementContext(
VerificationTask task,
WMMSolver solver,
Refiner refiner,
EncodingContext encCtx
) {}

private void run() throws InterruptedException, SolverException, InvalidConfigurationException {
private void runInternal(SolverContext ctx, ProverEnvironment prover, VerificationTask task)
throws InterruptedException, SolverException, InvalidConfigurationException {
final Program program = task.getProgram();
final Wmm memoryModel = task.getMemoryModel();
final Wmm baselineModel = createDefaultWmm();
Expand Down Expand Up @@ -169,7 +215,6 @@ private void run() throws InterruptedException, SolverException, InvalidConfigur
final WMMSolver solver = WMMSolver.withContext(context, cutRelations, task, analysisContext);
final Refiner refiner = new Refiner(analysisContext);
final Property.Type propertyType = Property.getCombinedType(task.getProperty(), task);
final RefinementContext refineCtx = new RefinementContext(task, solver, refiner, context);

logger.info("Starting encoding using " + ctx.getVersion());
prover.addConstraint(programEncoder.encodeFullProgram());
Expand All @@ -183,8 +228,8 @@ private void run() throws InterruptedException, SolverException, InvalidConfigur
prover.push();
prover.addConstraint(propertyEncoder.encodeProperties(task.getProperty()));

final RefinementTrace propertyTrace = runRefinement(prover, refineCtx);
final SMTStatus smtStatus = propertyTrace.getFinalResult();
final RefinementTrace propertyTrace = runRefinement(task, prover, solver, refiner);
SMTStatus smtStatus = propertyTrace.getFinalResult();

if (logger.isInfoEnabled()) {
final String message = switch (smtStatus) {
Expand Down Expand Up @@ -214,14 +259,20 @@ private void run() throws InterruptedException, SolverException, InvalidConfigur
prover.addConstraint(propertyEncoder.encodeBoundEventExec());
// Add back the refinement clauses we already found, hoping that this improves the performance.
prover.addConstraint(bmgr.and(propertyTrace.getRefinementFormulas()));
final RefinementTrace boundTrace = runRefinement(prover, refineCtx);
res = boundTrace.getFinalResult() == SMTStatus.UNSAT ? PASS : UNKNOWN;
final RefinementTrace boundTrace = runRefinement(task, prover, solver, refiner);
boundCheckTime = System.currentTimeMillis() - lastTime;

smtStatus = boundTrace.getFinalResult();
combinedTrace = combinedTrace.concat(boundTrace);
if (res == PASS) {
logger.info("Bounds are unreachable: Unbounded specification proven.");
} else {
logger.info("Bounds are reachable: Unbounded specification unknown.");
res = smtStatus == SMTStatus.UNSAT ? PASS : UNKNOWN;

if (logger.isInfoEnabled()) {
final String message = switch (smtStatus) {
case UNKNOWN -> "Bound check was inconclusive (bug?)";
case SAT -> "Bounds are reachable: Unbounded specification unknown.";
case UNSAT -> "Bounds are unreachable: Unbounded specification proven.";
};
logger.info(message);
}
} else {
res = FAIL;
Expand Down Expand Up @@ -251,78 +302,24 @@ private void run() throws InterruptedException, SolverException, InvalidConfigur
logger.info("Verification finished with result " + res);
}

private enum SMTStatus {
SAT, UNSAT, UNKNOWN
}

private record RefinementTrace(List<RefinementIteration> iterations) {
public RefinementIteration getFinalIteration() { return iterations.get(iterations.size() - 1); }

public SMTStatus getFinalResult() {
final RefinementIteration finalIteration = getFinalIteration();
if (finalIteration.smtStatus != SMTStatus.SAT) {
return finalIteration.smtStatus;
} else if (finalIteration.caatStatus == CONSISTENT) {
return SMTStatus.SAT;
} else {
return SMTStatus.UNKNOWN;
}
}

public long getNativeSmtTime() { return iterations.stream().mapToLong(RefinementIteration::nativeSmtTime).sum(); }
public long getCaatTime() { return iterations.stream().mapToLong(RefinementIteration::caatTime).sum(); }
public long getRefiningTime() { return iterations.stream().mapToLong(RefinementIteration::refineTime).sum(); }
// ================================================================================================================
// Refinement core algorithm

public Set<Event> getObservedEvents() {
return iterations.stream().filter(iter -> iter.observedEvents != null)
.flatMap(iter -> iter.observedEvents.stream()).collect(Collectors.toSet());
}

public List<BooleanFormula> getRefinementFormulas() {
return iterations.stream().filter(iter -> iter.refinementFormula != null)
.map(RefinementIteration::refinementFormula).toList();
}

public RefinementTrace concat(RefinementTrace other) {
return new RefinementTrace(Lists.newArrayList(Iterables.concat(this.iterations, other.iterations)));
}
}

private record RefinementIteration(
SMTStatus smtStatus,
long nativeSmtTime,
long caatTime,
long refineTime,
// The following are only meaningful if <smtStatus>==SAT
CAATSolver.Status caatStatus,
BooleanFormula refinementFormula,
// The following are only for statistics keeping
WMMSolver.Statistics caatStats,
DNF<CoreLiteral> inconsistencyReasons,
List<Event> observedEvents
) {
public boolean isInconclusive() { return smtStatus == SMTStatus.SAT && caatStatus == INCONSISTENT; }
public boolean isConclusive() { return !isInconclusive(); }
public boolean isConclusivelyUnknown() {
return smtStatus == SMTStatus.UNKNOWN || (smtStatus == SMTStatus.SAT && caatStatus == INCONCLUSIVE);
}
}

// Starts a refinement run on a given prover environment.
private RefinementTrace runRefinement(ProverEnvironment prover, RefinementContext refineCtx)
// TODO: We could expose the following method(s) to allow for more general application of refinement.
private RefinementTrace runRefinement(VerificationTask task, ProverEnvironment prover, WMMSolver solver, Refiner refiner)
throws SolverException, InterruptedException {

final List<RefinementIteration> trace = new ArrayList<>();
boolean isFinalIteration = false;
while (!isFinalIteration) {

final RefinementIteration iteration = doRefinementIteration(prover, refineCtx);
final RefinementIteration iteration = doRefinementIteration(prover, solver, refiner);
trace.add(iteration);
isFinalIteration = iteration.isConclusive();

// ----------- Debugging/Logging -----------
// ------------------------- Debugging/Logging -------------------------
if (REFINEMENT_GENERATE_GRAPHVIZ_DEBUG_FILES) {
generateGraphvizFiles(task, refineCtx.solver.getExecution(), trace.size(), iteration.inconsistencyReasons);
generateGraphvizFiles(task, solver.getExecution(), trace.size(), iteration.inconsistencyReasons);
}
if (logger.isDebugEnabled()) {
// ---- Internal SMT stats after the first iteration ----
Expand Down Expand Up @@ -360,12 +357,9 @@ Native solving time(ms): %s
return new RefinementTrace(trace);
}

private RefinementIteration doRefinementIteration(ProverEnvironment prover, RefinementContext refineCtx)
private RefinementIteration doRefinementIteration(ProverEnvironment prover, WMMSolver solver, Refiner refiner)
throws SolverException, InterruptedException {

final Refiner refiner = refineCtx.refiner;
final WMMSolver solver = refineCtx.solver;

long nativeTime = 0;
long caatTime = 0;
long refineTime = 0;
Expand Down Expand Up @@ -411,11 +405,12 @@ private RefinementIteration doRefinementIteration(ProverEnvironment prover, Refi
);
}

// ======================= Helper Methods ======================
// ================================================================================================================
// Special memory model processing

// This method cuts off negated relations that are dependencies of some
// consistency axiom. It ignores dependencies of flagged axioms, as those get
// eagarly encoded and can be completely ignored for Refinement.
// eagerly encoded and can be completely ignored for Refinement.
private static Set<Relation> cutRelationDifferences(Wmm targetWmm, Wmm baselineWmm) {
// TODO: Add support to move flagged axioms to the baselineWmm
Set<Relation> cutRelations = new HashSet<>();
Expand Down Expand Up @@ -491,7 +486,39 @@ private Relation[] copy(Relation[] r) {
}
}

// -------------------- Printing -----------------------------
private Wmm createDefaultWmm() {
Wmm baseline = new Wmm();
Relation rf = baseline.getRelation(RF);
if (baselines.contains(Baseline.UNIPROC)) {
// ---- acyclic(po-loc | com) ----
baseline.addConstraint(new Acyclic(baseline.addDefinition(new Union(baseline.newRelation(),
baseline.getRelation(POLOC),
rf,
baseline.getRelation(CO),
baseline.getRelation(FR)))));
}
if (baselines.contains(Baseline.NO_OOTA)) {
// ---- acyclic (dep | rf) ----
baseline.addConstraint(new Acyclic(baseline.addDefinition(new Union(baseline.newRelation(),
baseline.getRelation(CTRL),
baseline.getRelation(DATA),
baseline.getRelation(ADDR),
rf))));
}
if (baselines.contains(Baseline.ATOMIC_RMW)) {
// ---- empty (rmw & fre;coe) ----
Relation rmw = baseline.getRelation(RMW);
Relation coe = baseline.getRelation(COE);
Relation fre = baseline.getRelation(FRE);
Relation frecoe = baseline.addDefinition(new Composition(baseline.newRelation(), fre, coe));
Relation rmwANDfrecoe = baseline.addDefinition(new Intersection(baseline.newRelation(), rmw, frecoe));
baseline.addConstraint(new Empty(rmwANDfrecoe));
}
return baseline;
}

// ================================================================================================================
// Statistics & Debugging

private static CharSequence generateSummary(RefinementTrace trace, long boundCheckTime) {
final List<WMMSolver.Statistics> statList = trace.iterations.stream()
Expand Down Expand Up @@ -553,7 +580,7 @@ private static CharSequence generateCoverageReport(Set<Event> coveredEvents, Pro
final BranchEquivalence cf = analysisContext.requires(BranchEquivalence.class);

final Set<Event> programEvents = program.getThreadEvents(MemoryEvent.class).stream()
// TODO: Can we have events with source information but without parse id?
// TODO: Can we have events with source information but without oid?
.filter(e -> e.hasMetadata(SourceLocation.class) && e.hasMetadata(OriginalId.class))
.collect(Collectors.toSet());

Expand Down Expand Up @@ -644,35 +671,4 @@ private static void generateGraphvizFiles(VerificationTask task, ExecutionModel
generateGraphvizFile(model, iterationCount, (x, y) -> true, (x, y) -> true, (x, y) -> true, directoryName,
fileNameBase + "-full", emptySynContext);
}

private Wmm createDefaultWmm() {
Wmm baseline = new Wmm();
Relation rf = baseline.getRelation(RF);
if (baselines.contains(Baseline.UNIPROC)) {
// ---- acyclic(po-loc | com) ----
baseline.addConstraint(new Acyclic(baseline.addDefinition(new Union(baseline.newRelation(),
baseline.getRelation(POLOC),
rf,
baseline.getRelation(CO),
baseline.getRelation(FR)))));
}
if (baselines.contains(Baseline.NO_OOTA)) {
// ---- acyclic (dep | rf) ----
baseline.addConstraint(new Acyclic(baseline.addDefinition(new Union(baseline.newRelation(),
baseline.getRelation(CTRL),
baseline.getRelation(DATA),
baseline.getRelation(ADDR),
rf))));
}
if (baselines.contains(Baseline.ATOMIC_RMW)) {
// ---- empty (rmw & fre;coe) ----
Relation rmw = baseline.getRelation(RMW);
Relation coe = baseline.getRelation(COE);
Relation fre = baseline.getRelation(FRE);
Relation frecoe = baseline.addDefinition(new Composition(baseline.newRelation(), fre, coe));
Relation rmwANDfrecoe = baseline.addDefinition(new Intersection(baseline.newRelation(), rmw, frecoe));
baseline.addConstraint(new Empty(rmwANDfrecoe));
}
return baseline;
}
}

0 comments on commit 7158193

Please sign in to comment.