Skip to content

Commit

Permalink
Adding support for Algebraic Datatypes
Browse files Browse the repository at this point in the history
  • Loading branch information
amarshah1 committed Apr 16, 2024
1 parent b508697 commit b9594b7
Show file tree
Hide file tree
Showing 32 changed files with 945 additions and 19 deletions.
92 changes: 92 additions & 0 deletions src/main/scala/uclid/lang/ASTVistors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ trait ReadOnlyPass[T] {
def applyOnEnumType(d : TraversalDirection.T, enumT : EnumType, in : T, context : Scope) : T = { in }
def applyOnTupleType(d : TraversalDirection.T, tupleT : TupleType, in : T, context : Scope) : T = { in }
def applyOnRecordType(d : TraversalDirection.T, recordT : RecordType, in : T, context : Scope) : T = { in }
def applyOnDataType(d : TraversalDirection.T, dataT : DataType, in : T, context : Scope) : T = { in }
def applyOnConstructor(d : TraversalDirection.T, constructor : ConstructorType, in : T, context : Scope) : T = { in }
def applyOnSelector(d : TraversalDirection.T, selector : (Identifier, Type), in : T, context : Scope) : T = { in }
def applyOnMapType(d : TraversalDirection.T, mapT : MapType, in : T, context : Scope) : T = { in }
def applyOnProcedureType(d : TraversalDirection.T, procT : ProcedureType, in : T, context : Scope) : T = { in }
def applyOnArrayType(d : TraversalDirection.T, arrayT : ArrayType, in : T, context : Scope) : T = { in }
Expand Down Expand Up @@ -229,6 +232,9 @@ trait RewritePass {
def rewriteIntType(intT : IntegerType, context : Scope) : Option[IntegerType] = { Some(intT) }
def rewriteBitVectorType(bvT : BitVectorType, context : Scope) : Option[BitVectorType] = { Some(bvT) }
def rewriteEnumType(enumT : EnumType, context : Scope) : Option[EnumType] = { Some(enumT) }
def rewriteSelector(sel : (Identifier, Type), context : Scope) : Option[(Identifier, Type)] = { Some(sel) }
def rewriteConstructor(cstor : ConstructorType, context : Scope) : Option[ConstructorType] = { Some(cstor) }
def rewriteDataType(dataT : DataType, context : Scope) : Option[DataType] = { Some(dataT) }
def rewriteTupleType(tupleT : TupleType, context : Scope) : Option[TupleType] = { Some(tupleT) }
def rewriteRecordType(recordT : RecordType, context : Scope) : Option[RecordType] = { Some(recordT) }
def rewriteMapType(mapT : MapType, context : Scope) : Option[MapType] = { Some(mapT) }
Expand Down Expand Up @@ -652,6 +658,8 @@ class ASTAnalyzer[T] (_passName : String, _pass: ReadOnlyPass[T]) extends ASTAna
case groupT : GroupType => visitGroupType(groupT, result, context)
case floatT: FloatType => visitFloatType(floatT, result, context)
case realT: RealType => visitRealType(realT, result, context)
case dataT: DataType => visitDataType(dataT, result, context)
case constT: ConstructorType => visitConstructor(constT, result, context)
}
result = pass.applyOnType(TraversalDirection.Up, typ, result, context)
return result
Expand Down Expand Up @@ -767,6 +775,39 @@ class ASTAnalyzer[T] (_passName : String, _pass: ReadOnlyPass[T]) extends ASTAna
return result
}

def visitConstructor(cstor : ConstructorType, in : T, context : Scope) : T = {
var result : T = in
result = pass.applyOnConstructor(TraversalDirection.Down, cstor, result, context)
result = visitIdentifier(cstor.id, result, context)
result = cstor.inTypes.foldLeft(result)((r, sel) => visitSelector(sel, r, context))
result = pass.applyOnConstructor(TraversalDirection.Up, cstor, result, context)
return result
}

def visitSelector(sel : (Identifier, Type), in : T, context : Scope) : T = {
var result : T = in
result = pass.applyOnSelector(TraversalDirection.Down, sel, result, context)
result = visitIdentifier(sel._1, result, context)
result = visitType(sel._2, result, context)
result = pass.applyOnSelector(TraversalDirection.Up, sel, result, context)
return result
}

def visitDataType(dataT : DataType, in: T, context: Scope): T = {
var result: T = in
result = pass.applyOnDataType(TraversalDirection.Down, dataT, result, context)
result = visitIdentifier(dataT.id, result, context)
result = dataT.constructors.foldLeft(result)((r, constructor) => {
val r2 = visitIdentifier(constructor._1, r, context)
constructor._2.foldLeft(r2)((r, s) => {
val r3 = visitIdentifier(s._1, r, context)
visitType(s._2, r3, context)
})
})
result = pass.applyOnDataType(TraversalDirection.Up, dataT, result, context)
return result
}

def visitExternalType(extT : ExternalType, in : T, context : Scope) : T = {
var result : T = in
result = pass.applyOnExternalType(TraversalDirection.Down, extT, result, context)
Expand Down Expand Up @@ -1731,6 +1772,8 @@ class ASTRewriter (_passName : String, _pass: RewritePass, setFilename : Boolean
case groupT : GroupType => visitGroupType(groupT, context)
case floatT : FloatType => visitFloatType(floatT, context)
case realT : RealType => visitRealType(realT, context)
case dataT : DataType => visitDataType(dataT, context)
case cstor : ConstructorType => visitConstructor(cstor, context)
}).flatMap(pass.rewriteType(_, context))
return ASTNode.introducePos(setPosition, setFilename, typP, typ.position)
}
Expand Down Expand Up @@ -1835,6 +1878,55 @@ class ASTRewriter (_passName : String, _pass: RewritePass, setFilename : Boolean
return ASTNode.introducePos(setPosition, setFilename, realTP, realT.position)
}

def visitSelector(sel : (Identifier, Type), context : Scope) : Option[(Identifier, Type)] = {
val selId = visitIdentifier(sel._1, context)
val selType = visitType(sel._2, context)
(selId, selType) match {
case (Some(selId), Some(selType)) =>
pass.rewriteSelector((selId, selType), context)
case _ =>
None
}
}

def visitConstructor(cstor : ConstructorType, context : Scope) : Option[ConstructorType] = {
val cstorId = visitIdentifier(cstor.id, context)
val cstorSelectors = cstor.inTypes.map((sel => visitSelector(sel, context)))
(cstorId, cstorSelectors) match {
case (Some(cstorId), cstorSelectors) if cstorSelectors.forall( s => s.isDefined) =>
pass.rewriteConstructor(ConstructorType(cstorId, cstorSelectors.map(s => s.get), cstor.outTyp), context)
case _ =>
None
}
}

def visitDataType(dataT : DataType, context: Scope): Option[Type] = {
val id = visitIdentifier(dataT.id, context)
val cstors = dataT.constructors.map((cstor => {
val cid = visitIdentifier(cstor._1, context)
val cs = cstor._2.map(s => {
val sid = visitIdentifier(s._1, context)
val st = visitType(s._2, context)
if (sid.isDefined && st.isDefined) {
Some((sid.get, st.get))
} else {
None
}
})
if (cid.isDefined && cs.forall( s => s.isDefined)) {
Some((cid.get, cs.map(s => s.get)))
} else {
None
}
}))
val dt = if (cstors.forall( s => s.isDefined) && id.isDefined) {
pass.rewriteDataType(DataType(id.get, cstors.map(s => s.get)), context)
} else {
None
}
return ASTNode.introducePos(setPosition, setFilename, dt, dataT.position)
}

def visitExternalType(extT : ExternalType, context : Scope) : Option[Type] = {
val moduleIdP = visitIdentifier(extT.moduleId, context)
val typeIdP = visitIdentifier(extT.typeId, context)
Expand Down
14 changes: 14 additions & 0 deletions src/main/scala/uclid/lang/RewriteRecordSelect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ class RewriteRecordSelectPass extends RewritePass {
}
}

override def rewriteDataType(dataT : DataType, context : Scope) : Option[DataType] = {
Some(DataType(dataT.id, dataT.constructors.map(c => (c._1, c._2.map(s => {
if(!hasRecPrefix(s))
{
(Identifier(recordPrefix+s._1.toString), s._2)
}
else
{
UclidMain.printDebugRewriteRecord("we have not rewritten this selector " + dataT.toString )
s
}
})))))
}

}

class RewriteRecordSelect extends ASTRewriter(
Expand Down
24 changes: 22 additions & 2 deletions src/main/scala/uclid/lang/Scope.scala
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,16 @@ case class Scope (
case instD : InstanceDecl =>
Scope.addToMap(mapAcc, Scope.Instance(instD))
case ProcedureDecl(id, sig, _, _, _, _, _) => Scope.addToMap(mapAcc, Scope.Procedure(id, sig.typ))
case TypeDecl(id, typ) => Scope.addToMap(mapAcc, Scope.TypeSynonym(id, typ))
case TypeDecl(id, typ) => {
typ match {
case DataType(id, constructors) => {
constructors.foldLeft(mapAcc)((out, c) => {
Scope.addToMap(out, Scope.Function(c._1, ConstructorType(c._1, c._2, typ)))
})
}
case _ => mapAcc
}
}
case StateVarsDecl(ids, typ) => ids.foldLeft(mapAcc)((acc, id) => Scope.addToMap(acc, Scope.StateVar(id, typ)))
case InputVarsDecl(ids, typ) => ids.foldLeft(mapAcc)((acc, id) => Scope.addToMap(acc, Scope.InputVar(id, typ)))
case OutputVarsDecl(ids, typ) => ids.foldLeft(mapAcc)((acc, id) => Scope.addToMap(acc, Scope.OutputVar(id, typ)))
Expand Down Expand Up @@ -317,7 +326,18 @@ case class Scope (
val m1 = sig.args.foldLeft(mapAcc)((mapAcc2, operand) => Scope.addTypeToMap(mapAcc2, operand._2, Some(m)))
val m2 = Scope.addTypeToMap(m1, sig.retType, Some(m))
m2
case TypeDecl(_, typ) => Scope.addTypeToMap(mapAcc, typ, Some(m))
case TypeDecl(_, typ) => {
typ match {
case DataType(id, constructors) => {
constructors.foldLeft(mapAcc)((acc, c) => {
val m1 = c._2.foldLeft(mapAcc)((mapAcc2, operand) => Scope.addTypeToMap(mapAcc2, operand._2, Some(m)))
val m2 = Scope.addTypeToMap(m1, typ, Some(m))
m2
})
}
case _ => mapAcc
}
}
case StateVarsDecl(_, typ) => Scope.addTypeToMap(mapAcc, typ, Some(m))
case InputVarsDecl(_, typ) => Scope.addTypeToMap(mapAcc, typ, Some(m))
case OutputVarsDecl(_, typ) => Scope.addTypeToMap(mapAcc, typ, Some(m))
Expand Down
34 changes: 33 additions & 1 deletion src/main/scala/uclid/lang/SemanticAnalyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,27 @@ class SemanticAnalyzerPass extends ReadOnlyPass[List[ModuleError]] {
if (d == TraversalDirection.Down) {
// val moduleIds = module.decls.filter((d) => d.declNames.isDefined).map((d) => (d.declName.get, d.position))
val moduleIds = module.decls.flatMap((d) => d.declNames.map((n) => (n, d.position)))
SemanticAnalyzerPass.checkIdRedeclaration(moduleIds, in)
val selectorIds = module.decls.flatMap((d) => {
d match {
case TypeDecl(id, typ) => typ match {
case DataType(id, constructors) => constructors.flatMap((c) => c._2.map(s => (s._1, s._1.position)))
case _ => List()
}
case _ => List()
}
})
val constructorIds = module.decls.flatMap((d) => {
d match {
case TypeDecl(id, typ) => typ match {
case DataType(id, constructors) => constructors.map((c) => (c._1, c._1.position))
case _ => List()
}
case _ => List()
}
})
SemanticAnalyzerPass.checkIdRedeclaration(moduleIds, in) ++
SemanticAnalyzerPass.checkIdRedeclaration(selectorIds, in) ++
SemanticAnalyzerPass.checkIdRedeclaration(constructorIds, in)
} else { in }
}
override def applyOnProcedure(d : TraversalDirection.T, proc : ProcedureDecl, in : List[ModuleError], context : Scope) : List[ModuleError] = {
Expand Down Expand Up @@ -118,6 +138,18 @@ class SemanticAnalyzerPass extends ReadOnlyPass[List[ModuleError]] {
in
}
}

override def applyOnDataType(d : TraversalDirection.T, dataT : DataType, in : List[ModuleError], context : Scope) : List[ModuleError] = {
if (d == TraversalDirection.Down) {
val cstor_ids = dataT.constructors.map(x => (x._1, x._1.position)).toSeq
val selector_ids = dataT.constructors.flatMap(x => x._2.map(y => (y._1, y._1.position))).toSeq
SemanticAnalyzerPass.checkIdRedeclaration(cstor_ids, in)
SemanticAnalyzerPass.checkIdRedeclaration(selector_ids, in)
} else {
in
}
}

override def applyOnInstance(d : TraversalDirection.T, inst : InstanceDecl, in : List[ModuleError], context : Scope) : List[ModuleError] = {
if (d == TraversalDirection.Down) {
// val modType = inst.modType.get
Expand Down
22 changes: 22 additions & 0 deletions src/main/scala/uclid/lang/TypeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class TypeSynonymFinderPass extends ReadOnlyPass[Unit]
TupleType(fieldTypes.map(simplifyType(_, visited, m)))
case RecordType(fields) =>
RecordType(fields.map((f) => (f._1, simplifyType(f._2, visited, m))))
case dt: DataType => dt
case MapType(inTypes, outType) =>
MapType(inTypes.map(simplifyType(_, visited, m)), simplifyType(outType, visited, m))
case ArrayType(inTypes, outType) =>
Expand Down Expand Up @@ -585,6 +586,14 @@ class ExpressionTypeCheckerPass extends ReadOnlyPass[Set[Utils.TypeError]]
selectFromInstance.pos = opapp.op.pos
polyOpMap.put(opapp.op.astNodeId, selectFromInstance)
fldT.get
case dt : DataType =>
val allSels = dt.constructors.flatMap(c => c._2)
val typOption = allSels.find((p) => p._1 == field).flatMap(e => Some(e._2))
checkTypeError(!typOption.isEmpty, "Field '" + field.toString + "' does not exist in " + dt.toString(), opapp.pos, c.filename)
val recordSelect = RecordSelect(field)
recordSelect.pos = opapp.op.pos
polyOpMap.put(opapp.op.astNodeId, recordSelect)
typOption.get
case _ =>
checkTypeError(false, "Argument to select operator must be of type record or instance", opapp.pos, c.filename)
new UndefinedType()
Expand All @@ -602,6 +611,11 @@ class ExpressionTypeCheckerPass extends ReadOnlyPass[Set[Utils.TypeError]]
val indexI = indexS.toInt
checkTypeError(indexI >= 1 && indexI <= tupType.numFields, "Invalid tuple index: " + indexS, opapp.pos, c.filename)
tupType.fieldTypes(indexI-1)
case dt : DataType =>
val allSels = dt.constructors.flatMap(c => c._2)
val typOption = allSels.find((p) => p._1 == field).flatMap(e => Some(e._2))
checkTypeError(!typOption.isEmpty, "Field '" + field.toString + "' does not exist in " + dt.toString(), opapp.pos, c.filename)
typOption.get
case _ =>
checkTypeError(false, "Argument to select operator must be of type record", opapp.pos, c.filename)
new UndefinedType()
Expand Down Expand Up @@ -692,6 +706,14 @@ class ExpressionTypeCheckerPass extends ReadOnlyPass[Set[Utils.TypeError]]
def funcAppType(fapp : FuncApplication) : Type = {
val funcType1 = typeOf(fapp.e, c)
lazy val typeErrorMsg = "Cannot apply %s, which is of type %s".format(fapp.e.toString, funcType1.toString)

if (funcType1.isInstanceOf[ConstructorType]) {
val funcType = funcType1.asInstanceOf[ConstructorType]
val argTypes = fapp.args.map(typeOf(_, c))
checkTypeError(funcType.inTypes.map(s => s._2) == argTypes, "Argument type error in application", fapp.pos, c.filename)
return funcType.outTyp
}

checkTypeError(funcType1.isInstanceOf[MapType], typeErrorMsg, fapp.pos, c.filename)
val funcType = funcType1.asInstanceOf[MapType]
val argTypes = fapp.args.map(typeOf(_, c))
Expand Down
20 changes: 20 additions & 0 deletions src/main/scala/uclid/lang/UclidLanguage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,25 @@ case class MapType(inTypes: List[Type], outType: Type) extends Type {
override def isMap = true
}

case class DataType(id : Identifier, constructors: List[(Identifier, List[(Identifier, Type)])]) extends Type {
override def toString = {
id.name + " = | " + constructors.map(c => c.toString()).mkString(" | ")
}

override def equals(other: Any) = other match {
case that: DataType => that.id.name == this.id.name
case that: SynonymType => that.id.name == this.id.name
case _ => false
}

override def matches(t2: Type): Boolean = this.equals(t2)
}

case class ConstructorType(id: Identifier, inTypes: List[(Identifier, Type)], outTyp: Type) extends Type {
override def toString = id + " {" + inTypes.map(s => s._1 + ": " + s._2.toString()).mkString(" ") + "}"
override def isMap = true
}

case class ProcedureType(inTypes : List[Type], outTypes: List[Type]) extends Type {
override def toString =
"procedure (" + Utils.join(inTypes.map(_.toString), ", ") + ") returns " +
Expand All @@ -1221,6 +1240,7 @@ case class SynonymType(id: Identifier) extends Type {
override def toString = id.toString
override def equals(other: Any) = other match {
case that: SynonymType => that.id.name == this.id.name
case that: DataType => that.id.name == this.id.name
case _ => false
}
override def codegenUclidLang: Option[Type] = ULContext.smtToLangSynonym(id.name)
Expand Down
14 changes: 12 additions & 2 deletions src/main/scala/uclid/lang/UclidParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ object UclidParser extends UclidTokenParsers with PackratParsers {
lazy val KwSingle = "single"
lazy val KwDouble = "double"
lazy val KwEnum = "enum"
lazy val KwData = "datatype"
lazy val KwRecord = "record"
lazy val KwReturns = "returns"
lazy val KwAssume = "assume"
Expand Down Expand Up @@ -212,7 +213,7 @@ object UclidParser extends UclidTokenParsers with PackratParsers {
"false", "true", "bv", "fp", KwProcedure, KwBoolean, KwInteger, KwReal, KwHalf, KwSingle, KwDouble , KwReturns,
KwAssume, KwAssert, KwSharedVar, KwVar, KwHavoc, KwCall, KwImport,
KwIf, KwThen, KwElse, KwCase, KwEsac, KwFor, KwIn, KwRange, KwWhile,
KwInstance, KwInput, KwOutput, KwConst, KwConstRecord, KwModule, KwType, KwEnum,
KwInstance, KwInput, KwOutput, KwConst, KwConstRecord, KwModule, KwType, KwEnum, KwData,
KwRecord, KwSkip, KwDefine, KwFunction, KwOracle, KwControl, KwInit,
KwNext, KwLambda, KwModifies, KwProperty, KwDefineAxiom,
KwForall, KwExists, KwFiniteForall, KwFiniteExists, KwGroup, KwDefault, KwSynthesis, KwGrammar, KwRequires,
Expand Down Expand Up @@ -501,6 +502,14 @@ object UclidParser extends UclidTokenParsers with PackratParsers {
KwRecord ~> ("{" ~> IdType) ~ rep("," ~> IdType) <~ "}" ^^ { case id ~ ids => lang.RecordType(id::ids) }
}

lazy val Constructor : PackratParser[(lang.Identifier, List[(lang.Identifier, lang.Type)])] = {
Id ~ ("(" ~> (IdType ~ rep("," ~> IdType)).? <~ ")").? ^^ { case name ~ sels => sels match {
case Some(None) => (name, List.empty[(Identifier, Type)])
case Some(Some(pair)) => (name, pair._1::pair._2)
case None => throw new Utils.SyntaxError("Missing parentheses after constructor "+ name,Some(name.pos),name.filename)
}}
}

lazy val MapType : PackratParser[lang.MapType] = positioned {
PrimitiveType ~ rep ("*" ~> PrimitiveType) ~ ("->" ~> Type) ^^ { case t ~ ts ~ rt => lang.MapType(t :: ts, rt)}
}
Expand Down Expand Up @@ -701,7 +710,8 @@ object UclidParser extends UclidTokenParsers with PackratParsers {

lazy val TypeDecl : PackratParser[lang.TypeDecl] = positioned {
KwType ~> Id ~ ("=" ~> Type) <~ ";" ^^ { case id ~ t => lang.TypeDecl(id,t) } |
KwType ~> Id <~ ";" ^^ { case id => lang.TypeDecl(id, lang.UninterpretedType(id)) }
KwType ~> Id <~ ";" ^^ { case id => lang.TypeDecl(id, lang.UninterpretedType(id)) } |
KwData ~> Id ~ ("=" ~> "|".? ~> Constructor) ~ rep("|" ~> Constructor) <~ ";" ^^ {case dtname ~ ctr ~ ctrs => lang.TypeDecl(dtname, DataType(dtname, ctr :: ctrs))}
}

lazy val ModuleImportDecl : PackratParser[lang.ModuleImportDecl] = positioned {
Expand Down
Loading

0 comments on commit b9594b7

Please sign in to comment.