Skip to content

Commit

Permalink
Support preset types in StarknetTypedData (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
DelevoXDG committed Apr 8, 2024
1 parent 0f8acec commit 0600ddd
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 22 deletions.
78 changes: 59 additions & 19 deletions Sources/Starknet/Data/TypedData/StarknetTypedData.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Foundation
public enum StarknetTypedDataError: Error, Equatable {
case decodingError
case invalidRevision(Felt)
case presetTypeRedefinition(String)
case basicTypeRedefinition(String)
case invalidTypeName(String)
case danglingType(String)
Expand Down Expand Up @@ -73,15 +74,15 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
public let domain: Domain
public let message: [String: Element]

public var revision: Revision {
try! domain.resolveRevision()
}
public let revision: Revision
private let allTypes: [String: [TypeDeclarationWrapper]]
private let hashMethod: StarknetHashMethod

private var hashMethod: StarknetHashMethod {
switch revision {
case .v0: .pedersen
case .v1: .poseidon
}
fileprivate enum CodingKeys: CodingKey {
case types
case primaryType
case domain
case message
}

private func hashArray(_ values: [Felt]) -> Felt {
Expand All @@ -96,6 +97,18 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
self.domain = domain
self.message = message

self.revision = try domain.resolveRevision()

self.allTypes = self.types.merging(
Self.getPresetTypes(revision: self.revision),
uniquingKeysWith: { current, _ in current }
)

self.hashMethod = switch revision {
case .v0: .pedersen
case .v1: .poseidon
}

try self.verifyTypes()
}

Expand Down Expand Up @@ -128,7 +141,8 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
throw StarknetTypedDataError.dependencyNotDefined(domain.separatorName)
}

let basicTypes = getBasicTypes()
let basicTypes = Self.getBasicTypes(revision: revision)
let presetTypes = Self.getPresetTypes(revision: revision)

let referencedTypes = try Set(types.values.flatMap { type in
try type.flatMap { param in
Expand All @@ -154,6 +168,9 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
guard !basicTypes.contains(typeName) else {
throw StarknetTypedDataError.basicTypeRedefinition(typeName)
}
guard presetTypes[typeName] == nil else {
throw StarknetTypedDataError.presetTypeRedefinition(typeName)
}
guard !typeName.isEmpty,
!typeName.isArray(),
!typeName.isEnum(),
Expand Down Expand Up @@ -193,13 +210,13 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {

while !toVisit.isEmpty {
let currentType = toVisit.removeFirst()
let params = types[currentType] ?? []
let params = allTypes[currentType] ?? []

try params.forEach { param in
let extractedTypes = try extractTypes(from: param).map { $0.strippingPointer() }

extractedTypes.forEach { extractedType in
if types.keys.contains(extractedType), !dependencies.contains(extractedType) {
if allTypes.keys.contains(extractedType), !dependencies.contains(extractedType) {
dependencies.append(extractedType)
toVisit.append(extractedType)
}
Expand Down Expand Up @@ -239,7 +256,7 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
return "(\(enumTypes))"
}

guard let params = types[dependency] else {
guard let params = allTypes[dependency] else {
throw StarknetTypedDataError.dependencyNotDefined(dependency)
}

Expand All @@ -265,7 +282,7 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
}

func encode(element: Element, forType typeName: String, context: Context? = nil) throws -> Felt {
if types.keys.contains(typeName) {
if allTypes.keys.contains(typeName) {
let object = try unwrapObject(from: element)

return try getStructHash(typeName: typeName, data: object)
Expand Down Expand Up @@ -312,7 +329,7 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
private func encode(data: [String: Element], forType typeName: String) throws -> [Felt] {
var values: [Felt] = []

guard let types = types[typeName] else {
guard let types = allTypes[typeName] else {
throw StarknetTypedDataError.encodingError
}

Expand Down Expand Up @@ -473,13 +490,36 @@ private extension StarknetTypedData {
static let basicTypesV0: Set = ["felt", "bool", "string", "selector", "merkletree"]
static let basicTypesV1: Set = basicTypesV0.union(["enum", "u128", "i128", "ContractAddress", "ClassHash", "timestamp", "shortstring"])

func getBasicTypes() -> Set<String> {
static let presetTypesV1 = [
"u256": [
StandardType(name: "low", type: "u128"),
StandardType(name: "high", type: "u128"),
],
"TokenAmount": [
StandardType(name: "token_address", type: "ContractAddress"),
StandardType(name: "amount", type: "u256"),
],
"NftId": [
StandardType(name: "collection_address", type: "ContractAddress"),
StandardType(name: "token_id", type: "u256"),
],
]

static func getBasicTypes(revision: Revision) -> Set<String> {
switch revision {
case .v0:
Self.basicTypesV0
basicTypesV0
case .v1:
Self.basicTypesV1
basicTypesV1
}
}

static func getPresetTypes(revision: Revision) -> [String: [TypeDeclarationWrapper]] {
let types: [String: [any TypeDeclaration]] = switch revision {
case .v0: [:]
case .v1: Self.presetTypesV1
}
return types.mapValues { $0.map { TypeDeclarationWrapper($0) } }
}
}

Expand Down Expand Up @@ -621,7 +661,7 @@ extension StarknetTypedData {
private func getEnumVariants(context: Context) throws -> [TypeDeclarationWrapper] {
let enumType: EnumType = try resolveType(context)

guard let variants = types[enumType.contains] else {
guard let variants = allTypes[enumType.contains] else {
throw StarknetTypedDataError.dependencyNotDefined(enumType.contains)
}

Expand Down Expand Up @@ -652,7 +692,7 @@ extension StarknetTypedData {
private func resolveType<T: TypeDeclaration>(_ context: Context) throws -> T {
let (parent, key) = (context.parent, context.key)

guard let parentType = types[parent] else {
guard let parentType = allTypes[parent] else {
throw StarknetTypedDataError.parentNotDefined
}
guard let targetType = parentType.first(where: { $0.type.name == key }) else {
Expand Down
28 changes: 25 additions & 3 deletions Tests/StarknetTests/Data/TypedDataTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ final class TypedDataTests: XCTestCase {
static let td = try! loadTypedDataFromFile(name: "typed_data_rev_1_example")
static let tdFeltMerkleTree = try! loadTypedDataFromFile(name: "typed_data_rev_1_felt_merkletree_example")
static let tdBasicTypes = try! loadTypedDataFromFile(name: "typed_data_rev_1_basic_types_example")
static let tdPresetTypes = try! loadTypedDataFromFile(name: "typed_data_rev_1_preset_types_example")
static let tdEnum = try! loadTypedDataFromFile(name: "typed_data_rev_1_enum_example")
}

Expand Down Expand Up @@ -95,23 +96,32 @@ final class TypedDataTests: XCTestCase {
}

func testTypesRedifintion() throws {
func testTypeRedifintion(_ type: String, _ revision: StarknetTypedData.Revision) throws {
func testBasicTypeRedifintion(_ type: String, _ revision: StarknetTypedData.Revision) throws {
try XCTAssertThrowsError(makeTypedData(type, revision)) { error in
XCTAssertEqual(error as? StarknetTypedDataError, .basicTypeRedefinition(type))
}
}
func testPresetTypeRedifintion(_ type: String, _ revision: StarknetTypedData.Revision) throws {
try XCTAssertThrowsError(makeTypedData(type, revision)) { error in
XCTAssertEqual(error as? StarknetTypedDataError, .presetTypeRedefinition(type))
}
}

let basicTypesV0 = [
"felt", "bool", "string", "selector", "merkletree",
]
let basicTypesV1 = basicTypesV0 + ["enum", "u128", "i128", "ContractAddress", "ClassHash", "timestamp", "shortstring"]
let presetTypesV1 = ["u256", "TokenAmount", "NftId"]

try XCTAssertNoThrow(makeTypedData("myType", .v0))
try basicTypesV0.forEach { type in
try testTypeRedifintion(type, .v0)
try testBasicTypeRedifintion(type, .v0)
}
try basicTypesV1.forEach { type in
try testTypeRedifintion(type, .v1)
try testBasicTypeRedifintion(type, .v1)
}
try presetTypesV1.forEach { type in
try testPresetTypeRedifintion(type, .v1)
}
}

Expand Down Expand Up @@ -287,6 +297,7 @@ final class TypedDataTests: XCTestCase {
(Self.CasesRev1.td, "Person", "0x30f7aa21b8d67cb04c30f962dd29b95ab320cb929c07d1605f5ace304dadf34"),
(Self.CasesRev1.td, "Mail", "0x560430bf7a02939edd1a5c104e7b7a55bbab9f35928b1cf5c7c97de3a907bd"),
(Self.CasesRev1.tdBasicTypes, "Example", "0x1f94cd0be8b4097a41486170fdf09a4cd23aefbc74bb2344718562994c2c111"),
(Self.CasesRev1.tdPresetTypes, "Example", "0x1a25a8bb84b761090b1fadaebe762c4b679b0d8883d2bedda695ea340839a55"),
(Self.CasesRev1.tdEnum, "Example", "0x380a54d417fb58913b904675d94a8a62e2abc3467f4b5439de0fd65fafdd1a8"),
(Self.CasesRev1.tdFeltMerkleTree, "Example", "0x160b9c0e8a7c561f9c5d9e3cc2990a1b4d26e94aa319e9eb53e163cd06c71be"),
]
Expand Down Expand Up @@ -349,6 +360,12 @@ final class TypedDataTests: XCTestCase {
"message",
"0x391d09a51a31dd17f7270aaa9904688fbeeb9c56a7e2d15c5a6af32e981c730"
),
(
Self.CasesRev1.tdPresetTypes,
"Example",
"message",
"0x74fba3f77f8a6111a9315bac313bf75ecfa46d1234e0fda60312fb6a6517667"
),
(
Self.CasesRev1.tdEnum,
"Example",
Expand Down Expand Up @@ -416,6 +433,11 @@ final class TypedDataTests: XCTestCase {
"0xcd2a3d9f938e13cd947ec05abc7fe734df8dd826",
"0x2d80b87b8bc32068247c779b2ef0f15f65c9c449325e44a9df480fb01eb43ec"
),
(
Self.CasesRev1.tdPresetTypes,
"0xcd2a3d9f938e13cd947ec05abc7fe734df8dd826",
"0x185b339d5c566a883561a88fb36da301051e2c0225deb325c91bb7aa2f3473a"
),
(
Self.CasesRev1.tdEnum,
"0xcd2a3d9f938e13cd947ec05abc7fe734df8dd826",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"types": {
"StarknetDomain": [
{ "name": "name", "type": "shortstring" },
{ "name": "version", "type": "shortstring" },
{ "name": "chainId", "type": "shortstring" },
{ "name": "revision", "type": "shortstring" }
],
"Example": [
{ "name": "n0", "type": "TokenAmount" },
{ "name": "n1", "type": "NftId" }
]
},
"primaryType": "Example",
"domain": {
"name": "StarkNet Mail",
"version": "1",
"chainId": "1",
"revision": "1"
},
"message": {
"n0": {
"token_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7",
"amount": {
"low": "0x3e8",
"high": "0x0"
}
},
"n1": {
"collection_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7",
"token_id": {
"low": "0x3e8",
"high": "0x0"
}
}
}
}

0 comments on commit 0600ddd

Please sign in to comment.