Skip to content

Commit

Permalink
Support enum basic type in StartknetTypedData (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
DelevoXDG committed Apr 8, 2024
1 parent 3ba857c commit 0f8acec
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 23 deletions.
167 changes: 145 additions & 22 deletions Sources/Starknet/Data/TypedData/StarknetTypedData.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public enum StarknetTypedDataError: Error, Equatable {
case basicTypeRedefinition(String)
case invalidTypeName(String)
case danglingType(String)
case unsupportedType(String)
case dependencyNotDefined(String)
case contextNotDefined
case parentNotDefined
Expand Down Expand Up @@ -129,13 +130,22 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {

let basicTypes = getBasicTypes()

let referencedTypes = Set(types.values.flatMap { type in
type.map { param in
let referencedTypes = try Set(types.values.flatMap { type in
try type.flatMap { param in
switch param {
case let .enum(enumType):
return [enumType.contains]
case let .merkletree(merkle):
merkle.contains
return [merkle.contains]
case let .standard(standard):
standard.type.strippingPointer()
if standard.type.isEnum() {
guard revision == .v1 else {
throw StarknetTypedDataError.unsupportedType(standard.type)
}
return try standard.type.extractEnumTypes()
} else {
return [standard.type.strippingPointer()]
}
}
}
} + [domain.separatorName, primaryType])
Expand All @@ -144,8 +154,11 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
guard !basicTypes.contains(typeName) else {
throw StarknetTypedDataError.basicTypeRedefinition(typeName)
}

guard !typeName.isEmpty, !typeName.isArray() else {
guard !typeName.isEmpty,
!typeName.isArray(),
!typeName.isEnum(),
!typeName.contains(",")
else {
throw StarknetTypedDataError.invalidTypeName(typeName)
}
guard referencedTypes.contains(typeName) else {
Expand All @@ -154,20 +167,42 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
}
}

private func getDependencies(of type: String) -> [String] {
private func getDependencies(of type: String) throws -> [String] {
func extractTypes(from param: TypeDeclarationWrapper) throws -> [String] {
switch param {
case let .enum(enumType):
guard revision == .v1 else {
throw StarknetTypedDataError.unsupportedType("enum")
}
return [enumType.contains]
default:
let paramType = param.type.type
if paramType.isEnum() {
guard revision == .v1 else {
throw StarknetTypedDataError.unsupportedType(paramType)
}
return try paramType.extractEnumTypes()
} else {
return [paramType]
}
}
}

var dependencies = [type]
var toVisit = [type]

while !toVisit.isEmpty {
let currentType = toVisit.removeFirst()
let params = types[currentType] ?? []

params.forEach { param in
let typeStripped = param.type.type.strippingPointer()
try params.forEach { param in
let extractedTypes = try extractTypes(from: param).map { $0.strippingPointer() }

if types.keys.contains(typeStripped), !dependencies.contains(typeStripped) {
dependencies.append(typeStripped)
toVisit.append(typeStripped)
extractedTypes.forEach { extractedType in
if types.keys.contains(extractedType), !dependencies.contains(extractedType) {
dependencies.append(extractedType)
toVisit.append(extractedType)
}
}
}
}
Expand All @@ -178,25 +213,51 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
}

private func encode(dependency: String) throws -> String {
guard let params = types[dependency] else {
throw StarknetTypedDataError.dependencyNotDefined(dependency)
}
func escape(_ string: String) -> String {
switch revision {
case .v0: string
case .v1: "\"\(string)\""
}
}
func resolveTargetType(from param: TypeDeclarationWrapper) throws -> String {
switch param {
case let .enum(enumType):
guard revision == .v1 else {
throw StarknetTypedDataError.unsupportedType("enum")
}
return enumType.contains
default:
return param.type.type
}
}
func encodeEnumTypes(from type: String) throws -> String {
guard revision == .v1 else {
throw StarknetTypedDataError.unsupportedType("enum")
}

let enumTypes = try type.extractEnumTypes().map(escape).joined(separator: ",")
return "(\(enumTypes))"
}

guard let params = types[dependency] else {
throw StarknetTypedDataError.dependencyNotDefined(dependency)
}

let encodedParams = params.map {
"\(escape($0.type.name)):\(escape($0.type.type))"
let encodedParams = try params.map {
let targetType = try resolveTargetType(from: $0)
let typeString = if targetType.isEnum() {
try encodeEnumTypes(from: targetType)
} else {
escape(targetType)
}
return "\(escape($0.type.name)):\(typeString)"
}.joined(separator: ",")

return "\(escape(dependency))(\(encodedParams))"
}

func encode(type: String) throws -> String {
let dependencies = getDependencies(of: type)
let dependencies = try getDependencies(of: type)

return try dependencies.map {
try encode(dependency: $0)
Expand Down Expand Up @@ -233,6 +294,11 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
return try hashArray(unwrapLongString(from: element))
case ("selector", _):
return try unwrapSelector(from: element)
case ("enum", .v1):
guard let context else {
throw StarknetTypedDataError.contextNotDefined
}
return try unwrapEnum(from: element, context: context)
case ("merkletree", _):
guard let context else {
throw StarknetTypedDataError.contextNotDefined
Expand Down Expand Up @@ -405,7 +471,7 @@ public extension StarknetTypedData {

private extension StarknetTypedData {
static let basicTypesV0: Set = ["felt", "bool", "string", "selector", "merkletree"]
static let basicTypesV1: Set = basicTypesV0.union(["u128", "i128", "ContractAddress", "ClassHash", "timestamp", "shortstring"])
static let basicTypesV1: Set = basicTypesV0.union(["enum", "u128", "i128", "ContractAddress", "ClassHash", "timestamp", "shortstring"])

func getBasicTypes() -> Set<String> {
switch revision {
Expand Down Expand Up @@ -524,6 +590,44 @@ extension StarknetTypedData {
}
}

func unwrapEnum(from element: Element, context: Context) throws -> Felt {
let object = try unwrapObject(from: element)

guard let variant = object.first else {
throw StarknetTypedDataError.decodingError
}
let variantName = variant.key
guard case let .array(variantData) = variant.value else {
throw StarknetTypedDataError.decodingError
}

let variants = try getEnumVariants(context: context)
let variantType = variants.first { $0.type.name == variantName }
guard let variantType else {
throw StarknetTypedDataError.decodingError
}
guard let variantIndex = variants.firstIndex(of: variantType) else {
throw StarknetTypedDataError.decodingError
}

let encodedSubtypes = try variantType.type.type.extractEnumTypes().enumerated().map { index, subtype in
let subtypeData = variantData[index]
return try encode(element: subtypeData, forType: subtype)
}

return hashArray([Felt(variantIndex)!] + encodedSubtypes)
}

private func getEnumVariants(context: Context) throws -> [TypeDeclarationWrapper] {
let enumType: EnumType = try resolveType(context)

guard let variants = types[enumType.contains] else {
throw StarknetTypedDataError.dependencyNotDefined(enumType.contains)
}

return variants
}

func prepareMerkleTreeRoot(from element: Element, context: Context) throws -> Felt {
let leavesType = try getMerkleTreeLeavesType(context: context)

Expand All @@ -539,7 +643,13 @@ extension StarknetTypedData {
return merkleTree.rootHash
}

func getMerkleTreeLeavesType(context: Context) throws -> String {
private func getMerkleTreeLeavesType(context: Context) throws -> String {
let merkleType: MerkleTreeType = try resolveType(context)

return merkleType.contains
}

private func resolveType<T: TypeDeclaration>(_ context: Context) throws -> T {
let (parent, key) = (context.parent, context.key)

guard let parentType = types[parent] else {
Expand All @@ -548,11 +658,11 @@ extension StarknetTypedData {
guard let targetType = parentType.first(where: { $0.type.name == key }) else {
throw StarknetTypedDataError.keyNotDefined
}
guard let merkleType = targetType.type as? MerkleTreeType else {
guard let targetType = targetType.type as? T else {
throw StarknetTypedDataError.decodingError
}

return merkleType.contains
return targetType
}
}

Expand All @@ -565,7 +675,20 @@ private extension String {
return self
}

func extractEnumTypes() throws -> [String] {
guard self.isEnum() else {
throw StarknetTypedDataError.decodingError
}

let content = self[self.index(after: self.startIndex) ..< self.index(before: self.endIndex)]
return content.isEmpty ? [] : content.split(separator: ",").map { String($0) }
}

func isArray() -> Bool {
self.hasSuffix("*")
}

func isEnum() -> Bool {
self.hasPrefix("(") && self.hasSuffix(")")
}
}
29 changes: 29 additions & 0 deletions Sources/Starknet/Data/TypedData/TypeDeclaration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ public extension StarknetTypedData {
}
}

struct EnumType: TypeDeclaration {
public let name: String
public let type: String = "enum"
public let contains: String

public init(name: String, contains: String) {
self.name = name
self.contains = contains
}

fileprivate enum CodingKeys: String, CodingKey {
case name
case contains
}
}

struct MerkleTreeType: TypeDeclaration {
public let name: String
public let type: String = "merkletree"
Expand All @@ -33,15 +49,19 @@ public extension StarknetTypedData {
enum TypeDeclarationWrapper: Codable, Hashable, Equatable {
fileprivate enum Keys: String, CodingKey {
case type
case contains
}

case standard(StandardType)
case `enum`(EnumType)
case merkletree(MerkleTreeType)

public var type: any TypeDeclaration {
switch self {
case let .standard(type):
type
case let .enum(type):
type
case let .merkletree(type):
type
}
Expand All @@ -51,6 +71,8 @@ public extension StarknetTypedData {
switch type {
case let type as StandardType:
self = .standard(type)
case let type as EnumType:
self = .enum(type)
case let type as MerkleTreeType:
self = .merkletree(type)
default:
Expand All @@ -61,8 +83,15 @@ public extension StarknetTypedData {
public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: Keys.self)
let type = try container.decode(String.self, forKey: Keys.type)
let contains = try container.decodeIfPresent(String.self, forKey: Keys.contains)

switch type {
case "enum":
self = if contains != nil {
try .enum(EnumType(from: decoder))
} else {
try .standard(StandardType(from: decoder))
}
case "merkletree":
self = try .merkletree(MerkleTreeType(from: decoder))
default:
Expand Down
Loading

0 comments on commit 0f8acec

Please sign in to comment.