Skip to content

Commit

Permalink
update IsRewritable in RewritePolymorphicSelect to handle functions; …
Browse files Browse the repository at this point in the history
…add test case to show off difference
  • Loading branch information
FedericoAureliano authored and polgreen committed Jul 4, 2024
1 parent be53fa4 commit d506959
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 32 deletions.
78 changes: 46 additions & 32 deletions src/main/scala/uclid/lang/RewritePolymorphicSelect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,35 +95,49 @@ class RewritePolymorphicSelectPass extends RewritePass {
}
}

def IsRewritable(opapp : OperatorApplication, context:Scope): Boolean = {
opapp.op match {
case PolymorphicSelect(id) =>
val expr = opapp.operands(0)
def IsRewritable(candidate: Expr, context: Scope): Boolean = {
candidate match {
case OperatorApplication(PolymorphicSelect(id), operands) =>
val expr = operands(0)
expr match {
case arg : Identifier => isVarState(arg,id,context)||isVarInModule(id,arg,context)
case subopp: OperatorApplication =>{
if(IsRewritable(subopp,context))
case arg : Identifier =>
isVarState(arg,id,context)||isVarInModule(id,arg,context)
case OperatorApplication(_, _) | FuncApplication(_, _) => {
if(IsRewritable(expr,context)) {
true
}
else
{
val LastInstance = getLastInstance(subopp,context);
val LastInstance = getLastInstance(expr, context);
LastInstance match{
case mid : Identifier => isVarInModule(id,mid,context)
case mid : Identifier =>
isVarInModule(id,mid,context)
case _ => false
}
}
}
case _ => false
}
case _ => {
val expr = opapp.operands(0)
case OperatorApplication(op, operands) => {
val expr = operands(0)
expr match {
case arg : Identifier => isVarState(arg,Identifier(""),context)
case arg : Identifier =>
isVarState(arg,Identifier(""),context)
case subopp: OperatorApplication =>
IsRewritable(subopp,context)
case _ => false
}
}
case FuncApplication(func, operands) => {
func match {
case arg : Identifier =>
isVarState(arg,Identifier(""),context)
case subopp: OperatorApplication =>
IsRewritable(subopp,context)
case _ => false
}
}
case _ => false
}
}

Expand Down Expand Up @@ -155,18 +169,18 @@ class RewritePolymorphicSelectPass extends RewritePass {
}
}

def getLastInstance(opapp : OperatorApplication, context:Scope): Expr ={
opapp.op match {
case PolymorphicSelect(id) =>{
opapp.operands(0) match {
def getLastInstance(target: Expr, context: Scope): Expr ={
target match {
case OperatorApplication(PolymorphicSelect(id), operands) => {
operands(0) match {
case arg : Identifier =>
{
context.map.get(arg) match
{
case Some(module:Scope.ModuleDefinition) => {
checkIdDecl(module.mod.decls,id) match{
case Some(ident) => ident
case _ => opapp
case _ => target
}
}
case Some(Scope.Instance(instD)) => {
Expand All @@ -176,13 +190,13 @@ class RewritePolymorphicSelectPass extends RewritePass {

checkIdDecl(module.mod.decls,id) match{
case Some(ident) => ident
case _ => opapp
case _ => target
}
}
case _ => opapp
case _ => target
}
}
case _ => opapp
case _ => target
}
}
case subopp: OperatorApplication =>{
Expand All @@ -195,7 +209,7 @@ class RewritePolymorphicSelectPass extends RewritePass {
case Some(module:Scope.ModuleDefinition) => {
checkIdDecl(module.mod.decls,id) match{
case Some(ident) => ident
case _ => opapp
case _ => target
}
}
case Some(Scope.Instance(instD)) => {
Expand All @@ -204,29 +218,29 @@ class RewritePolymorphicSelectPass extends RewritePass {
{
checkIdDecl(module.mod.decls,id) match{
case Some(ident) => ident
case _ => opapp
case _ => target
}
}
case _ => opapp
case _ => target
}
}
case _ => opapp
case _ => target
}
}
case _ => opapp
case _ => target
}
}
case _ => opapp
case _ => target
}
}
case GetNextValueOp() =>
val expr = opapp.operands(0)
case OperatorApplication(GetNextValueOp(), operands) =>
val expr = operands(0)
expr match {
case subopp: OperatorApplication =>
getLastInstance(subopp,context)
case _ => opapp
case _ => target
}
case _ => opapp
case _ => target
}
}

Expand Down Expand Up @@ -257,14 +271,14 @@ class RewritePolymorphicSelectPass extends RewritePass {
}
}

def isVarState(arg: Identifier,id:Identifier,context:Scope): Boolean = {
def isVarState(arg: Identifier, id: Identifier, context: Scope): Boolean = {
UclidMain.printDebugRewriteRecord("We are going to check "+arg+"\n")
UclidMain.printDebugRewriteRecord("its type is "+context.map.get(arg)+"\n")
context.map.get(arg) match{
case Some(Scope.ProcedureInputArg(_,_)) | Some(Scope.StateVar(_,_)) | Some(Scope.ProcedureOutputArg(_,_))|
Some(Scope.BlockVar(_,_)) | Some(Scope.FunctionArg(_,_)) | Some(Scope.LambdaVar(_,_))|
Some(Scope.InputVar(_,_)) | Some(Scope.OutputVar(_,_)) | Some(Scope.SharedVar(_,_)) |
Some(Scope.ConstantVar(_,_)) | Some(Scope.SelectorField(_))
Some(Scope.ConstantVar(_,_)) | Some(Scope.SelectorField(_)) | Some(Scope.Function(_,_))
=>{
if(id.toString.startsWith("_") && id.toString.substring(1).forall(Character.isDigit))
false
Expand Down
6 changes: 6 additions & 0 deletions src/test/scala/ParserSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,12 @@ class ParserSpec extends AnyFlatSpec {
assert (instantiatedModules.size == 1)
}

"test-rewrite-polymorphic-select.ucl" should "parse successfully." in {
val fileModules = UclidMain.compile(ConfigCons.createConfig("test/test-rewrite-polymorphic-select.ucl"), lang.Identifier("main"))
val instantiatedModules = UclidMain.instantiateModules(UclidMain.Config(), fileModules, lang.Identifier("main"))
assert (instantiatedModules.size == 1)
}

"test-array-record.ucl" should "parse successfully." in {
val fileModules = UclidMain.compile(ConfigCons.createConfig("test/test-array-record.ucl"), lang.Identifier("main"))
val instantiatedModules = UclidMain.instantiateModules(UclidMain.Config(), fileModules, lang.Identifier("main"))
Expand Down
8 changes: 8 additions & 0 deletions test/test-rewrite-polymorphic-select.ucl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module main {
type r1 = record {x: integer};
type r2 = record {y: r1};
function z(i: integer): r2;
init {
assume(z(100).y.x == 0);
}
}

0 comments on commit d506959

Please sign in to comment.