Skip to content

Commit

Permalink
Update BlockFlattener.scala
Browse files Browse the repository at this point in the history
  • Loading branch information
polgreen committed May 8, 2024
1 parent a3ef1d7 commit 9ce6dd4
Showing 1 changed file with 41 additions and 25 deletions.
66 changes: 41 additions & 25 deletions src/main/scala/uclid/lang/BlockFlattener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,39 +139,57 @@ class BlockFlattenerPass extends RewritePass {
(stmtsP, varDecls)
}

override def rewriteBlock(blkStmt : BlockStmt, context : Scope) : Option[Statement] = {
logger.debug("==> [%s] Input:\n%s".format(analysis.passName, blkStmt.toString()))
val init = (List.empty[Statement], Map.empty[Identifier, Type])
def addConcurrentVars (blkStmt : BlockStmt, context: Scope) : BlockStmt = {

val filteredStmts = blkStmt.stmts.filter(_.isInstanceOf[BlockStmt])
val nonSequentialBlockCount = filteredStmts.count(_.asInstanceOf[BlockStmt].isProcedural == false)

if(nonSequentialBlockCount!=filteredStmts.size)
UclidMain.printError("BlockFlattener: All blocks inside a block must be either procedural or non-procedural.")
if(filteredStmts.size != blkStmt.stmts.size)
UclidMain.printError("BlockFlattener: block contains blk statements and other statements")

if(nonSequentialBlockCount>1 && blkStmt.isProcedural)
val nonSequentialBlockCount = filteredStmts.count(_.asInstanceOf[BlockStmt].isProcedural == false)
logger.debug("Number of blocks: " + filteredStmts.size.toString())

if(!blkStmt.isProcedural && filteredStmts.size >1)
{
// get set of state vars that are read from.
// val readVars:
// get set of state vars that are assigned to
// val writeVars:

// create map of readVars and writeVars to temporary new vars
val reads = filteredStmts.foldLeft(Set.empty[Identifier]) {
(acc, blk) => {
val readSet = StatementScheduler.readSets(blk.asInstanceOf[BlockStmt].stmts, context)
acc ++ readSet
}
}.filter(id => context.map.contains(id) && context.map(id).isInstanceOf[Scope.StateVar])

// rewrite expressions in the block using ExprRewriter
// create new vars. We only need new variables for the reads because there should only be
// one write to a variable in a concurrent block. Blocks with more than one write will have been
// caught earlier
val varPairs: Map[Expr, Expr] =
reads.map(
id => (id.asInstanceOf[Expr] -> NameProvider.get("block_" + id.toString()).asInstanceOf[Expr])).toMap
logger.debug("New vars: " + varPairs.toString())

// add variable declarations to block vars
val rewriter = new ReadSetExprRewriter("BlockFlattener:Rewrite", varPairs)
val stmtsP = rewriter.rewriteStatements(blkStmt.stmts, context + blkStmt.vars)

// create statements assigning the readVars to the new variables.
// val readVarAssigns
// create variable declarations for the new read variables.
val vars = varPairs.map(p => BlockVarsDecl(List(p._2.asInstanceOf[Identifier]), context.map(p._1.asInstanceOf[Identifier]).asInstanceOf[Scope.StateVar].typ))
// create assign statements for the new variables
val readVarAssigns = varPairs.map(p => AssignStmt(List(LhsId(p._2.asInstanceOf[Identifier])), List(p._1.asInstanceOf[Expr]))).toList

// create statements assigning the new variables to the writeVars.
// val writeVarAssigns
// new block statement
val blkStmtP = BlockStmt(blkStmt.vars ++ vars, readVarAssigns ++ stmtsP, blkStmt.isProcedural)
logger.debug("New block statement:\n" + blkStmtP.toString())
blkStmtP
}
else{
blkStmt
}
}

override def rewriteBlock(blkStmt : BlockStmt, context : Scope) : Option[Statement] = {
logger.debug("==> [%s] Input:\n%s".format(analysis.passName, blkStmt.toString()))
val init = (List.empty[Statement], Map.empty[Identifier, Type])


val (stmtsP, mapOut) = blkStmt.stmts.foldLeft(init) {
val blkStmtP = addConcurrentVars(blkStmt, context)
val (stmtsP, mapOut) = blkStmtP.stmts.foldLeft(init) {
(acc, st) => {
val (stP, mapOut) = st match {
case blk : BlockStmt => renameBlock(blk, context, acc._2)
Expand All @@ -180,11 +198,9 @@ class BlockFlattenerPass extends RewritePass {
(acc._1 ++ stP, mapOut)
}
}
// this line might need to be changed if it is pulling in duplicate variables

val vars = mapOut.map(p => BlockVarsDecl(List(p._1), p._2))
// change the lines below to add the readVar and writeVars to the variable list
// and the readVarAssigns to the beginning of the block, and the writevar assigns to the end of the block.
val result = BlockStmt(blkStmt.vars ++ vars, stmtsP, blkStmt.isProcedural)
val result = BlockStmt(blkStmtP.vars ++ vars, stmtsP, blkStmt.isProcedural)
logger.debug("<== Result:\n" + result.toString())
Some(result)
}
Expand Down

0 comments on commit 9ce6dd4

Please sign in to comment.