Skip to content

Commit

Permalink
Support custom chain IDs (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
DelevoXDG committed Feb 19, 2024
1 parent 505a2ef commit 6a78ba4
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 10 deletions.
43 changes: 35 additions & 8 deletions Sources/Starknet/Data/StarknetChainId.swift
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
import BigInt
import Foundation

public enum StarknetChainId: String, Codable, Equatable {
case mainnet = "0x534e5f4d41494e"
case goerli = "0x534e5f474f45524c49"
case sepolia = "0x534e5f5345504f4c4941"
case integration_sepolia = "0x534e5f494e544547524154494f4e5f5345504f4c4941"

public var feltValue: Felt {
Felt(fromHex: self.rawValue)!
public struct StarknetChainId: Codable, Equatable {
public let value: Felt

public static let main = StarknetChainId(fromHex: "0x534e5f4d41494e")
public static let goerli = StarknetChainId(fromHex: "0x534e5f474f45524c49")
public static let sepolia = StarknetChainId(fromHex: "0x534e5f5345504f4c4941")
public static let integration_sepolia = StarknetChainId(fromHex: "0x534e5f494e544547524154494f4e5f5345504f4c4941")

public init(_ value: Felt) {
self.value = value
}

public init(from decoder: Decoder) throws {
let container = try decoder.singleValueContainer()
let value = try container.decode(String.self)
self.value = Felt(fromHex: value)!
}

public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
try container.encode(self.value)
}

enum CodingKeys: String, CodingKey {
Expand All @@ -18,3 +31,17 @@ public enum StarknetChainId: String, Codable, Equatable {
case integration_sepolia = "0x534e5f494e544547524154494f4e5f5345504f4c4941"
}
}

public extension StarknetChainId {
init(fromHex hex: String) {
self.value = Felt(fromHex: hex)!
}

init(fromNetworkName networkName: String) {
self.value = Felt.fromShortString(networkName)!
}

func toNetworkName() -> String {
self.value.toShortString()
}
}
4 changes: 2 additions & 2 deletions Sources/Starknet/Data/Transaction/TransactionHash.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class StarknetTransactionHashCalculator {
entryPointSelector,
StarknetCurve.pedersenOn(calldata),
maxFee,
chainId.feltValue,
chainId.value,
nonce
)
}
Expand All @@ -45,7 +45,7 @@ public class StarknetTransactionHashCalculator {
+ StarknetTransactionHashCalculator.resourceBoundsForFee(resourceBounds)
),
StarknetPoseidon.poseidonHash(paymasterData),
chainId.feltValue,
chainId.value,
nonce,
StarknetTransactionHashCalculator.dataAvailabilityModes(
feeDataAvailabilityMode: feeDataAvailabilityMode,
Expand Down
64 changes: 64 additions & 0 deletions Tests/StarknetTests/Data/StarknetChainIdTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import XCTest

@testable import Starknet

final class StarknetChainIdTests: XCTestCase {
static var chainIdCases: [(StarknetChainId, String, String)] = []

override class func setUp() {
self.chainIdCases = [
(StarknetChainId.main, "SN_MAIN", "0x534e5f4d41494e"),
(StarknetChainId.goerli, "SN_GOERLI", "0x534e5f474f45524c49"),
(StarknetChainId.sepolia, "SN_SEPOLIA", "0x534e5f5345504f4c4941"),
(StarknetChainId.integration_sepolia, "SN_INTEGRATION_SEPOLIA", "0x534e5f494e544547524154494f4e5f5345504f4c4941"),
]
}

func testFromHexInitializer() {
for (chainId, _, hex) in Self.chainIdCases {
XCTAssertEqual(StarknetChainId(fromHex: hex), chainId)
}
}

func testFromNetworkNameInitializer() {
for (chainId, name, _) in Self.chainIdCases {
XCTAssertEqual(StarknetChainId(fromNetworkName: name), chainId)
}
}

func testToNetworkName() {
for (chainId, name, _) in Self.chainIdCases {
XCTAssertEqual(chainId.toNetworkName(), name)
}
}

func testEncoding() {
for (chainId, _, hexString) in Self.chainIdCases {
do {
let data = try JSONEncoder().encode(chainId)
let expectedData = Data("\"\(hexString)\"".utf8)
XCTAssertEqual(data, expectedData)
} catch {
XCTFail("Failed to encode \(chainId)")
}
}
}

func testDecoding() {
for (chainId, _, hexString) in Self.chainIdCases {
do {
let data = Data("\"\(hexString)\"".utf8)
let decoded = try JSONDecoder().decode(StarknetChainId.self, from: data)
XCTAssertEqual(decoded, chainId)
} catch {
XCTFail("Failed to decode \(chainId)")
}
}
}

func testCustomChainId() {
let chainIdFromHex = StarknetChainId(fromHex: "0x4b4154414e41")
let chainIdFromNetworkName = StarknetChainId(fromNetworkName: "KATANA")
XCTAssertEqual(chainIdFromHex, chainIdFromNetworkName)
}
}

0 comments on commit 6a78ba4

Please sign in to comment.