diff --git a/src/main/scala/uclid/lang/ASTVistors.scala b/src/main/scala/uclid/lang/ASTVistors.scala index 7db3684ab..1b8eee95e 100644 --- a/src/main/scala/uclid/lang/ASTVistors.scala +++ b/src/main/scala/uclid/lang/ASTVistors.scala @@ -123,6 +123,9 @@ trait ReadOnlyPass[T] { def applyOnEnumType(d : TraversalDirection.T, enumT : EnumType, in : T, context : Scope) : T = { in } def applyOnTupleType(d : TraversalDirection.T, tupleT : TupleType, in : T, context : Scope) : T = { in } def applyOnRecordType(d : TraversalDirection.T, recordT : RecordType, in : T, context : Scope) : T = { in } + def applyOnDataType(d : TraversalDirection.T, dataT : DataType, in : T, context : Scope) : T = { in } + def applyOnConstructor(d : TraversalDirection.T, constructor : ConstructorType, in : T, context : Scope) : T = { in } + def applyOnSelector(d : TraversalDirection.T, selector : (Identifier, Type), in : T, context : Scope) : T = { in } def applyOnMapType(d : TraversalDirection.T, mapT : MapType, in : T, context : Scope) : T = { in } def applyOnProcedureType(d : TraversalDirection.T, procT : ProcedureType, in : T, context : Scope) : T = { in } def applyOnArrayType(d : TraversalDirection.T, arrayT : ArrayType, in : T, context : Scope) : T = { in } @@ -229,6 +232,9 @@ trait RewritePass { def rewriteIntType(intT : IntegerType, context : Scope) : Option[IntegerType] = { Some(intT) } def rewriteBitVectorType(bvT : BitVectorType, context : Scope) : Option[BitVectorType] = { Some(bvT) } def rewriteEnumType(enumT : EnumType, context : Scope) : Option[EnumType] = { Some(enumT) } + def rewriteSelector(sel : (Identifier, Type), context : Scope) : Option[(Identifier, Type)] = { Some(sel) } + def rewriteConstructor(cstor : ConstructorType, context : Scope) : Option[ConstructorType] = { Some(cstor) } + def rewriteDataType(dataT : DataType, context : Scope) : Option[DataType] = { Some(dataT) } def rewriteTupleType(tupleT : TupleType, context : Scope) : Option[TupleType] = { Some(tupleT) } def rewriteRecordType(recordT : RecordType, context : Scope) : Option[RecordType] = { Some(recordT) } def rewriteMapType(mapT : MapType, context : Scope) : Option[MapType] = { Some(mapT) } @@ -652,6 +658,8 @@ class ASTAnalyzer[T] (_passName : String, _pass: ReadOnlyPass[T]) extends ASTAna case groupT : GroupType => visitGroupType(groupT, result, context) case floatT: FloatType => visitFloatType(floatT, result, context) case realT: RealType => visitRealType(realT, result, context) + case dataT: DataType => visitDataType(dataT, result, context) + case constT: ConstructorType => visitConstructor(constT, result, context) } result = pass.applyOnType(TraversalDirection.Up, typ, result, context) return result @@ -767,6 +775,39 @@ class ASTAnalyzer[T] (_passName : String, _pass: ReadOnlyPass[T]) extends ASTAna return result } + def visitConstructor(cstor : ConstructorType, in : T, context : Scope) : T = { + var result : T = in + result = pass.applyOnConstructor(TraversalDirection.Down, cstor, result, context) + result = visitIdentifier(cstor.id, result, context) + result = cstor.inTypes.foldLeft(result)((r, sel) => visitSelector(sel, r, context)) + result = pass.applyOnConstructor(TraversalDirection.Up, cstor, result, context) + return result + } + + def visitSelector(sel : (Identifier, Type), in : T, context : Scope) : T = { + var result : T = in + result = pass.applyOnSelector(TraversalDirection.Down, sel, result, context) + result = visitIdentifier(sel._1, result, context) + result = visitType(sel._2, result, context) + result = pass.applyOnSelector(TraversalDirection.Up, sel, result, context) + return result + } + + def visitDataType(dataT : DataType, in: T, context: Scope): T = { + var result: T = in + result = pass.applyOnDataType(TraversalDirection.Down, dataT, result, context) + result = visitIdentifier(dataT.id, result, context) + result = dataT.constructors.foldLeft(result)((r, constructor) => { + val r2 = visitIdentifier(constructor._1, r, context) + constructor._2.foldLeft(r2)((r, s) => { + val r3 = visitIdentifier(s._1, r, context) + visitType(s._2, r3, context) + }) + }) + result = pass.applyOnDataType(TraversalDirection.Up, dataT, result, context) + return result + } + def visitExternalType(extT : ExternalType, in : T, context : Scope) : T = { var result : T = in result = pass.applyOnExternalType(TraversalDirection.Down, extT, result, context) @@ -1731,6 +1772,8 @@ class ASTRewriter (_passName : String, _pass: RewritePass, setFilename : Boolean case groupT : GroupType => visitGroupType(groupT, context) case floatT : FloatType => visitFloatType(floatT, context) case realT : RealType => visitRealType(realT, context) + case dataT : DataType => visitDataType(dataT, context) + case cstor : ConstructorType => visitConstructor(cstor, context) }).flatMap(pass.rewriteType(_, context)) return ASTNode.introducePos(setPosition, setFilename, typP, typ.position) } @@ -1835,6 +1878,55 @@ class ASTRewriter (_passName : String, _pass: RewritePass, setFilename : Boolean return ASTNode.introducePos(setPosition, setFilename, realTP, realT.position) } + def visitSelector(sel : (Identifier, Type), context : Scope) : Option[(Identifier, Type)] = { + val selId = visitIdentifier(sel._1, context) + val selType = visitType(sel._2, context) + (selId, selType) match { + case (Some(selId), Some(selType)) => + pass.rewriteSelector((selId, selType), context) + case _ => + None + } + } + + def visitConstructor(cstor : ConstructorType, context : Scope) : Option[ConstructorType] = { + val cstorId = visitIdentifier(cstor.id, context) + val cstorSelectors = cstor.inTypes.map((sel => visitSelector(sel, context))) + (cstorId, cstorSelectors) match { + case (Some(cstorId), cstorSelectors) if cstorSelectors.forall( s => s.isDefined) => + pass.rewriteConstructor(ConstructorType(cstorId, cstorSelectors.map(s => s.get), cstor.outTyp), context) + case _ => + None + } + } + + def visitDataType(dataT : DataType, context: Scope): Option[Type] = { + val id = visitIdentifier(dataT.id, context) + val cstors = dataT.constructors.map((cstor => { + val cid = visitIdentifier(cstor._1, context) + val cs = cstor._2.map(s => { + val sid = visitIdentifier(s._1, context) + val st = visitType(s._2, context) + if (sid.isDefined && st.isDefined) { + Some((sid.get, st.get)) + } else { + None + } + }) + if (cid.isDefined && cs.forall( s => s.isDefined)) { + Some((cid.get, cs.map(s => s.get))) + } else { + None + } + })) + val dt = if (cstors.forall( s => s.isDefined) && id.isDefined) { + pass.rewriteDataType(DataType(id.get, cstors.map(s => s.get)), context) + } else { + None + } + return ASTNode.introducePos(setPosition, setFilename, dt, dataT.position) + } + def visitExternalType(extT : ExternalType, context : Scope) : Option[Type] = { val moduleIdP = visitIdentifier(extT.moduleId, context) val typeIdP = visitIdentifier(extT.typeId, context) diff --git a/src/main/scala/uclid/lang/RewriteRecordSelect.scala b/src/main/scala/uclid/lang/RewriteRecordSelect.scala index d3e74612c..eebe5619c 100644 --- a/src/main/scala/uclid/lang/RewriteRecordSelect.scala +++ b/src/main/scala/uclid/lang/RewriteRecordSelect.scala @@ -20,6 +20,20 @@ class RewriteRecordSelectPass extends RewritePass { } } + override def rewriteDataType(dataT : DataType, context : Scope) : Option[DataType] = { + Some(DataType(dataT.id, dataT.constructors.map(c => (c._1, c._2.map(s => { + if(!hasRecPrefix(s)) + { + (Identifier(recordPrefix+s._1.toString), s._2) + } + else + { + UclidMain.printDebugRewriteRecord("we have not rewritten this selector " + dataT.toString ) + s + } + }))))) + } + } class RewriteRecordSelect extends ASTRewriter( diff --git a/src/main/scala/uclid/lang/Scope.scala b/src/main/scala/uclid/lang/Scope.scala index 599b370ac..cae681a83 100644 --- a/src/main/scala/uclid/lang/Scope.scala +++ b/src/main/scala/uclid/lang/Scope.scala @@ -257,7 +257,16 @@ case class Scope ( case instD : InstanceDecl => Scope.addToMap(mapAcc, Scope.Instance(instD)) case ProcedureDecl(id, sig, _, _, _, _, _) => Scope.addToMap(mapAcc, Scope.Procedure(id, sig.typ)) - case TypeDecl(id, typ) => Scope.addToMap(mapAcc, Scope.TypeSynonym(id, typ)) + case TypeDecl(id, typ) => { + typ match { + case DataType(id, constructors) => { + constructors.foldLeft(mapAcc)((out, c) => { + Scope.addToMap(out, Scope.Function(c._1, ConstructorType(c._1, c._2, typ))) + }) + } + case _ => mapAcc + } + } case StateVarsDecl(ids, typ) => ids.foldLeft(mapAcc)((acc, id) => Scope.addToMap(acc, Scope.StateVar(id, typ))) case InputVarsDecl(ids, typ) => ids.foldLeft(mapAcc)((acc, id) => Scope.addToMap(acc, Scope.InputVar(id, typ))) case OutputVarsDecl(ids, typ) => ids.foldLeft(mapAcc)((acc, id) => Scope.addToMap(acc, Scope.OutputVar(id, typ))) @@ -317,7 +326,18 @@ case class Scope ( val m1 = sig.args.foldLeft(mapAcc)((mapAcc2, operand) => Scope.addTypeToMap(mapAcc2, operand._2, Some(m))) val m2 = Scope.addTypeToMap(m1, sig.retType, Some(m)) m2 - case TypeDecl(_, typ) => Scope.addTypeToMap(mapAcc, typ, Some(m)) + case TypeDecl(_, typ) => { + typ match { + case DataType(id, constructors) => { + constructors.foldLeft(mapAcc)((acc, c) => { + val m1 = c._2.foldLeft(mapAcc)((mapAcc2, operand) => Scope.addTypeToMap(mapAcc2, operand._2, Some(m))) + val m2 = Scope.addTypeToMap(m1, typ, Some(m)) + m2 + }) + } + case _ => mapAcc + } + } case StateVarsDecl(_, typ) => Scope.addTypeToMap(mapAcc, typ, Some(m)) case InputVarsDecl(_, typ) => Scope.addTypeToMap(mapAcc, typ, Some(m)) case OutputVarsDecl(_, typ) => Scope.addTypeToMap(mapAcc, typ, Some(m)) diff --git a/src/main/scala/uclid/lang/SemanticAnalyzer.scala b/src/main/scala/uclid/lang/SemanticAnalyzer.scala index fded8d8a0..8882dbafc 100644 --- a/src/main/scala/uclid/lang/SemanticAnalyzer.scala +++ b/src/main/scala/uclid/lang/SemanticAnalyzer.scala @@ -88,7 +88,27 @@ class SemanticAnalyzerPass extends ReadOnlyPass[List[ModuleError]] { if (d == TraversalDirection.Down) { // val moduleIds = module.decls.filter((d) => d.declNames.isDefined).map((d) => (d.declName.get, d.position)) val moduleIds = module.decls.flatMap((d) => d.declNames.map((n) => (n, d.position))) - SemanticAnalyzerPass.checkIdRedeclaration(moduleIds, in) + val selectorIds = module.decls.flatMap((d) => { + d match { + case TypeDecl(id, typ) => typ match { + case DataType(id, constructors) => constructors.flatMap((c) => c._2.map(s => (s._1, s._1.position))) + case _ => List() + } + case _ => List() + } + }) + val constructorIds = module.decls.flatMap((d) => { + d match { + case TypeDecl(id, typ) => typ match { + case DataType(id, constructors) => constructors.map((c) => (c._1, c._1.position)) + case _ => List() + } + case _ => List() + } + }) + SemanticAnalyzerPass.checkIdRedeclaration(moduleIds, in) ++ + SemanticAnalyzerPass.checkIdRedeclaration(selectorIds, in) ++ + SemanticAnalyzerPass.checkIdRedeclaration(constructorIds, in) } else { in } } override def applyOnProcedure(d : TraversalDirection.T, proc : ProcedureDecl, in : List[ModuleError], context : Scope) : List[ModuleError] = { @@ -118,6 +138,18 @@ class SemanticAnalyzerPass extends ReadOnlyPass[List[ModuleError]] { in } } + + override def applyOnDataType(d : TraversalDirection.T, dataT : DataType, in : List[ModuleError], context : Scope) : List[ModuleError] = { + if (d == TraversalDirection.Down) { + val cstor_ids = dataT.constructors.map(x => (x._1, x._1.position)).toSeq + val selector_ids = dataT.constructors.flatMap(x => x._2.map(y => (y._1, y._1.position))).toSeq + SemanticAnalyzerPass.checkIdRedeclaration(cstor_ids, in) + SemanticAnalyzerPass.checkIdRedeclaration(selector_ids, in) + } else { + in + } + } + override def applyOnInstance(d : TraversalDirection.T, inst : InstanceDecl, in : List[ModuleError], context : Scope) : List[ModuleError] = { if (d == TraversalDirection.Down) { // val modType = inst.modType.get diff --git a/src/main/scala/uclid/lang/TypeChecker.scala b/src/main/scala/uclid/lang/TypeChecker.scala index 3ba54c2dd..ff82b3150 100644 --- a/src/main/scala/uclid/lang/TypeChecker.scala +++ b/src/main/scala/uclid/lang/TypeChecker.scala @@ -98,6 +98,7 @@ class TypeSynonymFinderPass extends ReadOnlyPass[Unit] TupleType(fieldTypes.map(simplifyType(_, visited, m))) case RecordType(fields) => RecordType(fields.map((f) => (f._1, simplifyType(f._2, visited, m)))) + case dt: DataType => dt case MapType(inTypes, outType) => MapType(inTypes.map(simplifyType(_, visited, m)), simplifyType(outType, visited, m)) case ArrayType(inTypes, outType) => @@ -585,6 +586,14 @@ class ExpressionTypeCheckerPass extends ReadOnlyPass[Set[Utils.TypeError]] selectFromInstance.pos = opapp.op.pos polyOpMap.put(opapp.op.astNodeId, selectFromInstance) fldT.get + case dt : DataType => + val allSels = dt.constructors.flatMap(c => c._2) + val typOption = allSels.find((p) => p._1 == field).flatMap(e => Some(e._2)) + checkTypeError(!typOption.isEmpty, "Field '" + field.toString + "' does not exist in " + dt.toString(), opapp.pos, c.filename) + val recordSelect = RecordSelect(field) + recordSelect.pos = opapp.op.pos + polyOpMap.put(opapp.op.astNodeId, recordSelect) + typOption.get case _ => checkTypeError(false, "Argument to select operator must be of type record or instance", opapp.pos, c.filename) new UndefinedType() @@ -602,6 +611,11 @@ class ExpressionTypeCheckerPass extends ReadOnlyPass[Set[Utils.TypeError]] val indexI = indexS.toInt checkTypeError(indexI >= 1 && indexI <= tupType.numFields, "Invalid tuple index: " + indexS, opapp.pos, c.filename) tupType.fieldTypes(indexI-1) + case dt : DataType => + val allSels = dt.constructors.flatMap(c => c._2) + val typOption = allSels.find((p) => p._1 == field).flatMap(e => Some(e._2)) + checkTypeError(!typOption.isEmpty, "Field '" + field.toString + "' does not exist in " + dt.toString(), opapp.pos, c.filename) + typOption.get case _ => checkTypeError(false, "Argument to select operator must be of type record", opapp.pos, c.filename) new UndefinedType() @@ -692,6 +706,14 @@ class ExpressionTypeCheckerPass extends ReadOnlyPass[Set[Utils.TypeError]] def funcAppType(fapp : FuncApplication) : Type = { val funcType1 = typeOf(fapp.e, c) lazy val typeErrorMsg = "Cannot apply %s, which is of type %s".format(fapp.e.toString, funcType1.toString) + + if (funcType1.isInstanceOf[ConstructorType]) { + val funcType = funcType1.asInstanceOf[ConstructorType] + val argTypes = fapp.args.map(typeOf(_, c)) + checkTypeError(funcType.inTypes.map(s => s._2) == argTypes, "Argument type error in application", fapp.pos, c.filename) + return funcType.outTyp + } + checkTypeError(funcType1.isInstanceOf[MapType], typeErrorMsg, fapp.pos, c.filename) val funcType = funcType1.asInstanceOf[MapType] val argTypes = fapp.args.map(typeOf(_, c)) diff --git a/src/main/scala/uclid/lang/UclidLanguage.scala b/src/main/scala/uclid/lang/UclidLanguage.scala index bda312fdf..8e9d7bf2a 100644 --- a/src/main/scala/uclid/lang/UclidLanguage.scala +++ b/src/main/scala/uclid/lang/UclidLanguage.scala @@ -1201,6 +1201,25 @@ case class MapType(inTypes: List[Type], outType: Type) extends Type { override def isMap = true } +case class DataType(id : Identifier, constructors: List[(Identifier, List[(Identifier, Type)])]) extends Type { + override def toString = { + id.name + " = | " + constructors.map(c => c.toString()).mkString(" | ") + } + + override def equals(other: Any) = other match { + case that: DataType => that.id.name == this.id.name + case that: SynonymType => that.id.name == this.id.name + case _ => false + } + + override def matches(t2: Type): Boolean = this.equals(t2) +} + +case class ConstructorType(id: Identifier, inTypes: List[(Identifier, Type)], outTyp: Type) extends Type { + override def toString = id + " {" + inTypes.map(s => s._1 + ": " + s._2.toString()).mkString(" ") + "}" + override def isMap = true +} + case class ProcedureType(inTypes : List[Type], outTypes: List[Type]) extends Type { override def toString = "procedure (" + Utils.join(inTypes.map(_.toString), ", ") + ") returns " + @@ -1221,6 +1240,7 @@ case class SynonymType(id: Identifier) extends Type { override def toString = id.toString override def equals(other: Any) = other match { case that: SynonymType => that.id.name == this.id.name + case that: DataType => that.id.name == this.id.name case _ => false } override def codegenUclidLang: Option[Type] = ULContext.smtToLangSynonym(id.name) diff --git a/src/main/scala/uclid/lang/UclidParser.scala b/src/main/scala/uclid/lang/UclidParser.scala index e57b8f963..8fe444495 100644 --- a/src/main/scala/uclid/lang/UclidParser.scala +++ b/src/main/scala/uclid/lang/UclidParser.scala @@ -141,6 +141,7 @@ object UclidParser extends UclidTokenParsers with PackratParsers { lazy val KwSingle = "single" lazy val KwDouble = "double" lazy val KwEnum = "enum" + lazy val KwData = "datatype" lazy val KwRecord = "record" lazy val KwReturns = "returns" lazy val KwAssume = "assume" @@ -212,7 +213,7 @@ object UclidParser extends UclidTokenParsers with PackratParsers { "false", "true", "bv", "fp", KwProcedure, KwBoolean, KwInteger, KwReal, KwHalf, KwSingle, KwDouble , KwReturns, KwAssume, KwAssert, KwSharedVar, KwVar, KwHavoc, KwCall, KwImport, KwIf, KwThen, KwElse, KwCase, KwEsac, KwFor, KwIn, KwRange, KwWhile, - KwInstance, KwInput, KwOutput, KwConst, KwConstRecord, KwModule, KwType, KwEnum, + KwInstance, KwInput, KwOutput, KwConst, KwConstRecord, KwModule, KwType, KwEnum, KwData, KwRecord, KwSkip, KwDefine, KwFunction, KwOracle, KwControl, KwInit, KwNext, KwLambda, KwModifies, KwProperty, KwDefineAxiom, KwForall, KwExists, KwFiniteForall, KwFiniteExists, KwGroup, KwDefault, KwSynthesis, KwGrammar, KwRequires, @@ -501,6 +502,14 @@ object UclidParser extends UclidTokenParsers with PackratParsers { KwRecord ~> ("{" ~> IdType) ~ rep("," ~> IdType) <~ "}" ^^ { case id ~ ids => lang.RecordType(id::ids) } } + lazy val Constructor : PackratParser[(lang.Identifier, List[(lang.Identifier, lang.Type)])] = { + Id ~ ("(" ~> (IdType ~ rep("," ~> IdType)).? <~ ")").? ^^ { case name ~ sels => sels match { + case Some(None) => (name, List.empty[(Identifier, Type)]) + case Some(Some(pair)) => (name, pair._1::pair._2) + case None => throw new Utils.SyntaxError("Missing parentheses after constructor "+ name,Some(name.pos),name.filename) + }} + } + lazy val MapType : PackratParser[lang.MapType] = positioned { PrimitiveType ~ rep ("*" ~> PrimitiveType) ~ ("->" ~> Type) ^^ { case t ~ ts ~ rt => lang.MapType(t :: ts, rt)} } @@ -701,7 +710,8 @@ object UclidParser extends UclidTokenParsers with PackratParsers { lazy val TypeDecl : PackratParser[lang.TypeDecl] = positioned { KwType ~> Id ~ ("=" ~> Type) <~ ";" ^^ { case id ~ t => lang.TypeDecl(id,t) } | - KwType ~> Id <~ ";" ^^ { case id => lang.TypeDecl(id, lang.UninterpretedType(id)) } + KwType ~> Id <~ ";" ^^ { case id => lang.TypeDecl(id, lang.UninterpretedType(id)) } | + KwData ~> Id ~ ("=" ~> "|".? ~> Constructor) ~ rep("|" ~> Constructor) <~ ";" ^^ {case dtname ~ ctr ~ ctrs => lang.TypeDecl(dtname, DataType(dtname, ctr :: ctrs))} } lazy val ModuleImportDecl : PackratParser[lang.ModuleImportDecl] = positioned { diff --git a/src/main/scala/uclid/smt/Context.scala b/src/main/scala/uclid/smt/Context.scala index 2947217d3..75d75c097 100644 --- a/src/main/scala/uclid/smt/Context.scala +++ b/src/main/scala/uclid/smt/Context.scala @@ -196,10 +196,33 @@ abstract trait Context { val typeName = uniqueNamer("EnumType", None) val synMapP = synMap.addSynonym(typeName, enumType) (synMapP.get(typeName).get, synMapP) + case dataType : DataType => + // create new type + var smap = synMap + val newConstructors = dataType.cstors.map(c => { + val (newsels, m) = flattenFieldList(c.inTypes.map(s => (s._1, s._2)), smap) + smap = m + ConstructorType(c.id, newsels.map(s => (s._1, s._2)), dataType) + }) + val newDataType = DataType(dataType.id, newConstructors) + + val synMapP = smap.addSynonym(dataType.id, newDataType) + (synMapP.get(dataType.id).get, synMapP) + case ConstructorType(id, inTypes, outTyp) => + // create new type + val (newInTypes, synMapP1) = flattenTypeList(inTypes.map(s => s._2), synMap) + val (newOut, synMapP2) = flatten(outTyp, synMapP1) + val newConstructorType = ConstructorType(id, inTypes.map(s => s._1).zip(newInTypes), newOut) + // add to map + val typeName = uniqueNamer("ConstructorType", None) + val synMapP = synMapP2.addSynonym(typeName, newConstructorType) + (synMapP.get(typeName).get, synMapP) case synTyp : SynonymType => val (newType, synMapP1) = flatten(synTyp.typ, synMap) val synMapP = synMapP1.addSynonym(synTyp.name, newType) (newType, synMapP) + case selfTyp : SelfReferenceType => + (selfTyp, synMap) case UndefinedType => throw new Utils.AssertionError("Undefined types are not expected here.") } diff --git a/src/main/scala/uclid/smt/Converter.scala b/src/main/scala/uclid/smt/Converter.scala index 3d5e832b1..5ee41d2fd 100644 --- a/src/main/scala/uclid/smt/Converter.scala +++ b/src/main/scala/uclid/smt/Converter.scala @@ -69,6 +69,15 @@ object Converter { smt.RecordType(fields.map((f) => (f._1.toString, typeToSMT(f._2)))) case lang.EnumType(ids) => smt.EnumType(ids.map(_.name)) + case dt : lang.DataType => + smt.DataType(dt.id.name, dt.constructors.map(c => ConstructorType(c._1.name, c._2.map(s => { + s._2 match { + case lang.SynonymType(id2) if id2 == dt.id => (s._1.name, smt.SelfReferenceType(id2.name)) + case _ => (s._1.name, typeToSMT(s._2)) + } + }), smt.SelfReferenceType(dt.id.name)))) + case lang.ConstructorType(id, inTypes, outTyp) => + smt.ConstructorType(id.name, inTypes.map(t => (t._1.name, typeToSMT(t._2))), typeToSMT(outTyp)) case lang.SynonymType(_) => throw new Utils.UnimplementedException("Synonym types must have been eliminated by now.") case lang.UndefinedType() | lang.ProcedureType(_, _) | lang.ExternalType(_, _) | @@ -99,6 +108,8 @@ object Converter { lang.RecordType(fields.map((f) => (lang.Identifier(f._1), smtToType(f._2)))) case smt.EnumType(ids) => lang.EnumType(ids.map(lang.Identifier(_))) + case dt: smt.DataType => + lang.DataType(lang.Identifier(dt.id), dt.cstors.map(cstor => (lang.Identifier(cstor.id), cstor.inTypes.map(slctor => (lang.Identifier(slctor._1), smtToType(slctor._2)))))) case _ => throw new AssertionError("Type '" + typ.toString + "' not expected here.") } diff --git a/src/main/scala/uclid/smt/SMTLIB2Interface.scala b/src/main/scala/uclid/smt/SMTLIB2Interface.scala index 8df6e20cc..208a8b4ea 100644 --- a/src/main/scala/uclid/smt/SMTLIB2Interface.scala +++ b/src/main/scala/uclid/smt/SMTLIB2Interface.scala @@ -75,7 +75,8 @@ trait SMTLIB2Base { counterId += 1 "_let_" + counterId.toString() + "_" } - def generateInputDataTypes(t : Type) : (List[String]) = { + + def generateInputDataTypes(t : Type) : (List[String]) = { t match { case MapType(inputTyp, _) => inputTyp.foldLeft(List.empty[String]) { @@ -96,8 +97,8 @@ trait SMTLIB2Base { t match { case EnumType(members) => val typeName = getTypeName(t.typeNamePrefix) - val memStr = Utils.join(members.map(s => "[" + s + "]"), " ") - val declDatatype = "(declare-datatype [%s 0] (%s))".format(typeName, memStr) + val memStr = Utils.join(members.map(s => "(" + s + ")"), " ") + val declDatatype = "(declare-datatypes ((%s 0)) (%s))".format(typeName, memStr) typeMap = typeMap.addSynonym(typeName, t) // throw new RuntimeException("need a stack trace!") (typeName, List(declDatatype)) @@ -123,11 +124,26 @@ trait SMTLIB2Base { } } val fieldString = (fieldNames zip fieldTypes).map(p => "(%s %s)".format(p._1.toString(), p._2.toString())) - val nameString = "([%s 0])".format(typeName) + val nameString = "((%s 0))".format(typeName) val argString = "[" + Utils.join(mkTupleFn :: fieldString, " ") + "]" val newType = "(declare-datatypes %s ((%s)))".format(nameString, argString) typeMap = typeMap.addSynonym(typeName, t) (typeName, newType :: newTypes1) + case dt : DataType => + val typeName = dt.id + val nameString = "((%s 0))".format(typeName) + val constructorsString = Utils.join(dt.cstors.map(c => { + val sels = Utils.join(c.inTypes.map(s => { + val inner = generateDatatype(s._2) + val sel = "(%s %s)".format(Context.getFieldName(s._1), inner._1) + sel + }), " ") + val constru = "(%s %s)".format(c.id, sels) + constru + }), " ") + val newType = "(declare-datatypes %s ((%s)))".format(nameString, constructorsString) + typeMap = typeMap.addSynonym(typeName, t) + (typeName, newType :: List.empty) case BoolType => typeMap = typeMap.addSynonym("Bool", t) ("Bool", List.empty) @@ -154,11 +170,23 @@ trait SMTLIB2Base { } } (typeStr, newTypes) + case ConstructorType(id, inTypes, outType) => + val (typeStr, newTypes1) = generateDatatype(outType) + val (_, newTypes) = inTypes.foldRight((List.empty[String], newTypes1)) { + (typ, acc) => { + val (typeStr, newTypes2) = generateDatatype(typ._2) + (acc._1 :+ typeStr, acc._2 ++ newTypes2) + } + } + (typeStr, newTypes) case UninterpretedType(typeName) => // TODO: sorts with arity greater than 1? Does uclid allow such a thing? val declDatatype = "(declare-sort %s 0)".format(typeName) typeMap = typeMap.addSynonym(typeName, t) (typeName, List(declDatatype)) + case SelfReferenceType(name) => + typeMap = typeMap.addSynonym(name, t) + (name, List.empty) case _ => throw new Utils.UnimplementedException("TODO: Implement more types in SMTLIB2Interface.generateDatatype: " + t.toString()); } @@ -424,11 +452,13 @@ class SMTLIB2Interface(args: List[String], var disableLetify: Boolean=false) ext var synthDeclCommands : String = "" def generateDeclaration(sym: Symbol) = { - val (typeName, newTypes) = generateDatatype(sym.typ) - Utils.assert(newTypes.size == 0, "No new types are expected here.") - val inputTypes = generateInputDataTypes(sym.typ).mkString(" ") - val cmd = "(declare-fun %s (%s) %s)".format(sym, inputTypes, typeName) - writeCommand(cmd) + if (!sym.typ.isInstanceOf[ConstructorType]) { + val (typeName, newTypes) = generateDatatype(sym.typ) + Utils.assert(newTypes.size == 0, "No new types are expected here.") + val inputTypes = generateInputDataTypes(sym.typ).mkString(" ") + val cmd = "(declare-fun %s (%s) %s)".format(sym, inputTypes, typeName) + writeCommand(cmd) + } } /** diff --git a/src/main/scala/uclid/smt/SMTLanguage.scala b/src/main/scala/uclid/smt/SMTLanguage.scala index a23d5022c..74918c89b 100644 --- a/src/main/scala/uclid/smt/SMTLanguage.scala +++ b/src/main/scala/uclid/smt/SMTLanguage.scala @@ -199,6 +199,35 @@ case object UndefinedType extends Type { override def isUndefined = true } +case class DataType(id : String, cstors : List[ConstructorType]) extends Type { + override val hashId = 111 + override val hashCode = finalize(hashId, 0) + override val md5hashCode = computeMD5Hash + override def toString = "data " + cstors // TODO + override val typeNamePrefix = "data" + override def isUndefined = true +} + +case class ConstructorType(id: String, inTypes: List[(String, Type)], outTyp: Type) extends Type { + override val hashId = 113 + override val hashCode = computeHash(id, inTypes) + override val md5hashCode = computeMD5Hash(id, inTypes) + override def toString = { + "constructor " + id + " " + inTypes // TODO add selectors to the toString + } + override def isMap = true + override val typeNamePrefix = "constructor" +} + +case class SelfReferenceType(name: String) extends Type { + override val hashId = 112 + override val hashCode = computeHash(name) + override val md5hashCode = computeMD5Hash(name) + override def toString = "self %s".format(name) + override def isSynonym = true + val typeNamePrefix = "self" +} + trait Operator extends Hashable { override val hashBaseId : Int = 22446 // Random number. def resultType(args: List[Expr]) : Type @@ -635,7 +664,16 @@ case class RecordSelectOp(name : String) extends Operator { Utils.assert(args(0).typ.asInstanceOf[ProductType].hasField(name), "Field '" + name + "' does not exist in product type.") } def resultType(args: List[Expr]) : Type = { - args(0).typ.asInstanceOf[ProductType].fieldType(name).get + args(0).typ match { + case t: TupleType => t.asInstanceOf[ProductType].fieldType(name).get + case r: RecordType => r.asInstanceOf[ProductType].fieldType(name).get + case DataType(id, cstors) => { + val sels = cstors.flatMap(c => c.inTypes) + sels.find(p => p._1 == name).get._2 + } + case _ => + throw new Utils.RuntimeError("Must not use symbolToZ3 on: " + args(0).typ.toString() + ".") + } } } case class RecordUpdateOp(name: String) extends Operator { @@ -983,7 +1021,13 @@ case class LetExpression(letBindings : List[(Symbol, Expr)], expr : Expr) extend //For uninterpreted function symbols or anonymous functions defined by Lambda expressions case class FunctionApplication(e: Expr, args: List[Expr]) - extends Expr (e.typ.asInstanceOf[MapType].outType) + extends Expr ({ + if (e.typ.isInstanceOf[ConstructorType]) { + e.typ.asInstanceOf[ConstructorType].outTyp + } else { + e.typ.asInstanceOf[MapType].outType + } + }) { override val hashId = 311 override val hashCode = computeHash(args, e) diff --git a/src/main/scala/uclid/smt/Z3Interface.scala b/src/main/scala/uclid/smt/Z3Interface.scala index b588ef352..3420391b6 100644 --- a/src/main/scala/uclid/smt/Z3Interface.scala +++ b/src/main/scala/uclid/smt/Z3Interface.scala @@ -293,6 +293,28 @@ class Z3Interface() extends Context { } } }) + + val getDataSort = new Memo[(String, List[ConstructorType]), z3.DatatypeSort[_]]((dt: (String, List[ConstructorType])) => { + val constructors : Array[z3.Constructor[_]] = (dt._2.map(c => { + val name = c.id + val recognizer = "is-" + c.id + val fieldNames = c.inTypes.map(pair => pair._1).toArray + val sorts = c.inTypes.map(pair => + pair._2 match { + case SelfReferenceType(name) => null + case t => getZ3Sort(t) + }).toArray + val sortRefs = c.inTypes.map(pair => + pair._2 match { + case SelfReferenceType(name) => 0 + case t => 1 + }).toArray + ctx.mkConstructor(name, recognizer, fieldNames, sorts, sortRefs) + }).toArray) + ctx.mkDatatypeSort(dt._1, constructors.asInstanceOf[Array[com.microsoft.z3.Constructor[Any]]]) + }) + + val getArraySort = new Memo[(List[Type], Type), z3.ArraySort[_, _]]((arrayType : (List[Type], Type)) => { val indexTypeIn = arrayType._1 val z3IndexType = getArrayIndexSort(indexTypeIn) @@ -314,8 +336,9 @@ class Z3Interface() extends Context { case TupleType(ts) => getTupleSort(ts) case RecordType(rs) => getRecordSort(rs) case ArrayType(rs, d) => getArraySort(rs, d) + case DataType(id, cstors) => getDataSort((id, cstors)) case EnumType(ids) => getEnumSort(ids) - case SynonymType(_, _) | MapType(_, _) | UndefinedType => + case SynonymType(_, _) | MapType(_, _) | UndefinedType | SelfReferenceType(_) | ConstructorType(_, _, _) => throw new Utils.RuntimeError("Must not use getZ3Sort to convert type: " + typ.toString() + ".") } } @@ -357,6 +380,7 @@ class Z3Interface() extends Context { abstract class ExprSort case class VarSort(sort : z3.Sort) extends ExprSort case class MapSort(ins : List[Type], out : Type) extends ExprSort + case class ConstructorSort(name: String, out : Type) extends ExprSort val exprSort = (sym.typ) match { case UninterpretedType(name) => VarSort(getUninterpretedSort(name)) @@ -370,7 +394,9 @@ class Z3Interface() extends Context { case MapType(ins, out) => MapSort(ins, out) case ArrayType(ins, out) => VarSort(getArraySort(ins, out)) case EnumType(ids) => VarSort(getEnumSort(ids)) - case SynonymType(_, _) | UndefinedType => + case DataType(id, cstors) => VarSort(getDataSort(id, cstors)) + case ConstructorType(id, _, outTyp) => ConstructorSort(id, outTyp) + case SynonymType(_, _) | UndefinedType | SelfReferenceType(_) => throw new Utils.RuntimeError("Must not use symbolToZ3 on: " + sym.typ.toString() + ".") } @@ -379,6 +405,9 @@ class Z3Interface() extends Context { ctx.mkConst(sym.id, s) case MapSort(ins, out) => ctx.mkFuncDecl(sym.id, ins.map(getZ3Sort _).toArray, getZ3Sort(out)) + case ConstructorSort(name, out) => + val adt = getZ3Sort(out).asInstanceOf[z3.DatatypeSort[_]] + adt.getConstructors().find(c => c.getName().toString() == name).get } } @@ -493,11 +522,18 @@ class Z3Interface() extends Context { } }.toArray ctx.mkExists(qVars, boolArgs(0), 1, qPatterns, null, getExistsName(), getSkolemName()) - case RecordSelectOp(fld) => + case RecordSelectOp(fld) if operands(0).typ.isInstanceOf[ProductType] => val prodType = operands(0).typ.asInstanceOf[ProductType] val fieldIndex = prodType.fieldIndex(fld) val prodSort = getProductSort(prodType) prodSort.getFieldDecls()(fieldIndex).apply(exprArgs(0)) + case RecordSelectOp(fld) if operands(0).typ.isInstanceOf[DataType] => + // find the right selector to apply based on fld and dataType + val dataType = operands(0).typ.asInstanceOf[DataType] + val z3adt = getDataSort(dataType.id, dataType.cstors) + val sel = z3adt.getAccessors().flatMap(a => a).find(a => a.getName().toString() == fld).get + // apply it and return the result + sel.apply(exprArgs(0)) case RecordUpdateOp(fld) => val prodType = operands(0).typ.asInstanceOf[ProductType] val fieldIndex = prodType.fieldIndex(fld) diff --git a/src/test/scala/ParserSpec.scala b/src/test/scala/ParserSpec.scala index 1544ec8ac..1c58953ea 100644 --- a/src/test/scala/ParserSpec.scala +++ b/src/test/scala/ParserSpec.scala @@ -45,6 +45,94 @@ import uclid.{lang => l} import java.io.File class ParserSpec extends AnyFlatSpec { + "test-adt-5-reusingdatatypename.ucl" should "not parse successfully." in { + try { + val filename = "test/test-adt-5-reusingdatatypename.ucl" + val fileModules = UclidMain.compile(ConfigCons.createConfig(filename), lang.Identifier("main")) + assert (fileModules.size == 1) + } + catch { + case p : Utils.ParserErrorList => + assert (p.errors.size == 1) + } + } + "test-adt-6-reusingselectorname.ucl" should "not parse successfully." in { + try { + val filename = "test/test-adt-6-reusingselectorname.ucl" + val fileModules = UclidMain.compile(ConfigCons.createConfig(filename), lang.Identifier("main")) + assert (fileModules.size == 1) + } + catch { + case p : Utils.ParserErrorList => + assert (p.errors.size == 1) + } + } + "test-adt-9-badconstructing.ucl" should "not typecheck." in { + try { + val filename = "test/test-adt-9-badconstructing.ucl" + val fileModules = UclidMain.compile(ConfigCons.createConfig(filename), lang.Identifier("main")) + assert (fileModules.size == 1) + } + catch { + case p : Utils.TypeErrorList => + assert (p.errors.size == 1) + } + } + "test-adt-10-badconstructing.ucl" should "not typecheck." in { + try { + val filename = "test/test-adt-10-badconstructing.ucl" + val fileModules = UclidMain.compile(ConfigCons.createConfig(filename), lang.Identifier("main")) + assert (fileModules.size == 1) + } + catch { + case p : Utils.TypeErrorList => + assert (p.errors.size == 1) + } + } + "test-adt-11-badconstructing.ucl" should "not typecheck." in { + try { + val filename = "test/test-adt-11-badconstructing.ucl" + val fileModules = UclidMain.compile(ConfigCons.createConfig(filename), lang.Identifier("main")) + assert (fileModules.size == 1) + } + catch { + case p : Utils.TypeErrorList => + assert (p.errors.size > 0) + } + } + "test-adt-12-badselecting.ucl" should "not typecheck." in { + try { + val filename = "test/test-adt-12-badselecting.ucl" + val fileModules = UclidMain.compile(ConfigCons.createConfig(filename), lang.Identifier("main")) + assert (fileModules.size == 1) + } + catch { + case p : Utils.TypeErrorList => + assert (p.errors.size > 0) + } + } + "test-adt-13-badselecting.ucl" should "not parse successfully." in { + try { + val filename = "test/test-adt-13-badselecting.ucl" + val fileModules = UclidMain.compile(ConfigCons.createConfig(filename), lang.Identifier("main")) + assert (fileModules.size == 1) + } + catch { + case p : Utils.ParserErrorList => + assert (p.errors.size == 1) + } + } + "test-adt-15-multiplemodules.ucl" should "not parse successfully." in { + try { + val filename = "test/test-adt-15-multiplemodules.ucl" + val fileModules = UclidMain.compile(ConfigCons.createConfig(filename), lang.Identifier("main")) + assert (fileModules.size == 1) + } + catch { + case p : Utils.ParserErrorList => + assert (p.errors.size > 0) + } + } "test-type1.ucl" should "not parse successfully." in { try { val filename = "test/test-type1.ucl" diff --git a/src/test/scala/VerifierSpec.scala b/src/test/scala/VerifierSpec.scala index f8c0104bc..4a8d967f5 100644 --- a/src/test/scala/VerifierSpec.scala +++ b/src/test/scala/VerifierSpec.scala @@ -88,6 +88,37 @@ class VerifierSanitySpec extends AnyFlatSpec { "test-assert-1.ucl" should "verify successfully." in { VerifierSpec.expectedFails("./test/test-assert-1.ucl", 0) } + "test-adt-0.ucl" should "verify all but one assertion." in { + VerifierSpec.expectedFails("./test/test-adt-0.ucl", 1) + } + "test-adt-1.ucl" should "verify successfully." in { + VerifierSpec.expectedFails("./test/test-adt-1.ucl", 0) + } + "test-adt-2.ucl" should "fail to verify 6 assertions." in { + VerifierSpec.expectedFails("./test/test-adt-2.ucl", 6) + } + "test-adt-3.ucl" should "verify successfully." in { + VerifierSpec.expectedFails("./test/test-adt-3.ucl", 0) + } + "test-adt-4.ucl" should "fail to verify 2 assertions." in { + VerifierSpec.expectedFails("./test/test-adt-4.ucl", 2) + } + + "test-adt-7-testingacyclicality.ucl" should "fail to verify 3 assertions." in { + VerifierSpec.expectedFails("./test/test-adt-7-testingacyclicality.ucl", 3) + } + "test-adt-8-testingacyclicality.ucl" should "fail to verify 3 assertions." in { + VerifierSpec.expectedFails("./test/test-adt-8-testingacyclicality.ucl", 3) + } + "test-adt-14-goodselecting.ucl" should "fail to verify 2 assertions." in { + VerifierSpec.expectedFails("./test/test-adt-14-goodselecting.ucl", 2) + } + "test-adt-16-multiplemodules.ucl" should "verify successfully." in { + VerifierSpec.expectedFails("./test/test-adt-16-multiplemodules.ucl", 0) + } + "test-adt-17-procedures.ucl" should "verify successfully." in { + VerifierSpec.expectedFails("./test/test-adt-17-procedures.ucl", 0) + } "test-array-0.ucl" should "verify successfully." in { VerifierSpec.expectedFails("./test/test-array-0.ucl", 0) } diff --git a/test/test-adt-0.ucl b/test/test-adt-0.ucl new file mode 100644 index 000000000..89c02d3f6 --- /dev/null +++ b/test/test-adt-0.ucl @@ -0,0 +1,23 @@ +module main { + // should pass + + datatype list = cons(head: integer, tail: list) | nil() ; + + var l : list; + + init { + l = nil(); + } + + next { + l' = cons(1, l); + } + + invariant test : l.head == 1; + + control { + induction; + check; + print_results; + } +} diff --git a/test/test-adt-1.ucl b/test/test-adt-1.ucl new file mode 100644 index 000000000..75a86a1e4 --- /dev/null +++ b/test/test-adt-1.ucl @@ -0,0 +1,28 @@ +module main { + // should pass + // TODO: parser error when selector name ommitted + + datatype tree = join(left: tree, right: tree) | leaf(node: integer); + + var t1 : tree; + var t2 : tree; + + + init { + t1 = join(leaf(1), leaf(1)); + t2 = leaf(1); + } + + next { + t1' = join(t1, t1); + t2' = join(t2, t2); + } + + invariant test : t1.left == t2; + + control { + bmc(5); + check; + print_results; + } +} diff --git a/test/test-adt-10-badconstructing.ucl b/test/test-adt-10-badconstructing.ucl new file mode 100644 index 000000000..e6972bf47 --- /dev/null +++ b/test/test-adt-10-badconstructing.ucl @@ -0,0 +1,20 @@ +module main { + // should parse error on line 11 + + datatype list = cons(head: integer, tail: list) | nil(); + + + var l1 : list; + var l2: list; + + init { + l1 = cons(); + } + + + control { + bmc(2); + check; + print_results; + } +} diff --git a/test/test-adt-11-badconstructing.ucl b/test/test-adt-11-badconstructing.ucl new file mode 100644 index 000000000..913caa4ed --- /dev/null +++ b/test/test-adt-11-badconstructing.ucl @@ -0,0 +1,20 @@ +module main { + // should parse error on line 11 + + datatype list = cons(head: integer, tail: list) | nil(); + + + var l1 : list; + var l2: list; + + init { + l1 = cons(l2, l2); + } + + + control { + bmc(2); + check; + print_results; + } +} diff --git a/test/test-adt-12-badselecting.ucl b/test/test-adt-12-badselecting.ucl new file mode 100644 index 000000000..a9b18d015 --- /dev/null +++ b/test/test-adt-12-badselecting.ucl @@ -0,0 +1,20 @@ +module main { + // should parse error on line 11 + + datatype list = cons(head: integer, tail: list) | nil(); + datatype tree = | join(left: list, right: list) | leaf(node: integer); + + var l1 : list; + var l2: list; + + init { + l1 = l2.left; + } + + + control { + bmc(2); + check; + print_results; + } +} diff --git a/test/test-adt-13-badselecting.ucl b/test/test-adt-13-badselecting.ucl new file mode 100644 index 000000000..69a3ee054 --- /dev/null +++ b/test/test-adt-13-badselecting.ucl @@ -0,0 +1,20 @@ +module main { + // should parse error on line 11 + + datatype list = cons(head: integer, tail: list) | nil(); + datatype tree = | join(left: tree, right: tree) | leaf(node: integer); + + var l1 : list; + var l2: tree; + + init { + l1 = l2.left; + } + + + control { + bmc(2); + check; + print_results; + } +} diff --git a/test/test-adt-14-goodselecting.ucl b/test/test-adt-14-goodselecting.ucl new file mode 100644 index 000000000..36295ad6b --- /dev/null +++ b/test/test-adt-14-goodselecting.ucl @@ -0,0 +1,26 @@ +module main { + // should fail + + datatype list = cons(head: integer, tail: list) | nil(); + datatype tree = | join(left: tree, right: tree) | leaf(node: integer); + + var l1 : tree; + var l2: tree; + + init { + l1 = l2.left; + havoc l2; + } + + next { + l1' = l2.left; + } + + invariant test : l1 == leaf(1); + + control { + induction; + check; + print_results; + } +} diff --git a/test/test-adt-15-multiplemodules.ucl b/test/test-adt-15-multiplemodules.ucl new file mode 100644 index 000000000..853abd100 --- /dev/null +++ b/test/test-adt-15-multiplemodules.ucl @@ -0,0 +1,33 @@ +module aux { + datatype tree = | join(left: tree, right: tree) | leaf(node: integer); + +} + +module main { + // parse error at line 10 but only if we includ line 8 + type * = aux.*; + + datatype tree = | join(left: tree, right: tree) | leaf(node: integer); + + var t1 : tree; + var t2 : tree; + + + init { + t1 = join(leaf(1), leaf(1)); + t2 = leaf(1); + } + + next { + t1' = join(t1, t1); + t2' = join(t2, t2); + } + + invariant test : t1.left == t2; + + control { + bmc(5); + check; + print_results; + } +} diff --git a/test/test-adt-16-multiplemodules.ucl b/test/test-adt-16-multiplemodules.ucl new file mode 100644 index 000000000..397a33c54 --- /dev/null +++ b/test/test-adt-16-multiplemodules.ucl @@ -0,0 +1,33 @@ +module aux { + datatype tree = | join(left: tree, right: tree) | leaf(node: integer); + +} + + + +module main { + // should pass + datatype tree = | join(left: tree, right: tree) | leaf(node: integer); + + var t1 : tree; + var t2 : tree; + + + init { + t1 = join(leaf(1), leaf(1)); + t2 = leaf(1); + } + + next { + t1' = join(t1, t1); + t2' = join(t2, t2); + } + + invariant test : t1.left == t2; + + control { + bmc(5); + check; + print_results; + } +} diff --git a/test/test-adt-17-procedures.ucl b/test/test-adt-17-procedures.ucl new file mode 100644 index 000000000..23e63a33b --- /dev/null +++ b/test/test-adt-17-procedures.ucl @@ -0,0 +1,38 @@ + +module main { + // should pass + datatype tree = | join(left: tree, right: tree) | leaf(node: integer); + + // flip tree + procedure flip_tree(t: tree) + returns (new_tree : tree) + { + var l : tree; + var r : tree; + l = t.left; + r = t.right; + new_tree = join(r, l); + } + + var t1 : tree; + var t2 : tree; + + + init { + t1 = join(leaf(1), leaf(1)); + t2 = leaf(1); + } + + next { + call (t1') = flip_tree(join(t1, t1)); + call (t2') = flip_tree(join(t2, t2)); + } + + invariant test : t1.left == t2; + + control { + bmc(5); + check; + print_results; + } +} diff --git a/test/test-adt-2.ucl b/test/test-adt-2.ucl new file mode 100644 index 000000000..c2bd44ac3 --- /dev/null +++ b/test/test-adt-2.ucl @@ -0,0 +1,34 @@ +module main { + // should fail + // TODO: fix parser so we don't have to write A() with brackets for constants + + datatype myEnum = A() | B(); + + var t1 : myEnum; + var t2 : myEnum; + + + init { + t1 = A(); + t2 = B(); + } + + next { + case + (t1 == A()) : {t1' = B();} + (t1 == B()) : {t1' = A();} + esac + case + (t2 == A()) : {t2' = B();} + (t2 == B()) : {t2' = A();} + esac + } + + invariant test : t1 == t2; + + control { + bmc(5); + check; + print_results; + } +} diff --git a/test/test-adt-3.ucl b/test/test-adt-3.ucl new file mode 100644 index 000000000..ed5ebeb53 --- /dev/null +++ b/test/test-adt-3.ucl @@ -0,0 +1,26 @@ +module main { + // should pass + datatype myRecord = | rec(A: integer, B: integer, C: integer); + + var t1 : myRecord; + var t2 : myRecord; + + + init { + t1 = rec(1, 2, 3); + t2 = rec(3, 2, 1); + } + + next { + t1' = rec(t1.C, t1.A, t1.B); + t2' = rec(t2.C, t2.A, t2.B); + } + + invariant test : t1 != t2; + + control { + induction; + check; + print_results; + } +} diff --git a/test/test-adt-4.ucl b/test/test-adt-4.ucl new file mode 100644 index 000000000..8025335bd --- /dev/null +++ b/test/test-adt-4.ucl @@ -0,0 +1,25 @@ +module main { + // should fail every third step + datatype myRecord = | rec(A: integer, B: integer, C: integer); + + var t1 : myRecord; + var t2 : myRecord; + + + init { + t1 = rec(1, 2, 3); + t2 = rec(3, 1, 2); + } + + next { + t2' = rec(t2.C, t2.A, t2.B); + } + + invariant test : t1 != t2; + + control { + bmc(5); + check; + print_results; + } +} diff --git a/test/test-adt-5-reusingdatatypename.ucl b/test/test-adt-5-reusingdatatypename.ucl new file mode 100644 index 000000000..3120ee92d --- /dev/null +++ b/test/test-adt-5-reusingdatatypename.ucl @@ -0,0 +1,26 @@ +module main { + // should throw a parse error on line 4 + datatype myRecord = | rec(A: integer, B: integer, C: integer); + datatype myRecord = | rec2(E: integer, F: integer, G: integer); + + var t1 : myRecord; + var t2 : myRecord; + + + init { + t1 = rec(1, 2, 3); + t2 = rec(3, 1, 2); + } + + next { + t2' = t2; + } + + invariant test : t1 != t2; + + control { + bmc(5); + check; + print_results; + } +} diff --git a/test/test-adt-6-reusingselectorname.ucl b/test/test-adt-6-reusingselectorname.ucl new file mode 100644 index 000000000..d418dc24d --- /dev/null +++ b/test/test-adt-6-reusingselectorname.ucl @@ -0,0 +1,26 @@ +module main { + // should throw a parse error on line 4 + datatype myRecord1 = | rec1(A: integer, B: integer, C: integer); + datatype myRecord = | rec(A: integer, F: integer, G: integer); + + var t1 : myRecord; + var t2 : myRecord; + + + init { + t1 = rec(1, 2, 3); + t2 = rec(3, 1, 2); + } + + next { + t2' = t2; + } + + invariant test : t1 != t2; + + control { + bmc(5); + check; + print_results; + } +} diff --git a/test/test-adt-7-testingacyclicality.ucl b/test/test-adt-7-testingacyclicality.ucl new file mode 100644 index 000000000..4f7ba6264 --- /dev/null +++ b/test/test-adt-7-testingacyclicality.ucl @@ -0,0 +1,19 @@ +module main { + + datatype list = cons(head: integer, tail: list) | nil(); + + var l : list; + + init { + l = nil(); + } + + + invariant test : l.tail == l && l == nil(); + + control { + bmc(2); + check; + print_results; + } +} diff --git a/test/test-adt-8-testingacyclicality.ucl b/test/test-adt-8-testingacyclicality.ucl new file mode 100644 index 000000000..7e2c8a234 --- /dev/null +++ b/test/test-adt-8-testingacyclicality.ucl @@ -0,0 +1,16 @@ +module main { + // should fail + + datatype list = cons(head: integer, tail: list) | nil(); + + var l1 : list; + var l2: list; + + invariant test : l1.tail == l2 && l2.tail == l1 && l1 != nil() && l2 != nil(); + + control { + bmc(2); + check; + print_results; + } +} diff --git a/test/test-adt-9-badconstructing.ucl b/test/test-adt-9-badconstructing.ucl new file mode 100644 index 000000000..958b215fb --- /dev/null +++ b/test/test-adt-9-badconstructing.ucl @@ -0,0 +1,20 @@ +module main { + // should parse error on line 11 + + datatype list = cons(head: integer, tail: list) | nil(); + + + var l1 : list; + var l2: list; + + init { + l1 = cons(l1); + } + + + control { + bmc(2); + check; + print_results; + } +}