Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blockstatement ordering #238

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/main/scala/uclid/SymbolicSimulator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1868,7 +1868,7 @@ class SymbolicSimulator (module : Module) {
case AssignStmt(lhss,rhss) =>
val es = rhss.map(i => evaluate(i, symbolTable, frameTable, frameNumber, scope));
return simulateAssign(lhss, es, symbolTable, label)
case BlockStmt(vars, stmts) =>
case BlockStmt(vars, stmts, _ ) =>
val declaredVars = vars.flatMap(vs => vs.ids.map(v => (v, vs.typ)))
val initSymbolTable = symbolTable
val localSymbolTable = declaredVars.foldLeft(initSymbolTable) {
Expand Down Expand Up @@ -1929,7 +1929,7 @@ class SymbolicSimulator (module : Module) {
}
case AssignStmt(lhss,rhss) =>
return lhss.map(lhs => lhs.ident).toSet
case BlockStmt(vars, stmts) =>
case BlockStmt(vars, stmts, _) =>
val declaredVars : Set[Identifier] = vars.flatMap(vs => vs.ids.map(id => id)).toSet
return writeSets(stmts) -- declaredVars
case IfElseStmt(e,then_branch,else_branch) =>
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/uclid/UclidMain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ object UclidMain {
val newInitDecl = initAccDecls match {
case Some(initAcc) => initModuleDecls match {
case Some(initMod) => List(InitDecl(BlockStmt(List[BlockVarsDecl](),
List(initAcc.asInstanceOf[InitDecl].body, initMod.asInstanceOf[InitDecl].body))))
List(initAcc.asInstanceOf[InitDecl].body, initMod.asInstanceOf[InitDecl].body), true)))
case None => List(initAcc)
}
case None => initModuleDecls match {
Expand Down Expand Up @@ -468,6 +468,9 @@ object UclidMain {
passManager.addPass(new ModuleTypeChecker())
// optimisation, has previously been called
passManager.addPass(new SemanticAnalyzer())
// reorder statements if necessary.
// Pass MUST be run after variable renamers
//passManager.addPass(new BlockSorter())
// known bugs in the following passes
if (config.enumToNumeric) passManager.addPass(new EnumTypeAnalysis())
if (config.enumToNumeric) passManager.addPass(new EnumTypeRenamer("BV"))
Expand Down
42 changes: 42 additions & 0 deletions src/main/scala/uclid/lang/ASTVisitorUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,48 @@ class ExprRewriter(name: String, rewrites : Map[Expr, Expr])
}
}

// used to rewrite only the read expressions
class ReadSetExprRewriter(name: String, rewrites : Map[Expr, Expr])
extends ASTRewriter(name, new ExprRewriterPass(rewrites))
{
def rewriteExpr(e : Expr, context : Scope) : Expr = {
e match {
case OperatorApplication(OldOperator(), _) => e
case OperatorApplication(HistoryOperator(), _) => e
case _ => visitExpr(e, context).get
}
}

def rewriteStatements(stmts : List[Statement], context : Scope) : List[Statement] = {
return stmts.flatMap(visitStatement(_, context))
}

def rewriteStatement(stmt : Statement, context : Scope) : Option[Statement] = {
visitStatement(stmt, context)
}

// do nothing for the LHS
override def visitLhs(lhs: Lhs, context: Scope): Option[Lhs] = Some(lhs)

override def visitOperatorApp(opapp : OperatorApplication, context : Scope) : Option[Expr] = {

opapp match {
case OperatorApplication(HistoryOperator(), _) => {
Some(opapp)
}
case OperatorApplication(OldOperator(), _) => Some(opapp)
case _ => {
val opAppP = visitOperator(opapp.op, context).flatMap((op) => {
pass.rewriteOperatorApp(OperatorApplication(op, opapp.operands.map(visitExpr(_, context + opapp)).flatten), context)
})
return ASTNode.introducePos(true, true, opAppP, opapp.position) }
}
}

}



// This class has been modified to handle the abstract class: ModifiableEntity.
class OldExprRewriterPass(rewrites : Map[ModifiableEntity, Identifier]) extends RewritePass
{
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/uclid/lang/ASTVistors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1991,7 +1991,7 @@ class ASTRewriter (_passName : String, _pass: RewritePass, setFilename : Boolean
log.debug("visitBlockStatement\n{}", Utils.join(blkStmt.toLines, "\n"))
val contextP = context + blkStmt.vars
val varsP = blkStmt.vars.map(v => visitBlockVars(v, contextP)).flatten
val blkStmtP1 = BlockStmt(varsP, blkStmt.stmts.flatMap(st => visitStatement(st, contextP)))
val blkStmtP1 = BlockStmt(varsP, blkStmt.stmts.flatMap(st => visitStatement(st, contextP)), blkStmt.isProcedural)
val blkStmtP = pass.rewriteBlock(blkStmtP1, context)
return ASTNode.introducePos(setPosition, setFilename, blkStmtP, blkStmt.position)
}
Expand Down
66 changes: 63 additions & 3 deletions src/main/scala/uclid/lang/BlockFlattener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ package uclid
package lang

import com.typesafe.scalalogging.Logger
import java.util.jar.Attributes.Name


class BlockVariableRenamerPass extends RewritePass {
def renameVarList (vars : List[(Identifier, Type)], context : Scope) : List[(Identifier, Identifier, Type)] = {
Expand All @@ -63,7 +65,7 @@ class BlockVariableRenamerPass extends RewritePass {
val rewriter = new ExprRewriter("BlockVariableRenamerPass:Block", rewriteMap)
val statementsP = rewriter.rewriteStatements(blkStmt.stmts, context + blkStmt.vars)
val varsP = varTuples.map(p => BlockVarsDecl(List(p._2), p._3))
Some(BlockStmt(varsP, statementsP))
Some(BlockStmt(varsP, statementsP, blkStmt.isProcedural))
}
override def rewriteProcedure(proc : ProcedureDecl, context : Scope) : Option[ProcedureDecl] = {
val argTuples = renameVarList(proc.sig.inParams, context)
Expand Down Expand Up @@ -138,10 +140,65 @@ class BlockFlattenerPass extends RewritePass {
val stmtsP = rewriter.rewriteStatements(blk.stmts, context + blk.vars)
(stmtsP, varDecls)
}

def addConcurrentVars (blkStmt : BlockStmt, context: Scope) : BlockStmt = {
val filteredStmts = blkStmt.stmts.filter(_.isInstanceOf[BlockStmt])

if(filteredStmts.size != blkStmt.stmts.size)
logger.debug("BlockFlattener: block contains blk statements and other statements")

val nonSequentialBlockCount = filteredStmts.count(_.asInstanceOf[BlockStmt].isProcedural == false)
logger.debug("Number of blocks: " + filteredStmts.size.toString())

if(!blkStmt.isProcedural && filteredStmts.size >1)
{
val reads = filteredStmts.foldLeft(Set.empty[Identifier]) {
(acc, blk) => {
val readSet = StatementScheduler.readSets(blk.asInstanceOf[BlockStmt].stmts, context)
acc ++ readSet
}
}.filter(id => context.map.contains(id) && context.map(id).isInstanceOf[Scope.StateVar] && !id.name.startsWith("__ucld"))

val writes = filteredStmts.foldLeft(Set.empty[Identifier]) {
(acc, blk) => {
val writeSet = StatementScheduler.writeSets(blk.asInstanceOf[BlockStmt].stmts, context)
acc ++ writeSet
}
}.filter(id => context.map.contains(id) && context.map(id).isInstanceOf[Scope.StateVar])

// create new vars. We only need new variables for the reads that are also written to
// because there should only be
// one write to a variable in a concurrent block. Blocks with more than one write will have been
// caught earlier
val varPairs: Map[Expr, Expr] =
reads.intersect(writes).map(
id => (id.asInstanceOf[Expr] -> NameProvider.get("block_" + id.toString()).asInstanceOf[Expr])).toMap
logger.debug("New vars: " + varPairs.toString())

val rewriter = new ReadSetExprRewriter("BlockFlattener:Rewrite", varPairs)
val stmtsP = rewriter.rewriteStatements(blkStmt.stmts, context + blkStmt.vars)

// create variable declarations for the new read variables.
val vars = varPairs.map(p => BlockVarsDecl(List(p._2.asInstanceOf[Identifier]), context.map(p._1.asInstanceOf[Identifier]).asInstanceOf[Scope.StateVar].typ))
// create assign statements for the new variables
val readVarAssigns = varPairs.map(p => AssignStmt(List(LhsId(p._2.asInstanceOf[Identifier])), List(p._1.asInstanceOf[Expr]))).toList

// new block statement
val blkStmtP = BlockStmt(blkStmt.vars ++ vars, readVarAssigns ++ stmtsP, blkStmt.isProcedural)
logger.debug("New block statement:\n" + blkStmtP.toString())
blkStmtP
}
else{
blkStmt
}
}

override def rewriteBlock(blkStmt : BlockStmt, context : Scope) : Option[Statement] = {
logger.debug("==> [%s] Input:\n%s".format(analysis.passName, blkStmt.toString()))
val init = (List.empty[Statement], Map.empty[Identifier, Type])
val (stmtsP, mapOut) = blkStmt.stmts.foldLeft(init) {

val blkStmtP = addConcurrentVars(blkStmt, context)
val (stmtsP, mapOut) = blkStmtP.stmts.foldLeft(init) {
(acc, st) => {
val (stP, mapOut) = st match {
case blk : BlockStmt => renameBlock(blk, context, acc._2)
Expand All @@ -150,8 +207,9 @@ class BlockFlattenerPass extends RewritePass {
(acc._1 ++ stP, mapOut)
}
}

val vars = mapOut.map(p => BlockVarsDecl(List(p._1), p._2))
val result = BlockStmt(blkStmt.vars ++ vars, stmtsP)
val result = BlockStmt(blkStmtP.vars ++ vars, stmtsP, blkStmt.isProcedural)
logger.debug("<== Result:\n" + result.toString())
Some(result)
}
Expand All @@ -170,6 +228,8 @@ class BlockFlattener() extends ASTRewriter(BlockFlattener.getName(), new BlockFl
override val repeatUntilNoChange = true
}



object Optimizer {
var index = 0
def getName() : String = {
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/uclid/lang/LoopUnroller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ForLoopRewriterPass(forStmtsToRewrite: Set[ForStmt]) extends RewritePass {
rewriter.rewriteStatement(st.body, ctx)
}
val stmts = (low to high).foldLeft(List.empty[Statement])((acc, i) => acc ++ rewriteForValue(i).toList)
Some(BlockStmt(List.empty, stmts))
Some(BlockStmt(List.empty, stmts, true))
} else {
Some(st)
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/uclid/lang/MacroRewriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class MacroReplacerPass(macroId : Identifier, newMacroBody : BlockStmt) extends
case _ =>
}
}
BlockStmt(st.vars, leftStmts)
BlockStmt(st.vars, leftStmts, false)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/uclid/lang/ModSetAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ModSetRewriterPass() extends RewritePass {
* @param modSetMap The modifies set map inferred by the ModSetAnalysis pass. Should contain a map from procedures to thier inferred modifies sets.
*/
def getStmtModSet(stmt: Statement, modSetMap: Map[Identifier, Set[ModifiableEntity]], varIdSet: Set[Identifier], locVarIdSet: Set[Identifier]): Set[ModifiableEntity] = stmt match {
case BlockStmt(vars, stmts) => {
case BlockStmt(vars, stmts,_) => {
val locVarIdSetP = vars.foldLeft(locVarIdSet)((acc, bvd) => acc ++ bvd.ids.toSet)
stmts.foldLeft(Set.empty[ModifiableEntity])((acc, stmt) => acc ++ getStmtModSet(stmt, modSetMap, varIdSet, locVarIdSetP))
}
Expand Down
Loading
Loading