diff --git a/src/main/scala/uclid/UclidMain.scala b/src/main/scala/uclid/UclidMain.scala index 1874ceeae..6adf31d67 100644 --- a/src/main/scala/uclid/UclidMain.scala +++ b/src/main/scala/uclid/UclidMain.scala @@ -468,6 +468,9 @@ object UclidMain { passManager.addPass(new ModuleTypeChecker()) // optimisation, has previously been called passManager.addPass(new SemanticAnalyzer()) + // reorder statements if necessary. + // Pass MUST be run after variable renamers + passManager.addPass(new BlockSorter()) // known bugs in the following passes if (config.enumToNumeric) passManager.addPass(new EnumTypeAnalysis()) if (config.enumToNumeric) passManager.addPass(new EnumTypeRenamer("BV")) diff --git a/src/main/scala/uclid/lang/BlockFlattener.scala b/src/main/scala/uclid/lang/BlockFlattener.scala index 85915e37d..81993b478 100644 --- a/src/main/scala/uclid/lang/BlockFlattener.scala +++ b/src/main/scala/uclid/lang/BlockFlattener.scala @@ -138,9 +138,11 @@ class BlockFlattenerPass extends RewritePass { val stmtsP = rewriter.rewriteStatements(blk.stmts, context + blk.vars) (stmtsP, varDecls) } + override def rewriteBlock(blkStmt : BlockStmt, context : Scope) : Option[Statement] = { logger.debug("==> [%s] Input:\n%s".format(analysis.passName, blkStmt.toString())) val init = (List.empty[Statement], Map.empty[Identifier, Type]) + val (stmtsP, mapOut) = blkStmt.stmts.foldLeft(init) { (acc, st) => { val (stP, mapOut) = st match { @@ -150,6 +152,7 @@ class BlockFlattenerPass extends RewritePass { (acc._1 ++ stP, mapOut) } } + val vars = mapOut.map(p => BlockVarsDecl(List(p._1), p._2)) val result = BlockStmt(blkStmt.vars ++ vars, stmtsP) logger.debug("<== Result:\n" + result.toString()) @@ -170,6 +173,61 @@ class BlockFlattener() extends ASTRewriter(BlockFlattener.getName(), new BlockFl override val repeatUntilNoChange = true } +// This pass changes the order of any statements in a block so that +// assignments to state variables are moved to the end of the block. +// This avoids issues where one submodule reads from a variable after another +// submodule has written to it, without introducing additional variables +// It must be run after the procedures have been inlined and converted into SSA. +class BlockSorterPass extends RewritePass { + lazy val logger = Logger(classOf[BlockSorterPass]) + + override def rewriteBlock(blkStmt : BlockStmt, context : Scope) : Option[Statement] = { + logger.debug("==> [%s] Input:\n%s".format(analysis.passName, blkStmt.toString())) + val init = (List.empty[Statement], Map.empty[Identifier, Type]) + + def isStateVarAssign(st : Statement) : Boolean = { + st match { + case AssignStmt(lhss, rhs) => { + lhss.exists { + case LhsId(id) => context.map.contains(id) && context.map(id).isInstanceOf[Scope.StateVar] && !(id.toString contains "_ucld_") + case _ => false + } + } + case _ => false + } + } + + val endStmts = blkStmt.stmts.filter(st => isStateVarAssign(st)) + logger.debug("Moving the following statements to the end of the block: " + endStmts.toString()) + val filteredStmts = blkStmt.stmts.filter(endStmts.contains(_)==false) + logger.debug("Moving these statements to the start of the block " + filteredStmts.toString()) + val result = BlockStmt(blkStmt.vars, filteredStmts++endStmts) + logger.debug("<== Result:\n" + result.toString()) + Some(result) + } +} + +object BlockSorter { + var index = 0 + def getName() : String = { + index += 1 + "BlockSorter:" + index.toString() + } +} + +class BlockSorter() extends ASTRewriter(BlockSorter.getName(), new BlockSorterPass()) +{ + override val repeatUntilNoChange = true + // Don't reorder procedural code + override def visitInit(init : InitDecl, context : Scope) : Option[InitDecl] = Some(init) + // Don't reorder procedural code + override def visitProcedure(proc : ProcedureDecl, contextIn : Scope) : Option[ProcedureDecl] = Some(proc); + +} + + + + object Optimizer { var index = 0 def getName() : String = { diff --git a/src/test/scala/VerifierSpec.scala b/src/test/scala/VerifierSpec.scala index f8c0104bc..bf698bac1 100644 --- a/src/test/scala/VerifierSpec.scala +++ b/src/test/scala/VerifierSpec.scala @@ -484,6 +484,9 @@ class ModuleVerifSpec extends AnyFlatSpec { "test-module-import-0.ucl" should "verify all assertions." in { VerifierSpec.expectedFails("./test/test-module-import-0.ucl", 0) } + "test-module-ordering.ucl" should "verify all assertions." in { + VerifierSpec.expectedFails("./test/test-module-ordering.ucl", 0) + } "test-type-import.ucl" should "verify all assertions." in { VerifierSpec.expectedFails("./test/test-type-import.ucl", 0) } diff --git a/test/test-module-ordering.ucl b/test/test-module-ordering.ucl new file mode 100644 index 000000000..6683ba88d --- /dev/null +++ b/test/test-module-ordering.ucl @@ -0,0 +1,45 @@ +module test { + input a : integer; + output b : integer; + + init { + b = 0; + } + + next { + b' = a+1; + } +} + +module main { + var x: integer; + var y: integer; + + // test1 reads in x and updates y'=x+1 + instance test1 : test(a : (x), b: (y)); + // test2 reads in y and updates x'=y+1 + instance test2 : test(a : (y), b: (x)); + + init { + x = 0; + y = 0; + + } + + next { + // both assertions should pass regardless of the ordering of these statements + next(test1); + next(test2); + } + + invariant test1_lt2: test1.b < 2; + invariant test2lt2: test2.b < 2; + + control { + print_module; + v = bmc(1); + check; + print_results; + v.print_cex; + } +}