From aed1c26e726aa50b0a0f14ac95a3f56fa14a99a4 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Thu, 10 Aug 2023 12:50:07 -0400 Subject: [PATCH] Situationally gen Eq/Ord/Hash glue for tag unions --- crates/glue/src/RustGlue.roc | 242 +++++++++++++++++++---------------- 1 file changed, 131 insertions(+), 111 deletions(-) diff --git a/crates/glue/src/RustGlue.roc b/crates/glue/src/RustGlue.roc index 482e3a9047f..8c5d6a47184 100644 --- a/crates/glue/src/RustGlue.roc +++ b/crates/glue/src/RustGlue.roc @@ -426,107 +426,122 @@ deriveDebugTagUnion = \buf, types, tagUnionType, tags -> } """ -deriveEqTagUnion : Str, Str -> Str -deriveEqTagUnion = \buf, tagUnionType -> - """ - \(buf) +deriveEqTagUnion : Str, Types, Shape, Str -> Str +deriveEqTagUnion = \buf, types, shape, tagUnionType -> + if canSupportEqHashOrd types shape then + """ + \(buf) - impl Eq for \(tagUnionType) {} - """ + impl Eq for \(tagUnionType) {} + """ + else + buf -derivePartialEqTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str -derivePartialEqTagUnion = \buf, tagUnionType, tags -> - checks = - List.walk tags "" \accum, { name: tagName } -> - """ - \(accum) - \(tagName) => self.payload.\(tagName) == other.payload.\(tagName), - """ +derivePartialEqTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str +derivePartialEqTagUnion = \buf, types, shape, tagUnionType, tags -> + if canSupportPartialEqOrd types shape then + checks = + List.walk tags "" \accum, { name: tagName } -> + """ + \(accum) + \(tagName) => self.payload.\(tagName) == other.payload.\(tagName), + """ - """ - \(buf) + """ + \(buf) - impl PartialEq for \(tagUnionType) { - fn eq(&self, other: &Self) -> bool { - use discriminant_\(tagUnionType)::*; + impl PartialEq for \(tagUnionType) { + fn eq(&self, other: &Self) -> bool { + use discriminant_\(tagUnionType)::*; - if self.discriminant != other.discriminant { - return false; - } + if self.discriminant != other.discriminant { + return false; + } - unsafe { - match self.discriminant {\(checks) + unsafe { + match self.discriminant {\(checks) + } } } } - } - """ + """ + else + buf -deriveOrdTagUnion : Str, Str -> Str -deriveOrdTagUnion = \buf, tagUnionType -> - """ - \(buf) +deriveOrdTagUnion : Str, Types, Shape, Str -> Str +deriveOrdTagUnion = \buf, types, shape, tagUnionType -> + if canSupportEqHashOrd types shape then + """ + \(buf) - impl Ord for \(tagUnionType) { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.partial_cmp(other).unwrap() + impl Ord for \(tagUnionType) { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(other).unwrap() + } } - } - """ + """ + else + buf -derivePartialOrdTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str -derivePartialOrdTagUnion = \buf, tagUnionType, tags -> - checks = - List.walk tags "" \accum, { name: tagName } -> - """ - \(accum) - \(tagName) => self.payload.\(tagName).partial_cmp(&other.payload.\(tagName)), - """ +derivePartialOrdTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str +derivePartialOrdTagUnion = \buf, types, shape, tagUnionType, tags -> + if canSupportPartialEqOrd types shape then + checks = + List.walk tags "" \accum, { name: tagName } -> + """ + \(accum) + \(tagName) => self.payload.\(tagName).partial_cmp(&other.payload.\(tagName)), + """ - """ - \(buf) + """ + \(buf) - impl PartialOrd for \(tagUnionType) { - fn partial_cmp(&self, other: &Self) -> Option { - use discriminant_\(tagUnionType)::*; + impl PartialOrd for \(tagUnionType) { + fn partial_cmp(&self, other: &Self) -> Option { + use discriminant_\(tagUnionType)::*; - use std::cmp::Ordering::*; + use std::cmp::Ordering::*; - match self.discriminant.cmp(&other.discriminant) { - Less => Option::Some(Less), - Greater => Option::Some(Greater), - Equal => unsafe { - match self.discriminant {\(checks) - } - }, + match self.discriminant.cmp(&other.discriminant) { + Less => Option::Some(Less), + Greater => Option::Some(Greater), + Equal => unsafe { + match self.discriminant {\(checks) + } + }, + } } } - } - """ + """ + else + buf -deriveHashTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str -deriveHashTagUnion = \buf, tagUnionType, tags -> - checks = - List.walk tags "" \accum, { name: tagName } -> - """ - \(accum) - \(tagName) => self.payload.\(tagName).hash(state), - """ +deriveHashTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str +deriveHashTagUnion = \buf, types, shape, tagUnionType, tags -> + if canSupportEqHashOrd types shape then + checks = + List.walk tags "" \accum, { name: tagName } -> + """ + \(accum) + \(tagName) => self.payload.\(tagName).hash(state), + """ - """ - \(buf) + """ + \(buf) - impl core::hash::Hash for \(tagUnionType) { - fn hash(&self, state: &mut H) { - use discriminant_\(tagUnionType)::*; + impl core::hash::Hash for \(tagUnionType) { + fn hash(&self, state: &mut H) { + use discriminant_\(tagUnionType)::*; - unsafe { - match self.discriminant {\(checks) + unsafe { + match self.discriminant {\(checks) + } } } } - } - """ + """ + else + buf generateConstructorFunctions : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str generateConstructorFunctions = \buf, types, tagUnionType, tags -> @@ -646,6 +661,7 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di sizeOfSelf = Num.toStr (Types.size types id) alignOfSelf = Num.toStr (Types.alignment types id) + shape = Types.shape types id # TODO: this value can be different than the alignment of `id` align = @@ -701,16 +717,16 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di """ |> deriveCloneTagUnion escapedName tags |> deriveDebugTagUnion types escapedName tags - |> deriveEqTagUnion escapedName - |> derivePartialEqTagUnion escapedName tags - |> deriveOrdTagUnion escapedName - |> derivePartialOrdTagUnion escapedName tags - |> deriveHashTagUnion escapedName tags + |> deriveEqTagUnion types shape escapedName + |> derivePartialEqTagUnion types shape escapedName tags + |> deriveOrdTagUnion types shape escapedName + |> derivePartialOrdTagUnion types shape escapedName tags + |> deriveHashTagUnion types shape escapedName tags |> generateDestructorFunctions types escapedName tags |> generateConstructorFunctions types escapedName tags |> \b -> type = Types.shape types id - if cannotDeriveCopy types type then + if cannotSupportCopy types type then # A custom drop impl is only needed when we can't derive copy. b |> Str.concat @@ -942,7 +958,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz |> Str.joinWith "\n" partialEqImpl = - if canDerivePartialEq types (Types.shape types id) then + if canSupportPartialEqOrd types (Types.shape types id) then """ impl PartialEq for \(escapedName) { fn eq(&self, other: &Self) -> bool { @@ -1027,7 +1043,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz hashImpl = - if canDerivePartialEq types (Types.shape types id) then + if canSupportPartialEqOrd types (Types.shape types id) then """ impl core::hash::Hash for \(escapedName) { fn hash(&self, state: &mut H) { @@ -1067,7 +1083,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz |> Str.joinWith "\n" partialOrdImpl = - if canDerivePartialEq types (Types.shape types id) then + if canSupportPartialEqOrd types (Types.shape types id) then """ impl PartialOrd for \(escapedName) { fn partial_cmp(&self, other: &Self) -> Option { @@ -1198,7 +1214,7 @@ generateTagUnionDropPayload = \buf, types, selfMut, tags, discriminantName, disc buf |> writeTagImpls tags discriminantName indents \name, payload -> when payload is - Some id if cannotDeriveCopy types (Types.shape types id) -> + Some id if cannotSupportCopy types (Types.shape types id) -> "unsafe { core::mem::ManuallyDrop::drop(&mut \(selfMut).payload.\(name)) }," _ -> @@ -1272,7 +1288,7 @@ generateUnionField = \types -> type = Types.shape types id fullTypeStr = - if cannotDeriveCopy types type then + if cannotSupportCopy types type then # types with pointers need ManuallyDrop # because rust unions don't (and can't) # know how to drop them automatically! @@ -1673,54 +1689,58 @@ generateDeriveStr = \buf, types, type, includeDebug -> buf |> Str.concat "#[derive(Clone, " - |> condWrite (!(cannotDeriveCopy types type)) "Copy, " - |> condWrite (!(cannotDeriveDefault types type)) "Default, " + |> condWrite (!(cannotSupportCopy types type)) "Copy, " + |> condWrite (!(cannotSupportDefault types type)) "Default, " |> condWrite deriveDebug "Debug, " - |> condWrite (canDerivePartialEq types type) "PartialEq, PartialOrd, " - |> condWrite (!(hasFloat types type) && (canDerivePartialEq types type)) "Eq, Ord, Hash, " + |> condWrite (canSupportPartialEqOrd types type) "PartialEq, PartialOrd, " + |> condWrite (canSupportEqHashOrd types type) "Eq, Ord, Hash, " |> Str.concat ")]\n" -canDerivePartialEq : Types, Shape -> Bool -canDerivePartialEq = \types, type -> +canSupportEqHashOrd : Types, Shape -> Bool +canSupportEqHashOrd = \types, type -> + !(hasFloat types type) && (canSupportPartialEqOrd types type) + +canSupportPartialEqOrd : Types, Shape -> Bool +canSupportPartialEqOrd = \types, type -> when type is Function rocFn -> runtimeRepresentation = Types.shape types rocFn.lambdaSet - canDerivePartialEq types runtimeRepresentation + canSupportPartialEqOrd types runtimeRepresentation Unsized -> Bool.false Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) -> Bool.true RocStr -> Bool.true RocList inner | RocSet inner | RocBox inner -> innerType = Types.shape types inner - canDerivePartialEq types innerType + canSupportPartialEqOrd types innerType RocDict k v -> kType = Types.shape types k vType = Types.shape types v - canDerivePartialEq types kType && canDerivePartialEq types vType + canSupportPartialEqOrd types kType && canSupportPartialEqOrd types vType TagUnion (Recursive { tags }) -> List.all tags \{ payload } -> when payload is None -> Bool.true - Some id -> canDerivePartialEq types (Types.shape types id) + Some id -> canSupportPartialEqOrd types (Types.shape types id) TagUnion (NullableWrapped { tags }) -> List.all tags \{ payload } -> when payload is None -> Bool.true - Some id -> canDerivePartialEq types (Types.shape types id) + Some id -> canSupportPartialEqOrd types (Types.shape types id) TagUnion (NonNullableUnwrapped { payload }) -> - canDerivePartialEq types (Types.shape types payload) + canSupportPartialEqOrd types (Types.shape types payload) TagUnion (NullableUnwrapped { nonNullPayload }) -> - canDerivePartialEq types (Types.shape types nonNullPayload) + canSupportPartialEqOrd types (Types.shape types nonNullPayload) RecursivePointer _ -> Bool.true TagUnion (SingleTagStruct { payload: HasNoClosure fields }) -> - List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id) + List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id) TagUnion (SingleTagStruct { payload: HasClosure _ }) -> Bool.false @@ -1728,23 +1748,23 @@ canDerivePartialEq = \types, type -> TagUnion (NonRecursive { tags }) -> List.all tags \{ payload } -> when payload is - Some id -> canDerivePartialEq types (Types.shape types id) + Some id -> canSupportPartialEqOrd types (Types.shape types id) None -> Bool.true RocResult okId errId -> okShape = Types.shape types okId errShape = Types.shape types errId - canDerivePartialEq types okShape && canDerivePartialEq types errShape + canSupportPartialEqOrd types okShape && canSupportPartialEqOrd types errShape Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } -> - List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id) + List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id) Struct { fields: HasClosure fields } | TagUnionPayload { fields: HasClosure fields } -> - List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id) + List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id) -cannotDeriveCopy : Types, Shape -> Bool -cannotDeriveCopy = \types, type -> +cannotSupportCopy : Types, Shape -> Bool +cannotSupportCopy = \types, type -> !(canDeriveCopy types type) canDeriveCopy : Types, Shape -> Bool @@ -1780,22 +1800,22 @@ canDeriveCopy = \types, type -> Struct { fields: HasClosure fields } | TagUnionPayload { fields: HasClosure fields } -> List.all fields \{ id } -> canDeriveCopy types (Types.shape types id) -cannotDeriveDefault = \types, type -> +cannotSupportDefault = \types, type -> when type is Unit | Unsized | EmptyTagUnion | TagUnion _ | RocResult _ _ | RecursivePointer _ | Function _ -> Bool.true RocStr | Bool | Num _ -> Bool.false RocList id | RocSet id | RocBox id -> - cannotDeriveDefault types (Types.shape types id) + cannotSupportDefault types (Types.shape types id) TagUnionPayload { fields: HasClosure _ } -> Bool.true RocDict keyId valId -> - cannotDeriveCopy types (Types.shape types keyId) - || cannotDeriveCopy types (Types.shape types valId) + cannotSupportCopy types (Types.shape types keyId) + || cannotSupportCopy types (Types.shape types valId) Struct { fields: HasClosure _ } -> Bool.true Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } -> - List.any fields \{ id } -> cannotDeriveDefault types (Types.shape types id) + List.any fields \{ id } -> cannotSupportDefault types (Types.shape types id) hasFloat = \types, type -> hasFloatHelp types type (Set.empty {})