diff --git a/src/main/scala/uclid/lang/RewritePolymorphicSelect.scala b/src/main/scala/uclid/lang/RewritePolymorphicSelect.scala index eb86d07d..6534a7ac 100644 --- a/src/main/scala/uclid/lang/RewritePolymorphicSelect.scala +++ b/src/main/scala/uclid/lang/RewritePolymorphicSelect.scala @@ -95,35 +95,49 @@ class RewritePolymorphicSelectPass extends RewritePass { } } - def IsRewritable(opapp : OperatorApplication, context:Scope): Boolean = { - opapp.op match { - case PolymorphicSelect(id) => - val expr = opapp.operands(0) + def IsRewritable(candidate: Expr, context: Scope): Boolean = { + candidate match { + case OperatorApplication(PolymorphicSelect(id), operands) => + val expr = operands(0) expr match { - case arg : Identifier => isVarState(arg,id,context)||isVarInModule(id,arg,context) - case subopp: OperatorApplication =>{ - if(IsRewritable(subopp,context)) + case arg : Identifier => + isVarState(arg,id,context)||isVarInModule(id,arg,context) + case OperatorApplication(_, _) | FuncApplication(_, _) => { + if(IsRewritable(expr,context)) { true + } else { - val LastInstance = getLastInstance(subopp,context); + val LastInstance = getLastInstance(expr, context); LastInstance match{ - case mid : Identifier => isVarInModule(id,mid,context) + case mid : Identifier => + isVarInModule(id,mid,context) case _ => false } } } case _ => false } - case _ => { - val expr = opapp.operands(0) + case OperatorApplication(op, operands) => { + val expr = operands(0) expr match { - case arg : Identifier => isVarState(arg,Identifier(""),context) + case arg : Identifier => + isVarState(arg,Identifier(""),context) case subopp: OperatorApplication => IsRewritable(subopp,context) case _ => false } } + case FuncApplication(func, operands) => { + func match { + case arg : Identifier => + isVarState(arg,Identifier(""),context) + case subopp: OperatorApplication => + IsRewritable(subopp,context) + case _ => false + } + } + case _ => false } } @@ -155,10 +169,10 @@ class RewritePolymorphicSelectPass extends RewritePass { } } - def getLastInstance(opapp : OperatorApplication, context:Scope): Expr ={ - opapp.op match { - case PolymorphicSelect(id) =>{ - opapp.operands(0) match { + def getLastInstance(target: Expr, context: Scope): Expr ={ + target match { + case OperatorApplication(PolymorphicSelect(id), operands) => { + operands(0) match { case arg : Identifier => { context.map.get(arg) match @@ -166,7 +180,7 @@ class RewritePolymorphicSelectPass extends RewritePass { case Some(module:Scope.ModuleDefinition) => { checkIdDecl(module.mod.decls,id) match{ case Some(ident) => ident - case _ => opapp + case _ => target } } case Some(Scope.Instance(instD)) => { @@ -176,13 +190,13 @@ class RewritePolymorphicSelectPass extends RewritePass { checkIdDecl(module.mod.decls,id) match{ case Some(ident) => ident - case _ => opapp + case _ => target } } - case _ => opapp + case _ => target } } - case _ => opapp + case _ => target } } case subopp: OperatorApplication =>{ @@ -195,7 +209,7 @@ class RewritePolymorphicSelectPass extends RewritePass { case Some(module:Scope.ModuleDefinition) => { checkIdDecl(module.mod.decls,id) match{ case Some(ident) => ident - case _ => opapp + case _ => target } } case Some(Scope.Instance(instD)) => { @@ -204,29 +218,29 @@ class RewritePolymorphicSelectPass extends RewritePass { { checkIdDecl(module.mod.decls,id) match{ case Some(ident) => ident - case _ => opapp + case _ => target } } - case _ => opapp + case _ => target } } - case _ => opapp + case _ => target } } - case _ => opapp + case _ => target } } - case _ => opapp + case _ => target } } - case GetNextValueOp() => - val expr = opapp.operands(0) + case OperatorApplication(GetNextValueOp(), operands) => + val expr = operands(0) expr match { case subopp: OperatorApplication => getLastInstance(subopp,context) - case _ => opapp + case _ => target } - case _ => opapp + case _ => target } } @@ -257,14 +271,14 @@ class RewritePolymorphicSelectPass extends RewritePass { } } - def isVarState(arg: Identifier,id:Identifier,context:Scope): Boolean = { + def isVarState(arg: Identifier, id: Identifier, context: Scope): Boolean = { UclidMain.printDebugRewriteRecord("We are going to check "+arg+"\n") UclidMain.printDebugRewriteRecord("its type is "+context.map.get(arg)+"\n") context.map.get(arg) match{ case Some(Scope.ProcedureInputArg(_,_)) | Some(Scope.StateVar(_,_)) | Some(Scope.ProcedureOutputArg(_,_))| Some(Scope.BlockVar(_,_)) | Some(Scope.FunctionArg(_,_)) | Some(Scope.LambdaVar(_,_))| Some(Scope.InputVar(_,_)) | Some(Scope.OutputVar(_,_)) | Some(Scope.SharedVar(_,_)) | - Some(Scope.ConstantVar(_,_)) | Some(Scope.SelectorField(_)) + Some(Scope.ConstantVar(_,_)) | Some(Scope.SelectorField(_)) | Some(Scope.Function(_,_)) =>{ if(id.toString.startsWith("_") && id.toString.substring(1).forall(Character.isDigit)) false diff --git a/src/test/scala/ParserSpec.scala b/src/test/scala/ParserSpec.scala index d958967f..6617d0fb 100644 --- a/src/test/scala/ParserSpec.scala +++ b/src/test/scala/ParserSpec.scala @@ -761,6 +761,12 @@ class ParserSpec extends AnyFlatSpec { assert (instantiatedModules.size == 1) } + "test-rewrite-polymorphic-select.ucl" should "parse successfully." in { + val fileModules = UclidMain.compile(ConfigCons.createConfig("test/test-rewrite-polymorphic-select.ucl"), lang.Identifier("main")) + val instantiatedModules = UclidMain.instantiateModules(UclidMain.Config(), fileModules, lang.Identifier("main")) + assert (instantiatedModules.size == 1) + } + "test-array-record.ucl" should "parse successfully." in { val fileModules = UclidMain.compile(ConfigCons.createConfig("test/test-array-record.ucl"), lang.Identifier("main")) val instantiatedModules = UclidMain.instantiateModules(UclidMain.Config(), fileModules, lang.Identifier("main")) diff --git a/test/test-rewrite-polymorphic-select.ucl b/test/test-rewrite-polymorphic-select.ucl new file mode 100644 index 00000000..69877ecd --- /dev/null +++ b/test/test-rewrite-polymorphic-select.ucl @@ -0,0 +1,8 @@ +module main { + type r1 = record {x: integer}; + type r2 = record {y: r1}; + function z(i: integer): r2; + init { + assume(z(100).y.x == 0); + } +} \ No newline at end of file