From ef9133a5bea87fc5bba5d2360f26e4c6fe701422 Mon Sep 17 00:00:00 2001 From: Elizabeth Polgreen Date: Fri, 3 May 2024 15:33:17 +0100 Subject: [PATCH] add new pass to sort blocks --- src/main/scala/uclid/UclidMain.scala | 3 + .../scala/uclid/lang/BlockFlattener.scala | 58 ++++++++++++++++--- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/src/main/scala/uclid/UclidMain.scala b/src/main/scala/uclid/UclidMain.scala index 1874ceea..6adf31d6 100644 --- a/src/main/scala/uclid/UclidMain.scala +++ b/src/main/scala/uclid/UclidMain.scala @@ -468,6 +468,9 @@ object UclidMain { passManager.addPass(new ModuleTypeChecker()) // optimisation, has previously been called passManager.addPass(new SemanticAnalyzer()) + // reorder statements if necessary. + // Pass MUST be run after variable renamers + passManager.addPass(new BlockSorter()) // known bugs in the following passes if (config.enumToNumeric) passManager.addPass(new EnumTypeAnalysis()) if (config.enumToNumeric) passManager.addPass(new EnumTypeRenamer("BV")) diff --git a/src/main/scala/uclid/lang/BlockFlattener.scala b/src/main/scala/uclid/lang/BlockFlattener.scala index 8e9051aa..d630438b 100644 --- a/src/main/scala/uclid/lang/BlockFlattener.scala +++ b/src/main/scala/uclid/lang/BlockFlattener.scala @@ -138,15 +138,10 @@ class BlockFlattenerPass extends RewritePass { val stmtsP = rewriter.rewriteStatements(blk.stmts, context + blk.vars) (stmtsP, varDecls) } + override def rewriteBlock(blkStmt : BlockStmt, context : Scope) : Option[Statement] = { logger.debug("==> [%s] Input:\n%s".format(analysis.passName, blkStmt.toString())) val init = (List.empty[Statement], Map.empty[Identifier, Type]) - - val endStmts = blkStmt.stmts.filter( - st => st.isInstanceOf[AssignStmt] && - st.asInstanceOf[AssignStmt].lhss.head.isInstanceOf[LhsId] - && context.map.contains(st.asInstanceOf[AssignStmt].lhss.head.asInstanceOf[LhsId].id)) - logger.debug("Moving the following statements to the end of the block: " + endStmts.toString()) val (stmtsP, mapOut) = blkStmt.stmts.foldLeft(init) { (acc, st) => { @@ -157,10 +152,9 @@ class BlockFlattenerPass extends RewritePass { (acc._1 ++ stP, mapOut) } } - val filteredStmts = stmtsP.filter(endStmts.contains(_)==false) - logger.debug("Moving these statements to the start of the block " + filteredStmts.toString()) + val vars = mapOut.map(p => BlockVarsDecl(List(p._1), p._2)) - val result = BlockStmt(blkStmt.vars ++ vars, filteredStmts++endStmts) + val result = BlockStmt(blkStmt.vars ++ vars, stmtsP) logger.debug("<== Result:\n" + result.toString()) Some(result) } @@ -179,6 +173,52 @@ class BlockFlattener() extends ASTRewriter(BlockFlattener.getName(), new BlockFl override val repeatUntilNoChange = true } + +class BlockSorterPass extends RewritePass { + lazy val logger = Logger(classOf[BlockFlattenerPass]) + + override def rewriteBlock(blkStmt : BlockStmt, context : Scope) : Option[Statement] = { + logger.debug("==> [%s] Input:\n%s".format(analysis.passName, blkStmt.toString())) + val init = (List.empty[Statement], Map.empty[Identifier, Type]) + + def isStateVarAssign(st : Statement) : Boolean = { + st match { + case AssignStmt(lhss, rhs) => { + lhss.exists { + case LhsId(id) => context.map.contains(id) && context.map(id).isInstanceOf[Scope.StateVar] + case _ => false + } + } + case _ => false + } + } + + val endStmts = blkStmt.stmts.filter(st => isStateVarAssign(st)) + logger.debug("Moving the following statements to the end of the block: " + endStmts.toString()) + val filteredStmts = blkStmt.stmts.filter(endStmts.contains(_)==false) + logger.debug("Moving these statements to the start of the block " + filteredStmts.toString()) + val result = BlockStmt(blkStmt.vars, filteredStmts++endStmts) + logger.debug("<== Result:\n" + result.toString()) + Some(result) + } +} + +object BlockSorter { + var index = 0 + def getName() : String = { + index += 1 + "BlockSorter:" + index.toString() + } +} + +class BlockSorter() extends ASTRewriter(BlockSorter.getName(), new BlockSorterPass()) +{ + override val repeatUntilNoChange = true +} + + + + object Optimizer { var index = 0 def getName() : String = {